
Author: ods15 Date: Wed Nov 15 11:04:00 2006 New Revision: 206 Modified: trunk/libnut/demuxer.c Log: killl all EAGAIN issues in main header reading which are dangerous to malloc stuff by verifying all headers before parsing them make SAFE_CALLOC and SAFE_REALLOC macros add some helpful comments Modified: trunk/libnut/demuxer.c ============================================================================== --- trunk/libnut/demuxer.c (original) +++ trunk/libnut/demuxer.c Wed Nov 15 11:04:00 2006 @@ -185,6 +185,17 @@ #define ERROR(expr, code) do { if (expr) { err = code; goto err_out; } } while(0) #define GET_V(bc, v) do { uint64_t _tmp; CHECK(get_v_((bc), &_tmp, #v)); (v) = _tmp; } while(0) #define GET_S(bc, v) do { int64_t _tmp; CHECK(get_s_((bc), &_tmp, #v)); (v) = _tmp; } while(0) +#define SAFE_CALLOC(alloc, var, a, b) do { \ + ERROR(SIZE_MAX/(a) < (b), -ERR_OUT_OF_MEM); \ + ERROR(!((var) = (alloc)->malloc((a) * (b))), -ERR_OUT_OF_MEM); \ + memset((var), 0, (a) * (b)); \ +} while(0) +#define SAFE_REALLOC(alloc, var, a, b) do { \ + void * _tmp; \ + ERROR(SIZE_MAX/(a) < (b), -ERR_OUT_OF_MEM); \ + ERROR(!((_tmp) = (alloc)->realloc((var), (a) * (b))), -ERR_OUT_OF_MEM); \ + (var) = _tmp; \ +} while(0) static int get_data(input_buffer_t * bc, int len, uint8_t * buf) { int tmp; @@ -347,12 +358,7 @@ GET_V(tmp, info->chapter_len); GET_V(tmp, info->count); - if (!info->fields) { - ERROR(SIZE_MAX/sizeof(nut_info_field_t) < info->count, -ERR_OUT_OF_MEM); - info->fields = nut->alloc->malloc(info->count * sizeof(nut_info_field_t)); - ERROR(!nut->tb, -ERR_OUT_OF_MEM); - memset(info->fields, 0, info->count * sizeof(nut_info_field_t)); // initialize pointer to NULL... - } + SAFE_CALLOC(nut->alloc, info->fields, sizeof(nut_info_field_t), info->count); for (i = 0; i < info->count; i++) { int len; @@ -395,9 +401,59 @@ return err; } +static int skip_reserved_headers(nut_context_t * nut, uint64_t stop_startcode) { + int err; + uint64_t tmp; + CHECK(get_bytes(nut->i, 8, &tmp)); + while (tmp >> 56 != 'N') { + if (tmp == stop_startcode || tmp == SYNCPOINT_STARTCODE) break; + CHECK(get_header(nut->i, NULL)); + CHECK(get_bytes(nut->i, 8, &tmp)); + } + nut->i->buf_ptr -= 8; +err_out: + return err; +} + +static int get_headers(nut_context_t * nut, int read_info) { + int i, err; + uint64_t tmp; + + CHECK(get_bytes(nut->i, 8, &tmp)); + assert(tmp == MAIN_STARTCODE); // sanity, get_headers should only be called in this situation + CHECK(get_main_header(nut)); + + SAFE_CALLOC(nut->alloc, nut->sc, sizeof(stream_context_t), nut->stream_count); + + for (i = 0; i < nut->stream_count; i++) { + int j; + CHECK(skip_reserved_headers(nut, STREAM_STARTCODE)); + CHECK(get_bytes(nut->i, 8, &tmp)); + ERROR(tmp != STREAM_STARTCODE, -ERR_NOSTREAM_STARTCODE); + CHECK(get_stream_header(nut, i)); + SAFE_CALLOC(nut->alloc, nut->sc[i].pts_cache, sizeof(int64_t), nut->sc[i].sh.decode_delay); + for (j = 0; j < nut->sc[i].sh.decode_delay; j++) nut->sc[i].pts_cache[j] = -1; + } + if (read_info) { + CHECK(get_bytes(nut->i, 8, &tmp)); + while (tmp == INFO_STARTCODE) { + nut->info_count++; + SAFE_REALLOC(nut->alloc, nut->info, sizeof(nut_info_packet_t), nut->info_count + 1); + memset(&nut->info[nut->info_count - 1], 0, sizeof(nut_info_packet_t)); + CHECK(get_info_header(nut, &nut->info[nut->info_count - 1])); + CHECK(get_bytes(nut->i, 8, &tmp)); + } + nut->info[nut->info_count].count = -1; + nut->i->buf_ptr -= 8; + } +err_out: + assert(err != 2); // EAGAIN is illegal here!! + return err; +} + static int add_syncpoint(nut_context_t * nut, syncpoint_t sp, uint64_t * pts, uint64_t * eor, int * out) { syncpoint_list_t * sl = &nut->syncpoints; - int i, j; + int i, j, err = 0; assert(nut->dopts.cache_syncpoints & 1 || !pts); // pts information is never stored with no syncpoint cache for (i = sl->len; i--; ) { // more often than not, we're adding at end of list @@ -422,21 +478,11 @@ } i++; if (sl->len + 1 > sl->alloc_len) { - void * a; sl->alloc_len += PREALLOC_SIZE/4; - if (SIZE_MAX/sl->alloc_len < sizeof(syncpoint_t) || - SIZE_MAX/sl->alloc_len < sizeof(uint64_t) * nut->stream_count) - return -ERR_OUT_OF_MEM; - a = nut->alloc->realloc(sl->s, sl->alloc_len * sizeof(syncpoint_t)); - if (!a) return -ERR_OUT_OF_MEM; - sl->s = a; + SAFE_REALLOC(nut->alloc, sl->s, sizeof(syncpoint_t), sl->alloc_len); if (nut->dopts.cache_syncpoints & 1) { - a = nut->alloc->realloc(sl->pts, sl->alloc_len * nut->stream_count * sizeof(uint64_t)); - if (!a) return -ERR_OUT_OF_MEM; - sl->pts = a; - a = nut->alloc->realloc(sl->eor, sl->alloc_len * nut->stream_count * sizeof(uint64_t)); - if (!a) return -ERR_OUT_OF_MEM; - sl->eor = a; + SAFE_REALLOC(nut->alloc, sl->pts, nut->stream_count * sizeof(uint64_t), sl->alloc_len); + SAFE_REALLOC(nut->alloc, sl->eor, nut->stream_count * sizeof(uint64_t), sl->alloc_len); } } memmove(sl->s + i + 1, sl->s + i, (sl->len - i) * sizeof(syncpoint_t)); @@ -454,7 +500,8 @@ sl->len++; if (out) *out = i; - return 0; +err_out: + return err; } static void set_global_pts(nut_context_t * nut, uint64_t pts) { @@ -518,7 +565,6 @@ syncpoint_list_t * sl = &nut->syncpoints; uint64_t x; int i; - void * a, * b, * c; CHECK(get_bytes(nut->i, 8, &x)); ERROR(x != INDEX_STARTCODE, -ERR_GENERAL_ERROR); @@ -532,16 +578,10 @@ } GET_V(tmp, x); - ERROR(SIZE_MAX/x < sizeof(syncpoint_t) || SIZE_MAX/x < sizeof(uint64_t) * nut->stream_count, -ERR_OUT_OF_MEM); - sl->alloc_len = sl->len = x; - a = nut->alloc->realloc(sl->s, sl->alloc_len * sizeof(syncpoint_t)); - b = nut->alloc->realloc(sl->pts, sl->alloc_len * nut->stream_count * sizeof(uint64_t)); - c = nut->alloc->realloc(sl->eor, sl->alloc_len * nut->stream_count * sizeof(uint64_t)); - ERROR(!a || !b || !c, -ERR_OUT_OF_MEM); - sl->s = a; - sl->pts = b; - sl->eor = c; + SAFE_REALLOC(nut->alloc, sl->s, sizeof(syncpoint_t), sl->alloc_len); + SAFE_REALLOC(nut->alloc, sl->pts, nut->stream_count * sizeof(uint64_t), sl->alloc_len); + SAFE_REALLOC(nut->alloc, sl->eor, nut->stream_count * sizeof(uint64_t), sl->alloc_len); for (i = 0; i < sl->len; i++) { GET_V(tmp, sl->s[i].pos); @@ -710,6 +750,31 @@ else sc->eor = 0; } +static int find_main_headers(nut_context_t * nut) { + int err = 0; + uint64_t tmp; + off_t start = bctello(nut->i); + if (start < strlen(ID_STRING) + 1) { + int n = strlen(ID_STRING) + 1 - start; + ERROR(ready_read_buf(nut->i, n) < n, buf_eof(nut->i)); + if (memcmp(get_buf(nut->i, start), ID_STRING + start, n)) nut->i->buf_ptr = nut->i->buf; // rewind + else fprintf(stderr, "NUT file_id checks out\n"); + } + + CHECK(get_bytes(nut->i, 7, &tmp)); + ERROR(ready_read_buf(nut->i, 4096) < 4096, buf_eof(nut->i)); + while (bctello(nut->i) < 4096) { + tmp = (tmp << 8) | *(nut->i->buf_ptr++); + if (tmp == MAIN_STARTCODE) break; + } + ERROR(tmp != MAIN_STARTCODE, -ERR_NO_HEADERS); + nut->i->buf_ptr -= 8; + nut->last_headers = bctello(nut->i); + flush_buf(nut->i); +err_out: + return err; +} + static int find_syncpoint(nut_context_t * nut, int backwards, syncpoint_t * res, off_t stop) { int read; int err = 0; @@ -806,82 +871,26 @@ int nut_read_headers(nut_context_t * nut, nut_stream_header_t * s [], nut_info_packet_t * info []) { int i, err = 0; uint64_t tmp; - *s = NULL; - if (!nut->seek_status) { // we already have headers, we were called just for index - if (!nut->last_headers) { - off_t start = bctello(nut->i); - if (start < strlen(ID_STRING) + 1) { - int n = strlen(ID_STRING) + 1 - start; - ERROR(ready_read_buf(nut->i, n) < n, buf_eof(nut->i)); - if (memcmp(get_buf(nut->i, start), ID_STRING + start, n)) nut->i->buf_ptr = nut->i->buf; // rewind - fprintf(stderr, "NUT file_id checks out\n"); - } + if (!nut->sc) { // we already have headers, we were called just for index + if (!nut->last_headers) CHECK(find_main_headers(nut)); - CHECK(get_bytes(nut->i, 7, &tmp)); - ERROR(ready_read_buf(nut->i, 4096) < 4096, buf_eof(nut->i)); - while (bctello(nut->i) < 4096) { - tmp = (tmp << 8) | *(nut->i->buf_ptr++); - if (tmp == MAIN_STARTCODE) break; - } - ERROR(tmp != MAIN_STARTCODE, -ERR_NO_HEADERS); - nut->last_headers = bctello(nut->i) - 8; - flush_buf(nut->i); - } + // load all headers into memory so they can be cleanly decoded without EAGAIN issues + CHECK(skip_reserved_headers(nut, SYNCPOINT_STARTCODE)); - CHECK(get_main_header(nut)); + // rewind to where the headers were found + nut->i->buf_ptr = get_buf(nut->i, nut->last_headers); + CHECK(get_headers(nut, !!info)); - if (!nut->sc) { - ERROR(SIZE_MAX/sizeof(stream_context_t) < nut->stream_count+1, -ERR_OUT_OF_MEM); - nut->sc = nut->alloc->malloc(sizeof(stream_context_t) * nut->stream_count); - ERROR(!nut->sc, -ERR_OUT_OF_MEM); - memset(nut->sc, 0, sizeof(stream_context_t) * nut->stream_count); - } - - for (i = 0; i < nut->stream_count; i++) { - int j; - CHECK(get_bytes(nut->i, 8, &tmp)); - while (tmp != STREAM_STARTCODE) { - ERROR(tmp >> 56 != 'N', -ERR_NOSTREAM_STARTCODE); - CHECK(get_header(nut->i, NULL)); - CHECK(get_bytes(nut->i, 8, &tmp)); - } - CHECK(get_stream_header(nut, i)); - if (!nut->sc[i].pts_cache) { - ERROR(SIZE_MAX/sizeof(int64_t) < nut->sc[i].sh.decode_delay, -ERR_OUT_OF_MEM); - nut->sc[i].pts_cache = nut->alloc->malloc(nut->sc[i].sh.decode_delay * sizeof(int64_t)); - ERROR(!nut->sc[i].pts_cache, -ERR_OUT_OF_MEM); - for (j = 0; j < nut->sc[i].sh.decode_delay; j++) - nut->sc[i].pts_cache[j] = -1; - } - } - if (info) { - CHECK(get_bytes(nut->i, 8, &tmp)); - while (tmp == INFO_STARTCODE) { - nut->info_count++; - ERROR(SIZE_MAX/sizeof(nut_info_packet_t) < nut->info_count + 1, -ERR_OUT_OF_MEM); - nut->info = nut->alloc->realloc(nut->info, sizeof(nut_info_packet_t) * (nut->info_count + 1)); - ERROR(!nut->info, -ERR_OUT_OF_MEM); - memset(&nut->info[nut->info_count - 1], 0, sizeof(nut_info_packet_t)); - nut->info[nut->info_count].count = -1; - CHECK(get_info_header(nut, &nut->info[nut->info_count - 1])); - CHECK(get_bytes(nut->i, 8, &tmp)); - } - nut->i->buf_ptr -= 8; - } - if (nut->dopts.read_index) { + if (nut->dopts.read_index) { // check for index right after main headers + CHECK(skip_reserved_headers(nut, INDEX_STARTCODE)); CHECK(get_bytes(nut->i, 8, &tmp)); - while (tmp >> 56 == 'N') { - if (tmp == INDEX_STARTCODE || tmp == SYNCPOINT_STARTCODE) break; - CHECK(get_header(nut->i, NULL)); - CHECK(get_bytes(nut->i, 8, &tmp)); - } - if (tmp == INDEX_STARTCODE) nut->seek_status = 2; nut->i->buf_ptr -= 8; + if (tmp == INDEX_STARTCODE) nut->seek_status = 2; // signals to not seek to find index flush_buf(nut->i); } } - if (nut->dopts.read_index & 1) { + if (nut->dopts.read_index & 1) { // we already have index, we were called just for the final syncpoint search uint64_t idx_ptr; if (nut->seek_status <= 1) { if (nut->seek_status == 0) { @@ -906,38 +915,16 @@ nut->before_seek = 0; } + CHECK(skip_reserved_headers(nut, SYNCPOINT_STARTCODE)); CHECK(get_bytes(nut->i, 8, &tmp)); - while (tmp >> 56 == 'N') { - if (tmp == SYNCPOINT_STARTCODE) break; - if ((err = get_header(nut->i, NULL)) == 2) goto err_out; - if (err) break; - CHECK(get_bytes(nut->i, 8, &tmp)); - } nut->i->buf_ptr -= 8; - if (err || tmp != SYNCPOINT_STARTCODE) { - nut->seek_status = 1; // enter error mode - nut->i->buf_ptr = nut->i->buf; // rewind as much as possible - err = 0; - } else { - nut->seek_status = 0; - } + nut->seek_status = (tmp != SYNCPOINT_STARTCODE); // enter error mode if we're not at a syncpoint - *s = nut->alloc->malloc(sizeof(nut_stream_header_t) * (nut->stream_count + 1)); - ERROR(!*s, -ERR_OUT_OF_MEM); + SAFE_CALLOC(nut->alloc, *s, sizeof(nut_stream_header_t), nut->stream_count + 1); for (i = 0; i < nut->stream_count; i++) (*s)[i] = nut->sc[i].sh; (*s)[i].type = -1; if (info) *info = nut->info; err_out: - if (err && err != 2 && !nut->seek_status) { - if (nut->sc) for (i = 0; i < nut->stream_count; i++) { - nut->alloc->free(nut->sc[i].sh.fourcc); - nut->alloc->free(nut->sc[i].sh.codec_specific); - nut->alloc->free(nut->sc[i].pts_cache); - } - nut->alloc->free(nut->sc); - nut->sc = NULL; - nut->stream_count = 0; - } if (err != 2) flush_buf(nut->i); // unless EAGAIN else nut->i->buf_ptr = nut->i->buf; // rewind return err;