diff --git a/include/libbz3.h b/include/libbz3.h index 7da48cb..65854fc 100644 --- a/include/libbz3.h +++ b/include/libbz3.h @@ -52,6 +52,7 @@ extern "C" { #define BZ3_ERR_TRUNCATED_DATA -5 #define BZ3_ERR_DATA_TOO_BIG -6 #define BZ3_ERR_INIT -7 +#define BZ3_ERR_DATA_SIZE_TOO_SMALL -8 struct bz3_state; @@ -173,8 +174,16 @@ BZIP3_API int32_t bz3_encode_block(struct bz3_state * state, uint8_t * buffer, i /** * @brief Decode a single block. - * `buffer' must be able to hold at least `bz3_bound(orig_size)' bytes. The size must not exceed the block size - * associated with the state. + * + * `buffer' must be able to hold at least `bz3_bound(orig_size)' bytes + * in order to ensure decompression will succeed for all possible bzip3 blocks. + * + * In most (but not all) cases, `orig_size` should usually be sufficient. + * If it is not sufficient, you must allocate a buffer of size `bz3_bound(orig_size)` temporarily. + * + * If `buffer` is too small, `BZ3_ERR_DATA_SIZE_TOO_SMALL` will be returned. + * The size must not exceed the block size associated with the state. + * * @param size The size of the compressed data in `buffer' * @param orig_size The original size of the data before compression. */ diff --git a/src/libbz3.c b/src/libbz3.c index 7057dc1..d73facf 100644 --- a/src/libbz3.c +++ b/src/libbz3.c @@ -489,6 +489,8 @@ BZIP3_API const char * bz3_strerror(struct bz3_state * state) { return "Truncated data"; case BZ3_ERR_DATA_TOO_BIG: return "Too much data"; + case BZ3_ERR_DATA_SIZE_TOO_SMALL: + return "Size of buffer (data_size) passed to the block decoder (bz3_decode_block) is too small."; default: return "Unknown error"; } @@ -662,6 +664,34 @@ BZIP3_API s32 bz3_decode_block(struct bz3_state * state, u8 * buffer, s32 data_s return -1; } + // Size that undoing BWT+BCM should decompress into. + s32 size_before_bwt; + + if (model & 2) + size_before_bwt = lzp_size; + else if (model & 4) + size_before_bwt = rle_size; + else + size_before_bwt = orig_size; + + // Note(sewer): It's technically valid within the spec to create a bzip3 block + // where the size after LZP/RLE is larger than the original input. Some earlier encoders + // even (mistakenly?) were able to do this. + // + // SAFETY: Data passed to the BWT+BCM step can be one of the following: + // - original data + // - original data + LZP + // - original data + RLE + // - original data + RLE + LZP + // + // We must ensure `data_size` is large enough to store the data at every step of the way + // when we walk backwards from undoing BWT+BCM. The size required may be stored in either `lzp_size`, + // `rle_size` OR `orig_size`. We therefore simply check all possible sizes. + if ((lzp_size > data_size) || (rle_size > data_size)) { + state->last_error = BZ3_ERR_DATA_SIZE_TOO_SMALL; + return -1; + } + // Decode the data. u8 *b1 = buffer, *b2 = state->swap_buffer; @@ -670,32 +700,25 @@ BZIP3_API s32 bz3_decode_block(struct bz3_state * state, u8 * buffer, s32 data_s state->cm_state->input_ptr = 0; state->cm_state->input_max = data_size; - s32 size_src; - - if (model & 2) - size_src = lzp_size; - else if (model & 4) - size_src = rle_size; - else - size_src = orig_size; - - decode_bytes(state->cm_state, b2, size_src); + decode_bytes(state->cm_state, b2, size_before_bwt); swap(b1, b2); - if (bwt_idx > size_src) { + if (bwt_idx > size_before_bwt) { state->last_error = BZ3_ERR_MALFORMED_HEADER; return -1; } // Undo BWT memset(state->sais_array, 0, sizeof(s32) * BWT_BOUND(state->block_size)); - memset(b2, 0, size_src); - if (libsais_unbwt(b1, b2, state->sais_array, size_src, NULL, bwt_idx) < 0) { + memset(b2, 0, size_before_bwt); // buffer b2, swap b1 + if (libsais_unbwt(b1, b2, state->sais_array, size_before_bwt, NULL, bwt_idx) < 0) { state->last_error = BZ3_ERR_BWT; return -1; } swap(b1, b2); + s32 size_src = size_before_bwt; + // Undo LZP if (model & 2) { size_src = lzp_decompress(b1, b2, lzp_size, bz3_bound(state->block_size), state->lzp_lut); @@ -706,7 +729,7 @@ BZIP3_API s32 bz3_decode_block(struct bz3_state * state, u8 * buffer, s32 data_s swap(b1, b2); } - if (model & 4) { + if (model & 4) { int err = mrled(b1, b2, orig_size, size_src); if (err) { state->last_error = BZ3_ERR_CRC;