fuzztest.c 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. /* Fuzz testing for the nanopb core.
  2. * Attempts to verify all the properties defined in the security model document.
  3. *
  4. * This program can run in three configurations:
  5. * - Standalone fuzzer, generating its own inputs and testing against them.
  6. * - Fuzzing target, reading input on stdin.
  7. * - LLVM libFuzzer target, taking input as a function argument.
  8. */
  9. #include <pb_decode.h>
  10. #include <pb_encode.h>
  11. #include <stdio.h>
  12. #include <stdlib.h>
  13. #include <string.h>
  14. #include <assert.h>
  15. #include <malloc_wrappers.h>
  16. #include "random_data.h"
  17. #include "validation.h"
  18. #include "flakystream.h"
  19. #include "test_helpers.h"
  20. #include "alltypes_static.pb.h"
  21. #include "alltypes_pointer.pb.h"
  22. #include "alltypes_callback.pb.h"
  23. #include "alltypes_proto3_static.pb.h"
  24. #include "alltypes_proto3_pointer.pb.h"
  25. /* Longer buffer size allows hitting more branches, but lowers performance. */
  26. #ifndef FUZZTEST_BUFSIZE
  27. #define FUZZTEST_BUFSIZE 256*1024
  28. #endif
  29. #ifndef FUZZTEST_MAX_STANDALONE_BUFSIZE
  30. #define FUZZTEST_MAX_STANDALONE_BUFSIZE 16384
  31. #endif
  32. static size_t g_bufsize = FUZZTEST_BUFSIZE;
  33. /* Focusing on a single test case at a time improves fuzzing performance.
  34. * If no test case is specified, enable all tests.
  35. */
  36. #if !defined(FUZZTEST_PROTO2_STATIC) && \
  37. !defined(FUZZTEST_PROTO3_STATIC) && \
  38. !defined(FUZZTEST_PROTO2_POINTER) && \
  39. !defined(FUZZTEST_PROTO3_POINTER) && \
  40. !defined(FUZZTEST_IO_ERRORS)
  41. #define FUZZTEST_PROTO2_STATIC
  42. #define FUZZTEST_PROTO3_STATIC
  43. #define FUZZTEST_PROTO2_POINTER
  44. #define FUZZTEST_PROTO3_POINTER
  45. #define FUZZTEST_IO_ERRORS
  46. #endif
  47. static uint32_t xor32_checksum(const void *data, size_t len)
  48. {
  49. const uint8_t *buf = (const uint8_t*)data;
  50. uint32_t checksum = 1234;
  51. for (; len > 0; len--)
  52. {
  53. checksum ^= checksum << 13;
  54. checksum ^= checksum >> 17;
  55. checksum ^= checksum << 5;
  56. checksum += *buf++;
  57. }
  58. return checksum;
  59. }
  60. static bool do_decode(const uint8_t *buffer, size_t msglen, size_t structsize, const pb_msgdesc_t *msgtype, unsigned flags, bool assert_success)
  61. {
  62. bool status;
  63. pb_istream_t stream;
  64. size_t initial_alloc_count = get_alloc_count();
  65. uint8_t *buf2 = malloc_with_check(g_bufsize); /* This is just to match the amount of memory allocations in do_roundtrips(). */
  66. void *msg = malloc_with_check(structsize);
  67. alltypes_static_TestExtension extmsg = alltypes_static_TestExtension_init_zero;
  68. pb_extension_t ext = pb_extension_init_zero;
  69. assert(msg);
  70. memset(msg, 0, structsize);
  71. ext.type = &alltypes_static_TestExtension_testextension;
  72. ext.dest = &extmsg;
  73. ext.next = NULL;
  74. if (msgtype == alltypes_static_AllTypes_fields)
  75. {
  76. ((alltypes_static_AllTypes*)msg)->extensions = &ext;
  77. }
  78. else if (msgtype == alltypes_pointer_AllTypes_fields)
  79. {
  80. ((alltypes_pointer_AllTypes*)msg)->extensions = &ext;
  81. }
  82. stream = pb_istream_from_buffer(buffer, msglen);
  83. status = pb_decode_ex(&stream, msgtype, msg, flags);
  84. if (status)
  85. {
  86. validate_message(msg, structsize, msgtype);
  87. }
  88. if (assert_success)
  89. {
  90. if (!status) fprintf(stderr, "pb_decode: %s\n", PB_GET_ERROR(&stream));
  91. assert(status);
  92. }
  93. pb_release(msgtype, msg);
  94. free_with_check(msg);
  95. free_with_check(buf2);
  96. assert(get_alloc_count() == initial_alloc_count);
  97. return status;
  98. }
  99. static bool do_stream_decode(const uint8_t *buffer, size_t msglen, size_t fail_after, size_t structsize, const pb_msgdesc_t *msgtype, bool assert_success)
  100. {
  101. bool status;
  102. flakystream_t stream;
  103. size_t initial_alloc_count = get_alloc_count();
  104. void *msg = malloc_with_check(structsize);
  105. assert(msg);
  106. memset(msg, 0, structsize);
  107. flakystream_init(&stream, buffer, msglen, fail_after);
  108. status = pb_decode(&stream.stream, msgtype, msg);
  109. if (status)
  110. {
  111. validate_message(msg, structsize, msgtype);
  112. }
  113. if (assert_success)
  114. {
  115. if (!status) fprintf(stderr, "pb_decode: %s\n", PB_GET_ERROR(&stream.stream));
  116. assert(status);
  117. }
  118. pb_release(msgtype, msg);
  119. free_with_check(msg);
  120. assert(get_alloc_count() == initial_alloc_count);
  121. return status;
  122. }
  123. static int g_sentinel;
  124. static bool field_callback(pb_istream_t *stream, const pb_field_t *field, void **arg)
  125. {
  126. assert(stream);
  127. assert(field);
  128. assert(*arg == &g_sentinel);
  129. return pb_read(stream, NULL, stream->bytes_left);
  130. }
  131. static bool submsg_callback(pb_istream_t *stream, const pb_field_t *field, void **arg)
  132. {
  133. assert(stream);
  134. assert(field);
  135. assert(*arg == &g_sentinel);
  136. return true;
  137. }
  138. bool do_callback_decode(const uint8_t *buffer, size_t msglen, bool assert_success)
  139. {
  140. bool status;
  141. pb_istream_t stream;
  142. size_t initial_alloc_count = get_alloc_count();
  143. alltypes_callback_AllTypes *msg = malloc_with_check(sizeof(alltypes_callback_AllTypes));
  144. assert(msg);
  145. memset(msg, 0, sizeof(alltypes_callback_AllTypes));
  146. stream = pb_istream_from_buffer(buffer, msglen);
  147. msg->rep_int32.funcs.decode = &field_callback;
  148. msg->rep_int32.arg = &g_sentinel;
  149. msg->rep_string.funcs.decode = &field_callback;
  150. msg->rep_string.arg = &g_sentinel;
  151. msg->rep_farray.funcs.decode = &field_callback;
  152. msg->rep_farray.arg = &g_sentinel;
  153. msg->req_limits.int64_min.funcs.decode = &field_callback;
  154. msg->req_limits.int64_min.arg = &g_sentinel;
  155. msg->cb_oneof.funcs.decode = &submsg_callback;
  156. msg->cb_oneof.arg = &g_sentinel;
  157. status = pb_decode(&stream, alltypes_callback_AllTypes_fields, msg);
  158. if (assert_success)
  159. {
  160. if (!status) fprintf(stderr, "pb_decode: %s\n", PB_GET_ERROR(&stream));
  161. assert(status);
  162. }
  163. pb_release(alltypes_callback_AllTypes_fields, msg);
  164. free_with_check(msg);
  165. assert(get_alloc_count() == initial_alloc_count);
  166. return status;
  167. }
  168. /* Do a decode -> encode -> decode -> encode roundtrip */
  169. void do_roundtrip(const uint8_t *buffer, size_t msglen, size_t structsize, const pb_msgdesc_t *msgtype)
  170. {
  171. bool status;
  172. uint32_t checksum2, checksum3;
  173. size_t msglen2, msglen3;
  174. uint8_t *buf2 = malloc_with_check(g_bufsize);
  175. void *msg = malloc_with_check(structsize);
  176. /* For proto2 types, we also test extension fields */
  177. alltypes_static_TestExtension extmsg = alltypes_static_TestExtension_init_zero;
  178. pb_extension_t ext = pb_extension_init_zero;
  179. pb_extension_t **ext_field = NULL;
  180. ext.type = &alltypes_static_TestExtension_testextension;
  181. ext.dest = &extmsg;
  182. ext.next = NULL;
  183. assert(buf2 && msg);
  184. if (msgtype == alltypes_static_AllTypes_fields)
  185. {
  186. ext_field = &((alltypes_static_AllTypes*)msg)->extensions;
  187. }
  188. else if (msgtype == alltypes_pointer_AllTypes_fields)
  189. {
  190. ext_field = &((alltypes_pointer_AllTypes*)msg)->extensions;
  191. }
  192. /* Decode and encode the input data.
  193. * This will bring it into canonical format.
  194. */
  195. {
  196. pb_istream_t stream = pb_istream_from_buffer(buffer, msglen);
  197. memset(msg, 0, structsize);
  198. if (ext_field) *ext_field = &ext;
  199. status = pb_decode(&stream, msgtype, msg);
  200. if (!status) fprintf(stderr, "pb_decode: %s\n", PB_GET_ERROR(&stream));
  201. assert(status);
  202. validate_message(msg, structsize, msgtype);
  203. }
  204. {
  205. pb_ostream_t stream = pb_ostream_from_buffer(buf2, g_bufsize);
  206. status = pb_encode(&stream, msgtype, msg);
  207. /* Some messages expand when re-encoding and might no longer fit
  208. * in the buffer. */
  209. if (!status && strcmp(PB_GET_ERROR(&stream), "stream full") != 0)
  210. {
  211. fprintf(stderr, "pb_encode: %s\n", PB_GET_ERROR(&stream));
  212. assert(status);
  213. }
  214. msglen2 = stream.bytes_written;
  215. checksum2 = xor32_checksum(buf2, msglen2);
  216. }
  217. pb_release(msgtype, msg);
  218. /* Then decode from canonical format and re-encode. Result should remain the same. */
  219. if (status)
  220. {
  221. pb_istream_t stream = pb_istream_from_buffer(buf2, msglen2);
  222. memset(msg, 0, structsize);
  223. if (ext_field) *ext_field = &ext;
  224. status = pb_decode(&stream, msgtype, msg);
  225. if (!status) fprintf(stderr, "pb_decode: %s\n", PB_GET_ERROR(&stream));
  226. assert(status);
  227. validate_message(msg, structsize, msgtype);
  228. }
  229. if (status)
  230. {
  231. pb_ostream_t stream = pb_ostream_from_buffer(buf2, g_bufsize);
  232. status = pb_encode(&stream, msgtype, msg);
  233. if (!status) fprintf(stderr, "pb_encode: %s\n", PB_GET_ERROR(&stream));
  234. assert(status);
  235. msglen3 = stream.bytes_written;
  236. checksum3 = xor32_checksum(buf2, msglen3);
  237. assert(msglen2 == msglen3);
  238. assert(checksum2 == checksum3);
  239. }
  240. pb_release(msgtype, msg);
  241. free_with_check(msg);
  242. free_with_check(buf2);
  243. }
  244. /* Run all enabled test cases for a given input */
  245. void do_roundtrips(const uint8_t *data, size_t size, bool expect_valid)
  246. {
  247. size_t initial_alloc_count = get_alloc_count();
  248. PB_UNUSED(expect_valid); /* Potentially unused depending on configuration */
  249. #ifdef FUZZTEST_PROTO2_STATIC
  250. if (do_decode(data, size, sizeof(alltypes_static_AllTypes), alltypes_static_AllTypes_fields, 0, expect_valid))
  251. {
  252. do_roundtrip(data, size, sizeof(alltypes_static_AllTypes), alltypes_static_AllTypes_fields);
  253. do_stream_decode(data, size, SIZE_MAX, sizeof(alltypes_static_AllTypes), alltypes_static_AllTypes_fields, true);
  254. do_callback_decode(data, size, true);
  255. }
  256. #endif
  257. #ifdef FUZZTEST_PROTO3_STATIC
  258. if (do_decode(data, size, sizeof(alltypes_proto3_static_AllTypes), alltypes_proto3_static_AllTypes_fields, 0, expect_valid))
  259. {
  260. do_roundtrip(data, size, sizeof(alltypes_proto3_static_AllTypes), alltypes_proto3_static_AllTypes_fields);
  261. do_stream_decode(data, size, SIZE_MAX, sizeof(alltypes_proto3_static_AllTypes), alltypes_proto3_static_AllTypes_fields, true);
  262. }
  263. #endif
  264. #ifdef FUZZTEST_PROTO2_POINTER
  265. if (do_decode(data, size, sizeof(alltypes_pointer_AllTypes), alltypes_pointer_AllTypes_fields, 0, expect_valid))
  266. {
  267. do_roundtrip(data, size, sizeof(alltypes_pointer_AllTypes), alltypes_pointer_AllTypes_fields);
  268. do_stream_decode(data, size, SIZE_MAX, sizeof(alltypes_pointer_AllTypes), alltypes_pointer_AllTypes_fields, true);
  269. }
  270. #endif
  271. #ifdef FUZZTEST_PROTO3_POINTER
  272. if (do_decode(data, size, sizeof(alltypes_proto3_pointer_AllTypes), alltypes_proto3_pointer_AllTypes_fields, 0, expect_valid))
  273. {
  274. do_roundtrip(data, size, sizeof(alltypes_proto3_pointer_AllTypes), alltypes_proto3_pointer_AllTypes_fields);
  275. do_stream_decode(data, size, SIZE_MAX, sizeof(alltypes_proto3_pointer_AllTypes), alltypes_proto3_pointer_AllTypes_fields, true);
  276. }
  277. #endif
  278. #ifdef FUZZTEST_IO_ERRORS
  279. {
  280. size_t orig_max_alloc_bytes = get_max_alloc_bytes();
  281. /* Test decoding when error conditions occur.
  282. * The decoding will end either when running out of memory or when stream returns IO error.
  283. * Testing proto2 is enough for good coverage here, as it has a superset of the field types of proto3.
  284. */
  285. set_max_alloc_bytes(get_alloc_bytes() + 4096);
  286. do_stream_decode(data, size, size - 16, sizeof(alltypes_static_AllTypes), alltypes_static_AllTypes_fields, false);
  287. do_stream_decode(data, size, size - 16, sizeof(alltypes_pointer_AllTypes), alltypes_pointer_AllTypes_fields, false);
  288. set_max_alloc_bytes(orig_max_alloc_bytes);
  289. }
  290. /* Test pb_decode_ex() modes */
  291. do_decode(data, size, sizeof(alltypes_static_AllTypes), alltypes_static_AllTypes_fields, PB_DECODE_NOINIT | PB_DECODE_DELIMITED, false);
  292. do_decode(data, size, sizeof(alltypes_static_AllTypes), alltypes_static_AllTypes_fields, PB_DECODE_NULLTERMINATED, false);
  293. /* Test callbacks also when message is not valid */
  294. do_callback_decode(data, size, false);
  295. #endif
  296. assert(get_alloc_count() == initial_alloc_count);
  297. }
  298. /* Fuzzer stub for Google OSSFuzz integration */
  299. int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)
  300. {
  301. if (size > g_bufsize)
  302. return 0;
  303. do_roundtrips(data, size, false);
  304. return 0;
  305. }
  306. #ifndef LLVMFUZZER
  307. static bool generate_base_message(uint8_t *buffer, size_t *msglen)
  308. {
  309. pb_ostream_t stream;
  310. bool status;
  311. static const alltypes_static_AllTypes initval = alltypes_static_AllTypes_init_default;
  312. /* Allocate a message and fill it with defaults */
  313. alltypes_static_AllTypes *msg = malloc_with_check(sizeof(alltypes_static_AllTypes));
  314. memcpy(msg, &initval, sizeof(initval));
  315. /* Apply randomness to the data before encoding */
  316. while (rand_int(0, 7))
  317. rand_mess((uint8_t*)msg, sizeof(alltypes_static_AllTypes));
  318. msg->extensions = NULL;
  319. stream = pb_ostream_from_buffer(buffer, g_bufsize);
  320. status = pb_encode(&stream, alltypes_static_AllTypes_fields, msg);
  321. assert(stream.bytes_written <= g_bufsize);
  322. assert(stream.bytes_written <= alltypes_static_AllTypes_size);
  323. *msglen = stream.bytes_written;
  324. pb_release(alltypes_static_AllTypes_fields, msg);
  325. free_with_check(msg);
  326. return status;
  327. }
  328. /* Stand-alone fuzzer iteration, generates random data itself */
  329. static void run_iteration()
  330. {
  331. uint8_t *buffer = malloc_with_check(g_bufsize);
  332. size_t msglen;
  333. /* Fill the whole buffer with noise, to prepare for length modifications */
  334. rand_fill(buffer, g_bufsize);
  335. if (generate_base_message(buffer, &msglen))
  336. {
  337. rand_protobuf_noise(buffer, g_bufsize, &msglen);
  338. /* At this point the message should always be valid */
  339. do_roundtrips(buffer, msglen, true);
  340. /* Apply randomness to the encoded data */
  341. while (rand_bool())
  342. rand_mess(buffer, g_bufsize);
  343. /* Apply randomness to encoded data length */
  344. if (rand_bool())
  345. msglen = rand_int(0, g_bufsize);
  346. /* In this step the message may be valid or invalid */
  347. do_roundtrips(buffer, msglen, false);
  348. }
  349. free_with_check(buffer);
  350. assert(get_alloc_count() == 0);
  351. }
  352. int main(int argc, char **argv)
  353. {
  354. int i;
  355. int iterations;
  356. if (argc >= 2)
  357. {
  358. /* Run in stand-alone mode */
  359. if (g_bufsize > FUZZTEST_MAX_STANDALONE_BUFSIZE)
  360. g_bufsize = FUZZTEST_MAX_STANDALONE_BUFSIZE;
  361. random_set_seed(strtoul(argv[1], NULL, 0));
  362. iterations = (argc >= 3) ? atol(argv[2]) : 10000;
  363. for (i = 0; i < iterations; i++)
  364. {
  365. printf("Iteration %d/%d, seed %lu\n", i, iterations, (unsigned long)random_get_seed());
  366. run_iteration();
  367. }
  368. }
  369. else
  370. {
  371. /* Run as a stub for afl-fuzz and similar */
  372. uint8_t *buffer;
  373. size_t msglen;
  374. buffer = malloc_with_check(g_bufsize);
  375. SET_BINARY_MODE(stdin);
  376. msglen = fread(buffer, 1, g_bufsize, stdin);
  377. LLVMFuzzerTestOneInput(buffer, msglen);
  378. if (!feof(stdin))
  379. {
  380. /* Read any leftover input data if our buffer is smaller than
  381. * message size. */
  382. fprintf(stderr, "Warning: input message too long\n");
  383. while (fread(buffer, 1, g_bufsize, stdin) == g_bufsize);
  384. }
  385. free_with_check(buffer);
  386. }
  387. return 0;
  388. }
  389. #endif