[NUT-devel] [nut]: r206 - trunk/libnut/demuxer.c

ods15 subversion at mplayerhq.hu
Wed Nov 15 11:04:00 CET 2006


Author: ods15
Date: Wed Nov 15 11:04:00 2006
New Revision: 206

Modified:
   trunk/libnut/demuxer.c

Log:
killl all EAGAIN issues in main header reading which are dangerous to
malloc stuff by verifying all headers before parsing them
make SAFE_CALLOC and SAFE_REALLOC macros
add some helpful comments


Modified: trunk/libnut/demuxer.c
==============================================================================
--- trunk/libnut/demuxer.c	(original)
+++ trunk/libnut/demuxer.c	Wed Nov 15 11:04:00 2006
@@ -185,6 +185,17 @@
 #define ERROR(expr, code) do { if (expr) { err = code; goto err_out; } } while(0)
 #define GET_V(bc, v) do { uint64_t _tmp; CHECK(get_v_((bc), &_tmp, #v)); (v) = _tmp; } while(0)
 #define GET_S(bc, v) do {  int64_t _tmp; CHECK(get_s_((bc), &_tmp, #v)); (v) = _tmp; } while(0)
+#define SAFE_CALLOC(alloc, var, a, b) do { \
+	ERROR(SIZE_MAX/(a) < (b), -ERR_OUT_OF_MEM); \
+	ERROR(!((var) = (alloc)->malloc((a) * (b))), -ERR_OUT_OF_MEM); \
+	memset((var), 0, (a) * (b)); \
+} while(0)
+#define SAFE_REALLOC(alloc, var, a, b) do { \
+	void * _tmp; \
+	ERROR(SIZE_MAX/(a) < (b), -ERR_OUT_OF_MEM); \
+	ERROR(!((_tmp) = (alloc)->realloc((var), (a) * (b))), -ERR_OUT_OF_MEM); \
+	(var) = _tmp; \
+} while(0)
 
 static int get_data(input_buffer_t * bc, int len, uint8_t * buf) {
 	int tmp;
@@ -347,12 +358,7 @@
 	GET_V(tmp, info->chapter_len);
 
 	GET_V(tmp, info->count);
-	if (!info->fields) {
-		ERROR(SIZE_MAX/sizeof(nut_info_field_t) < info->count, -ERR_OUT_OF_MEM);
-		info->fields = nut->alloc->malloc(info->count * sizeof(nut_info_field_t));
-		ERROR(!nut->tb, -ERR_OUT_OF_MEM);
-		memset(info->fields, 0, info->count * sizeof(nut_info_field_t)); // initialize pointer to NULL...
-	}
+	SAFE_CALLOC(nut->alloc, info->fields, sizeof(nut_info_field_t), info->count);
 
 	for (i = 0; i < info->count; i++) {
 		int len;
@@ -395,9 +401,59 @@
 	return err;
 }
 
+static int skip_reserved_headers(nut_context_t * nut, uint64_t stop_startcode) {
+	int err;
+	uint64_t tmp;
+	CHECK(get_bytes(nut->i, 8, &tmp));
+	while (tmp >> 56 != 'N') {
+		if (tmp == stop_startcode || tmp == SYNCPOINT_STARTCODE) break;
+		CHECK(get_header(nut->i, NULL));
+		CHECK(get_bytes(nut->i, 8, &tmp));
+	}
+	nut->i->buf_ptr -= 8;
+err_out:
+	return err;
+}
+
+static int get_headers(nut_context_t * nut, int read_info) {
+	int i, err;
+	uint64_t tmp;
+
+	CHECK(get_bytes(nut->i, 8, &tmp));
+	assert(tmp == MAIN_STARTCODE); // sanity, get_headers should only be called in this situation
+	CHECK(get_main_header(nut));
+
+	SAFE_CALLOC(nut->alloc, nut->sc, sizeof(stream_context_t), nut->stream_count);
+
+	for (i = 0; i < nut->stream_count; i++) {
+		int j;
+		CHECK(skip_reserved_headers(nut, STREAM_STARTCODE));
+		CHECK(get_bytes(nut->i, 8, &tmp));
+		ERROR(tmp != STREAM_STARTCODE, -ERR_NOSTREAM_STARTCODE);
+		CHECK(get_stream_header(nut, i));
+		SAFE_CALLOC(nut->alloc, nut->sc[i].pts_cache, sizeof(int64_t), nut->sc[i].sh.decode_delay);
+		for (j = 0; j < nut->sc[i].sh.decode_delay; j++) nut->sc[i].pts_cache[j] = -1;
+	}
+	if (read_info) {
+		CHECK(get_bytes(nut->i, 8, &tmp));
+		while (tmp == INFO_STARTCODE) {
+			nut->info_count++;
+			SAFE_REALLOC(nut->alloc, nut->info, sizeof(nut_info_packet_t), nut->info_count + 1);
+			memset(&nut->info[nut->info_count - 1], 0, sizeof(nut_info_packet_t));
+			CHECK(get_info_header(nut, &nut->info[nut->info_count - 1]));
+			CHECK(get_bytes(nut->i, 8, &tmp));
+		}
+		nut->info[nut->info_count].count = -1;
+		nut->i->buf_ptr -= 8;
+	}
+err_out:
+	assert(err != 2); // EAGAIN is illegal here!!
+	return err;
+}
+
 static int add_syncpoint(nut_context_t * nut, syncpoint_t sp, uint64_t * pts, uint64_t * eor, int * out) {
 	syncpoint_list_t * sl = &nut->syncpoints;
-	int i, j;
+	int i, j, err = 0;
 
 	assert(nut->dopts.cache_syncpoints & 1 || !pts); // pts information is never stored with no syncpoint cache
 	for (i = sl->len; i--; ) { // more often than not, we're adding at end of list
@@ -422,21 +478,11 @@
 	}
 	i++;
 	if (sl->len + 1 > sl->alloc_len) {
-		void * a;
 		sl->alloc_len += PREALLOC_SIZE/4;
-		if (SIZE_MAX/sl->alloc_len < sizeof(syncpoint_t) ||
-		    SIZE_MAX/sl->alloc_len < sizeof(uint64_t) * nut->stream_count)
-			return -ERR_OUT_OF_MEM;
-		a = nut->alloc->realloc(sl->s, sl->alloc_len * sizeof(syncpoint_t));
-		if (!a) return -ERR_OUT_OF_MEM;
-		sl->s = a;
+		SAFE_REALLOC(nut->alloc, sl->s, sizeof(syncpoint_t), sl->alloc_len);
 		if (nut->dopts.cache_syncpoints & 1) {
-			a = nut->alloc->realloc(sl->pts, sl->alloc_len * nut->stream_count * sizeof(uint64_t));
-			if (!a) return -ERR_OUT_OF_MEM;
-			sl->pts = a;
-			a = nut->alloc->realloc(sl->eor, sl->alloc_len * nut->stream_count * sizeof(uint64_t));
-			if (!a) return -ERR_OUT_OF_MEM;
-			sl->eor = a;
+			SAFE_REALLOC(nut->alloc, sl->pts, nut->stream_count * sizeof(uint64_t), sl->alloc_len);
+			SAFE_REALLOC(nut->alloc, sl->eor, nut->stream_count * sizeof(uint64_t), sl->alloc_len);
 		}
 	}
 	memmove(sl->s + i + 1, sl->s + i, (sl->len - i) * sizeof(syncpoint_t));
@@ -454,7 +500,8 @@
 
 	sl->len++;
 	if (out) *out = i;
-	return 0;
+err_out:
+	return err;
 }
 
 static void set_global_pts(nut_context_t * nut, uint64_t pts) {
@@ -518,7 +565,6 @@
 	syncpoint_list_t * sl = &nut->syncpoints;
 	uint64_t x;
 	int i;
-	void * a, * b, * c;
 
 	CHECK(get_bytes(nut->i, 8, &x));
 	ERROR(x != INDEX_STARTCODE, -ERR_GENERAL_ERROR);
@@ -532,16 +578,10 @@
 	}
 
 	GET_V(tmp, x);
-	ERROR(SIZE_MAX/x < sizeof(syncpoint_t) || SIZE_MAX/x < sizeof(uint64_t) * nut->stream_count, -ERR_OUT_OF_MEM);
-
 	sl->alloc_len = sl->len = x;
-	a = nut->alloc->realloc(sl->s, sl->alloc_len * sizeof(syncpoint_t));
-	b = nut->alloc->realloc(sl->pts, sl->alloc_len * nut->stream_count * sizeof(uint64_t));
-	c = nut->alloc->realloc(sl->eor, sl->alloc_len * nut->stream_count * sizeof(uint64_t));
-	ERROR(!a || !b || !c, -ERR_OUT_OF_MEM);
-	sl->s = a;
-	sl->pts = b;
-	sl->eor = c;
+	SAFE_REALLOC(nut->alloc, sl->s, sizeof(syncpoint_t), sl->alloc_len);
+	SAFE_REALLOC(nut->alloc, sl->pts, nut->stream_count * sizeof(uint64_t), sl->alloc_len);
+	SAFE_REALLOC(nut->alloc, sl->eor, nut->stream_count * sizeof(uint64_t), sl->alloc_len);
 
 	for (i = 0; i < sl->len; i++) {
 		GET_V(tmp, sl->s[i].pos);
@@ -710,6 +750,31 @@
 	else sc->eor = 0;
 }
 
+static int find_main_headers(nut_context_t * nut) {
+	int err = 0;
+	uint64_t tmp;
+	off_t start = bctello(nut->i);
+	if (start < strlen(ID_STRING) + 1) {
+		int n = strlen(ID_STRING) + 1 - start;
+		ERROR(ready_read_buf(nut->i, n) < n, buf_eof(nut->i));
+		if (memcmp(get_buf(nut->i, start), ID_STRING + start, n)) nut->i->buf_ptr = nut->i->buf; // rewind
+		else fprintf(stderr, "NUT file_id checks out\n");
+	}
+
+	CHECK(get_bytes(nut->i, 7, &tmp));
+	ERROR(ready_read_buf(nut->i, 4096) < 4096, buf_eof(nut->i));
+	while (bctello(nut->i) < 4096) {
+		tmp = (tmp << 8) | *(nut->i->buf_ptr++);
+		if (tmp == MAIN_STARTCODE) break;
+	}
+	ERROR(tmp != MAIN_STARTCODE, -ERR_NO_HEADERS);
+	nut->i->buf_ptr -= 8;
+	nut->last_headers = bctello(nut->i);
+	flush_buf(nut->i);
+err_out:
+	return err;
+}
+
 static int find_syncpoint(nut_context_t * nut, int backwards, syncpoint_t * res, off_t stop) {
 	int read;
 	int err = 0;
@@ -806,82 +871,26 @@
 int nut_read_headers(nut_context_t * nut, nut_stream_header_t * s [], nut_info_packet_t * info []) {
 	int i, err = 0;
 	uint64_t tmp;
-	*s = NULL;
-	if (!nut->seek_status) { // we already have headers, we were called just for index
-		if (!nut->last_headers) {
-			off_t start = bctello(nut->i);
-			if (start < strlen(ID_STRING) + 1) {
-				int n = strlen(ID_STRING) + 1 - start;
-				ERROR(ready_read_buf(nut->i, n) < n, buf_eof(nut->i));
-				if (memcmp(get_buf(nut->i, start), ID_STRING + start, n)) nut->i->buf_ptr = nut->i->buf; // rewind
-				fprintf(stderr, "NUT file_id checks out\n");
-			}
+	if (!nut->sc) { // we already have headers, we were called just for index
+		if (!nut->last_headers) CHECK(find_main_headers(nut));
 
-			CHECK(get_bytes(nut->i, 7, &tmp));
-			ERROR(ready_read_buf(nut->i, 4096) < 4096, buf_eof(nut->i));
-			while (bctello(nut->i) < 4096) {
-				tmp = (tmp << 8) | *(nut->i->buf_ptr++);
-				if (tmp == MAIN_STARTCODE) break;
-			}
-			ERROR(tmp != MAIN_STARTCODE, -ERR_NO_HEADERS);
-			nut->last_headers = bctello(nut->i) - 8;
-			flush_buf(nut->i);
-		}
+		// load all headers into memory so they can be cleanly decoded without EAGAIN issues
+		CHECK(skip_reserved_headers(nut, SYNCPOINT_STARTCODE));
 
-		CHECK(get_main_header(nut));
+		// rewind to where the headers were found
+		nut->i->buf_ptr = get_buf(nut->i, nut->last_headers);
+		CHECK(get_headers(nut, !!info));
 
-		if (!nut->sc) {
-			ERROR(SIZE_MAX/sizeof(stream_context_t) < nut->stream_count+1, -ERR_OUT_OF_MEM);
-			nut->sc = nut->alloc->malloc(sizeof(stream_context_t) * nut->stream_count);
-			ERROR(!nut->sc, -ERR_OUT_OF_MEM);
-			memset(nut->sc, 0, sizeof(stream_context_t) * nut->stream_count);
-		}
-
-		for (i = 0; i < nut->stream_count; i++) {
-			int j;
-			CHECK(get_bytes(nut->i, 8, &tmp));
-			while (tmp != STREAM_STARTCODE) {
-				ERROR(tmp >> 56 != 'N', -ERR_NOSTREAM_STARTCODE);
-				CHECK(get_header(nut->i, NULL));
-				CHECK(get_bytes(nut->i, 8, &tmp));
-			}
-			CHECK(get_stream_header(nut, i));
-			if (!nut->sc[i].pts_cache) {
-				ERROR(SIZE_MAX/sizeof(int64_t) < nut->sc[i].sh.decode_delay, -ERR_OUT_OF_MEM);
-				nut->sc[i].pts_cache = nut->alloc->malloc(nut->sc[i].sh.decode_delay * sizeof(int64_t));
-				ERROR(!nut->sc[i].pts_cache, -ERR_OUT_OF_MEM);
-				for (j = 0; j < nut->sc[i].sh.decode_delay; j++)
-					nut->sc[i].pts_cache[j] = -1;
-			}
-		}
-		if (info) {
-			CHECK(get_bytes(nut->i, 8, &tmp));
-			while (tmp == INFO_STARTCODE) {
-				nut->info_count++;
-				ERROR(SIZE_MAX/sizeof(nut_info_packet_t) < nut->info_count + 1, -ERR_OUT_OF_MEM);
-				nut->info = nut->alloc->realloc(nut->info, sizeof(nut_info_packet_t) * (nut->info_count + 1));
-				ERROR(!nut->info, -ERR_OUT_OF_MEM);
-				memset(&nut->info[nut->info_count - 1], 0, sizeof(nut_info_packet_t));
-				nut->info[nut->info_count].count = -1;
-				CHECK(get_info_header(nut, &nut->info[nut->info_count - 1]));
-				CHECK(get_bytes(nut->i, 8, &tmp));
-			}
-			nut->i->buf_ptr -= 8;
-		}
-		if (nut->dopts.read_index) {
+		if (nut->dopts.read_index) { // check for index right after main headers
+			CHECK(skip_reserved_headers(nut, INDEX_STARTCODE));
 			CHECK(get_bytes(nut->i, 8, &tmp));
-			while (tmp >> 56 == 'N') {
-				if (tmp == INDEX_STARTCODE || tmp == SYNCPOINT_STARTCODE) break;
-				CHECK(get_header(nut->i, NULL));
-				CHECK(get_bytes(nut->i, 8, &tmp));
-			}
-			if (tmp == INDEX_STARTCODE) nut->seek_status = 2;
 			nut->i->buf_ptr -= 8;
+			if (tmp == INDEX_STARTCODE) nut->seek_status = 2; // signals to not seek to find index
 			flush_buf(nut->i);
 		}
 	}
 
-	if (nut->dopts.read_index & 1) {
+	if (nut->dopts.read_index & 1) { // we already have index, we were called just for the final syncpoint search
 		uint64_t idx_ptr;
 		if (nut->seek_status <= 1) {
 			if (nut->seek_status == 0) {
@@ -906,38 +915,16 @@
 		nut->before_seek = 0;
 	}
 
+	CHECK(skip_reserved_headers(nut, SYNCPOINT_STARTCODE));
 	CHECK(get_bytes(nut->i, 8, &tmp));
-	while (tmp >> 56 == 'N') {
-		if (tmp == SYNCPOINT_STARTCODE) break;
-		if ((err = get_header(nut->i, NULL)) == 2) goto err_out;
-		if (err) break;
-		CHECK(get_bytes(nut->i, 8, &tmp));
-	}
 	nut->i->buf_ptr -= 8;
-	if (err || tmp != SYNCPOINT_STARTCODE) {
-		nut->seek_status = 1; // enter error mode
-		nut->i->buf_ptr = nut->i->buf; // rewind as much as possible
-		err = 0;
-	} else {
-		nut->seek_status = 0;
-	}
+	nut->seek_status = (tmp != SYNCPOINT_STARTCODE); // enter error mode if we're not at a syncpoint
 
-	*s = nut->alloc->malloc(sizeof(nut_stream_header_t) * (nut->stream_count + 1));
-	ERROR(!*s, -ERR_OUT_OF_MEM);
+	SAFE_CALLOC(nut->alloc, *s, sizeof(nut_stream_header_t), nut->stream_count + 1);
 	for (i = 0; i < nut->stream_count; i++) (*s)[i] = nut->sc[i].sh;
 	(*s)[i].type = -1;
 	if (info) *info = nut->info;
 err_out:
-	if (err && err != 2 && !nut->seek_status) {
-		if (nut->sc) for (i = 0; i < nut->stream_count; i++) {
-			nut->alloc->free(nut->sc[i].sh.fourcc);
-			nut->alloc->free(nut->sc[i].sh.codec_specific);
-			nut->alloc->free(nut->sc[i].pts_cache);
-		}
-		nut->alloc->free(nut->sc);
-		nut->sc = NULL;
-		nut->stream_count = 0;
-	}
 	if (err != 2) flush_buf(nut->i); // unless EAGAIN
 	else nut->i->buf_ptr = nut->i->buf; // rewind
 	return err;



More information about the NUT-devel mailing list