[NUT-devel] [nut]: r206 - trunk/libnut/demuxer.c
ods15
subversion at mplayerhq.hu
Wed Nov 15 11:04:00 CET 2006
Author: ods15
Date: Wed Nov 15 11:04:00 2006
New Revision: 206
Modified:
trunk/libnut/demuxer.c
Log:
killl all EAGAIN issues in main header reading which are dangerous to
malloc stuff by verifying all headers before parsing them
make SAFE_CALLOC and SAFE_REALLOC macros
add some helpful comments
Modified: trunk/libnut/demuxer.c
==============================================================================
--- trunk/libnut/demuxer.c (original)
+++ trunk/libnut/demuxer.c Wed Nov 15 11:04:00 2006
@@ -185,6 +185,17 @@
#define ERROR(expr, code) do { if (expr) { err = code; goto err_out; } } while(0)
#define GET_V(bc, v) do { uint64_t _tmp; CHECK(get_v_((bc), &_tmp, #v)); (v) = _tmp; } while(0)
#define GET_S(bc, v) do { int64_t _tmp; CHECK(get_s_((bc), &_tmp, #v)); (v) = _tmp; } while(0)
+#define SAFE_CALLOC(alloc, var, a, b) do { \
+ ERROR(SIZE_MAX/(a) < (b), -ERR_OUT_OF_MEM); \
+ ERROR(!((var) = (alloc)->malloc((a) * (b))), -ERR_OUT_OF_MEM); \
+ memset((var), 0, (a) * (b)); \
+} while(0)
+#define SAFE_REALLOC(alloc, var, a, b) do { \
+ void * _tmp; \
+ ERROR(SIZE_MAX/(a) < (b), -ERR_OUT_OF_MEM); \
+ ERROR(!((_tmp) = (alloc)->realloc((var), (a) * (b))), -ERR_OUT_OF_MEM); \
+ (var) = _tmp; \
+} while(0)
static int get_data(input_buffer_t * bc, int len, uint8_t * buf) {
int tmp;
@@ -347,12 +358,7 @@
GET_V(tmp, info->chapter_len);
GET_V(tmp, info->count);
- if (!info->fields) {
- ERROR(SIZE_MAX/sizeof(nut_info_field_t) < info->count, -ERR_OUT_OF_MEM);
- info->fields = nut->alloc->malloc(info->count * sizeof(nut_info_field_t));
- ERROR(!nut->tb, -ERR_OUT_OF_MEM);
- memset(info->fields, 0, info->count * sizeof(nut_info_field_t)); // initialize pointer to NULL...
- }
+ SAFE_CALLOC(nut->alloc, info->fields, sizeof(nut_info_field_t), info->count);
for (i = 0; i < info->count; i++) {
int len;
@@ -395,9 +401,59 @@
return err;
}
+static int skip_reserved_headers(nut_context_t * nut, uint64_t stop_startcode) {
+ int err;
+ uint64_t tmp;
+ CHECK(get_bytes(nut->i, 8, &tmp));
+ while (tmp >> 56 != 'N') {
+ if (tmp == stop_startcode || tmp == SYNCPOINT_STARTCODE) break;
+ CHECK(get_header(nut->i, NULL));
+ CHECK(get_bytes(nut->i, 8, &tmp));
+ }
+ nut->i->buf_ptr -= 8;
+err_out:
+ return err;
+}
+
+static int get_headers(nut_context_t * nut, int read_info) {
+ int i, err;
+ uint64_t tmp;
+
+ CHECK(get_bytes(nut->i, 8, &tmp));
+ assert(tmp == MAIN_STARTCODE); // sanity, get_headers should only be called in this situation
+ CHECK(get_main_header(nut));
+
+ SAFE_CALLOC(nut->alloc, nut->sc, sizeof(stream_context_t), nut->stream_count);
+
+ for (i = 0; i < nut->stream_count; i++) {
+ int j;
+ CHECK(skip_reserved_headers(nut, STREAM_STARTCODE));
+ CHECK(get_bytes(nut->i, 8, &tmp));
+ ERROR(tmp != STREAM_STARTCODE, -ERR_NOSTREAM_STARTCODE);
+ CHECK(get_stream_header(nut, i));
+ SAFE_CALLOC(nut->alloc, nut->sc[i].pts_cache, sizeof(int64_t), nut->sc[i].sh.decode_delay);
+ for (j = 0; j < nut->sc[i].sh.decode_delay; j++) nut->sc[i].pts_cache[j] = -1;
+ }
+ if (read_info) {
+ CHECK(get_bytes(nut->i, 8, &tmp));
+ while (tmp == INFO_STARTCODE) {
+ nut->info_count++;
+ SAFE_REALLOC(nut->alloc, nut->info, sizeof(nut_info_packet_t), nut->info_count + 1);
+ memset(&nut->info[nut->info_count - 1], 0, sizeof(nut_info_packet_t));
+ CHECK(get_info_header(nut, &nut->info[nut->info_count - 1]));
+ CHECK(get_bytes(nut->i, 8, &tmp));
+ }
+ nut->info[nut->info_count].count = -1;
+ nut->i->buf_ptr -= 8;
+ }
+err_out:
+ assert(err != 2); // EAGAIN is illegal here!!
+ return err;
+}
+
static int add_syncpoint(nut_context_t * nut, syncpoint_t sp, uint64_t * pts, uint64_t * eor, int * out) {
syncpoint_list_t * sl = &nut->syncpoints;
- int i, j;
+ int i, j, err = 0;
assert(nut->dopts.cache_syncpoints & 1 || !pts); // pts information is never stored with no syncpoint cache
for (i = sl->len; i--; ) { // more often than not, we're adding at end of list
@@ -422,21 +478,11 @@
}
i++;
if (sl->len + 1 > sl->alloc_len) {
- void * a;
sl->alloc_len += PREALLOC_SIZE/4;
- if (SIZE_MAX/sl->alloc_len < sizeof(syncpoint_t) ||
- SIZE_MAX/sl->alloc_len < sizeof(uint64_t) * nut->stream_count)
- return -ERR_OUT_OF_MEM;
- a = nut->alloc->realloc(sl->s, sl->alloc_len * sizeof(syncpoint_t));
- if (!a) return -ERR_OUT_OF_MEM;
- sl->s = a;
+ SAFE_REALLOC(nut->alloc, sl->s, sizeof(syncpoint_t), sl->alloc_len);
if (nut->dopts.cache_syncpoints & 1) {
- a = nut->alloc->realloc(sl->pts, sl->alloc_len * nut->stream_count * sizeof(uint64_t));
- if (!a) return -ERR_OUT_OF_MEM;
- sl->pts = a;
- a = nut->alloc->realloc(sl->eor, sl->alloc_len * nut->stream_count * sizeof(uint64_t));
- if (!a) return -ERR_OUT_OF_MEM;
- sl->eor = a;
+ SAFE_REALLOC(nut->alloc, sl->pts, nut->stream_count * sizeof(uint64_t), sl->alloc_len);
+ SAFE_REALLOC(nut->alloc, sl->eor, nut->stream_count * sizeof(uint64_t), sl->alloc_len);
}
}
memmove(sl->s + i + 1, sl->s + i, (sl->len - i) * sizeof(syncpoint_t));
@@ -454,7 +500,8 @@
sl->len++;
if (out) *out = i;
- return 0;
+err_out:
+ return err;
}
static void set_global_pts(nut_context_t * nut, uint64_t pts) {
@@ -518,7 +565,6 @@
syncpoint_list_t * sl = &nut->syncpoints;
uint64_t x;
int i;
- void * a, * b, * c;
CHECK(get_bytes(nut->i, 8, &x));
ERROR(x != INDEX_STARTCODE, -ERR_GENERAL_ERROR);
@@ -532,16 +578,10 @@
}
GET_V(tmp, x);
- ERROR(SIZE_MAX/x < sizeof(syncpoint_t) || SIZE_MAX/x < sizeof(uint64_t) * nut->stream_count, -ERR_OUT_OF_MEM);
-
sl->alloc_len = sl->len = x;
- a = nut->alloc->realloc(sl->s, sl->alloc_len * sizeof(syncpoint_t));
- b = nut->alloc->realloc(sl->pts, sl->alloc_len * nut->stream_count * sizeof(uint64_t));
- c = nut->alloc->realloc(sl->eor, sl->alloc_len * nut->stream_count * sizeof(uint64_t));
- ERROR(!a || !b || !c, -ERR_OUT_OF_MEM);
- sl->s = a;
- sl->pts = b;
- sl->eor = c;
+ SAFE_REALLOC(nut->alloc, sl->s, sizeof(syncpoint_t), sl->alloc_len);
+ SAFE_REALLOC(nut->alloc, sl->pts, nut->stream_count * sizeof(uint64_t), sl->alloc_len);
+ SAFE_REALLOC(nut->alloc, sl->eor, nut->stream_count * sizeof(uint64_t), sl->alloc_len);
for (i = 0; i < sl->len; i++) {
GET_V(tmp, sl->s[i].pos);
@@ -710,6 +750,31 @@
else sc->eor = 0;
}
+static int find_main_headers(nut_context_t * nut) {
+ int err = 0;
+ uint64_t tmp;
+ off_t start = bctello(nut->i);
+ if (start < strlen(ID_STRING) + 1) {
+ int n = strlen(ID_STRING) + 1 - start;
+ ERROR(ready_read_buf(nut->i, n) < n, buf_eof(nut->i));
+ if (memcmp(get_buf(nut->i, start), ID_STRING + start, n)) nut->i->buf_ptr = nut->i->buf; // rewind
+ else fprintf(stderr, "NUT file_id checks out\n");
+ }
+
+ CHECK(get_bytes(nut->i, 7, &tmp));
+ ERROR(ready_read_buf(nut->i, 4096) < 4096, buf_eof(nut->i));
+ while (bctello(nut->i) < 4096) {
+ tmp = (tmp << 8) | *(nut->i->buf_ptr++);
+ if (tmp == MAIN_STARTCODE) break;
+ }
+ ERROR(tmp != MAIN_STARTCODE, -ERR_NO_HEADERS);
+ nut->i->buf_ptr -= 8;
+ nut->last_headers = bctello(nut->i);
+ flush_buf(nut->i);
+err_out:
+ return err;
+}
+
static int find_syncpoint(nut_context_t * nut, int backwards, syncpoint_t * res, off_t stop) {
int read;
int err = 0;
@@ -806,82 +871,26 @@
int nut_read_headers(nut_context_t * nut, nut_stream_header_t * s [], nut_info_packet_t * info []) {
int i, err = 0;
uint64_t tmp;
- *s = NULL;
- if (!nut->seek_status) { // we already have headers, we were called just for index
- if (!nut->last_headers) {
- off_t start = bctello(nut->i);
- if (start < strlen(ID_STRING) + 1) {
- int n = strlen(ID_STRING) + 1 - start;
- ERROR(ready_read_buf(nut->i, n) < n, buf_eof(nut->i));
- if (memcmp(get_buf(nut->i, start), ID_STRING + start, n)) nut->i->buf_ptr = nut->i->buf; // rewind
- fprintf(stderr, "NUT file_id checks out\n");
- }
+ if (!nut->sc) { // we already have headers, we were called just for index
+ if (!nut->last_headers) CHECK(find_main_headers(nut));
- CHECK(get_bytes(nut->i, 7, &tmp));
- ERROR(ready_read_buf(nut->i, 4096) < 4096, buf_eof(nut->i));
- while (bctello(nut->i) < 4096) {
- tmp = (tmp << 8) | *(nut->i->buf_ptr++);
- if (tmp == MAIN_STARTCODE) break;
- }
- ERROR(tmp != MAIN_STARTCODE, -ERR_NO_HEADERS);
- nut->last_headers = bctello(nut->i) - 8;
- flush_buf(nut->i);
- }
+ // load all headers into memory so they can be cleanly decoded without EAGAIN issues
+ CHECK(skip_reserved_headers(nut, SYNCPOINT_STARTCODE));
- CHECK(get_main_header(nut));
+ // rewind to where the headers were found
+ nut->i->buf_ptr = get_buf(nut->i, nut->last_headers);
+ CHECK(get_headers(nut, !!info));
- if (!nut->sc) {
- ERROR(SIZE_MAX/sizeof(stream_context_t) < nut->stream_count+1, -ERR_OUT_OF_MEM);
- nut->sc = nut->alloc->malloc(sizeof(stream_context_t) * nut->stream_count);
- ERROR(!nut->sc, -ERR_OUT_OF_MEM);
- memset(nut->sc, 0, sizeof(stream_context_t) * nut->stream_count);
- }
-
- for (i = 0; i < nut->stream_count; i++) {
- int j;
- CHECK(get_bytes(nut->i, 8, &tmp));
- while (tmp != STREAM_STARTCODE) {
- ERROR(tmp >> 56 != 'N', -ERR_NOSTREAM_STARTCODE);
- CHECK(get_header(nut->i, NULL));
- CHECK(get_bytes(nut->i, 8, &tmp));
- }
- CHECK(get_stream_header(nut, i));
- if (!nut->sc[i].pts_cache) {
- ERROR(SIZE_MAX/sizeof(int64_t) < nut->sc[i].sh.decode_delay, -ERR_OUT_OF_MEM);
- nut->sc[i].pts_cache = nut->alloc->malloc(nut->sc[i].sh.decode_delay * sizeof(int64_t));
- ERROR(!nut->sc[i].pts_cache, -ERR_OUT_OF_MEM);
- for (j = 0; j < nut->sc[i].sh.decode_delay; j++)
- nut->sc[i].pts_cache[j] = -1;
- }
- }
- if (info) {
- CHECK(get_bytes(nut->i, 8, &tmp));
- while (tmp == INFO_STARTCODE) {
- nut->info_count++;
- ERROR(SIZE_MAX/sizeof(nut_info_packet_t) < nut->info_count + 1, -ERR_OUT_OF_MEM);
- nut->info = nut->alloc->realloc(nut->info, sizeof(nut_info_packet_t) * (nut->info_count + 1));
- ERROR(!nut->info, -ERR_OUT_OF_MEM);
- memset(&nut->info[nut->info_count - 1], 0, sizeof(nut_info_packet_t));
- nut->info[nut->info_count].count = -1;
- CHECK(get_info_header(nut, &nut->info[nut->info_count - 1]));
- CHECK(get_bytes(nut->i, 8, &tmp));
- }
- nut->i->buf_ptr -= 8;
- }
- if (nut->dopts.read_index) {
+ if (nut->dopts.read_index) { // check for index right after main headers
+ CHECK(skip_reserved_headers(nut, INDEX_STARTCODE));
CHECK(get_bytes(nut->i, 8, &tmp));
- while (tmp >> 56 == 'N') {
- if (tmp == INDEX_STARTCODE || tmp == SYNCPOINT_STARTCODE) break;
- CHECK(get_header(nut->i, NULL));
- CHECK(get_bytes(nut->i, 8, &tmp));
- }
- if (tmp == INDEX_STARTCODE) nut->seek_status = 2;
nut->i->buf_ptr -= 8;
+ if (tmp == INDEX_STARTCODE) nut->seek_status = 2; // signals to not seek to find index
flush_buf(nut->i);
}
}
- if (nut->dopts.read_index & 1) {
+ if (nut->dopts.read_index & 1) { // we already have index, we were called just for the final syncpoint search
uint64_t idx_ptr;
if (nut->seek_status <= 1) {
if (nut->seek_status == 0) {
@@ -906,38 +915,16 @@
nut->before_seek = 0;
}
+ CHECK(skip_reserved_headers(nut, SYNCPOINT_STARTCODE));
CHECK(get_bytes(nut->i, 8, &tmp));
- while (tmp >> 56 == 'N') {
- if (tmp == SYNCPOINT_STARTCODE) break;
- if ((err = get_header(nut->i, NULL)) == 2) goto err_out;
- if (err) break;
- CHECK(get_bytes(nut->i, 8, &tmp));
- }
nut->i->buf_ptr -= 8;
- if (err || tmp != SYNCPOINT_STARTCODE) {
- nut->seek_status = 1; // enter error mode
- nut->i->buf_ptr = nut->i->buf; // rewind as much as possible
- err = 0;
- } else {
- nut->seek_status = 0;
- }
+ nut->seek_status = (tmp != SYNCPOINT_STARTCODE); // enter error mode if we're not at a syncpoint
- *s = nut->alloc->malloc(sizeof(nut_stream_header_t) * (nut->stream_count + 1));
- ERROR(!*s, -ERR_OUT_OF_MEM);
+ SAFE_CALLOC(nut->alloc, *s, sizeof(nut_stream_header_t), nut->stream_count + 1);
for (i = 0; i < nut->stream_count; i++) (*s)[i] = nut->sc[i].sh;
(*s)[i].type = -1;
if (info) *info = nut->info;
err_out:
- if (err && err != 2 && !nut->seek_status) {
- if (nut->sc) for (i = 0; i < nut->stream_count; i++) {
- nut->alloc->free(nut->sc[i].sh.fourcc);
- nut->alloc->free(nut->sc[i].sh.codec_specific);
- nut->alloc->free(nut->sc[i].pts_cache);
- }
- nut->alloc->free(nut->sc);
- nut->sc = NULL;
- nut->stream_count = 0;
- }
if (err != 2) flush_buf(nut->i); // unless EAGAIN
else nut->i->buf_ptr = nut->i->buf; // rewind
return err;
More information about the NUT-devel
mailing list