11#include "cbc.h"
2+ #include "padding.h"
23#include <string.h>
34#include <stdlib.h>
45
5- /* PKCS#7 padding helper: returns number of padding bytes appended (1..16) */
6- static size_t pkcs7_pad (const uint8_t * in , size_t in_len , uint8_t * out ) {
7- size_t pad_len = AES_BLOCK - (in_len % AES_BLOCK );
8- if (pad_len == 0 ) pad_len = AES_BLOCK ;
9- /* copy input */
10- memcpy (out , in , in_len );
11- /* append pad bytes */
12- for (size_t i = 0 ; i < pad_len ; ++ i ) out [in_len + i ] = (uint8_t )pad_len ;
13- return pad_len ;
14- }
15-
16- /* remove PKCS#7 pad; returns 0 on success and sets *out_len to unpadded length.
17- * Returns non-zero on invalid padding.
18- */
19- static int pkcs7_unpad (uint8_t * buf , size_t buf_len , size_t * out_len ) {
20- if (buf_len == 0 || (buf_len % AES_BLOCK ) != 0 ) return -1 ;
21- uint8_t pad = buf [buf_len - 1 ];
22- if (pad == 0 || pad > AES_BLOCK ) return -2 ;
23- /* check that last pad bytes equal pad */
24- for (size_t i = 0 ; i < pad ; ++ i ) {
25- if (buf [buf_len - 1 - i ] != pad ) return -3 ;
26- }
27- * out_len = buf_len - pad ;
28- return 0 ;
29- }
30-
316/* XOR helper */
32- static inline void xor_block (uint8_t out [AES_BLOCK ], const uint8_t a [AES_BLOCK ], const uint8_t b [AES_BLOCK ]) {
33- for (int i = 0 ; i < AES_BLOCK ; ++ i ) out [i ] = a [i ] ^ b [i ];
7+ static inline void xor_block (uint8_t out [AES_BLOCK ],
8+ const uint8_t a [AES_BLOCK ],
9+ const uint8_t b [AES_BLOCK ])
10+ {
11+ for (int i = 0 ; i < AES_BLOCK ; ++ i )
12+ out [i ] = a [i ] ^ b [i ];
3413}
3514
3615int aes_cbc_encrypt (const uint8_t * in , size_t in_len ,
3716 uint8_t * out , size_t * out_len ,
3817 const uint8_t iv [AES_BLOCK ],
3918 encrypt_block_fn encrypt , const void * ctx )
4019{
41- if (!in || !out || !out_len || !iv || !encrypt ) return -1 ;
42-
43- /* padded buffer size = ceil(in_len/16)*16 + 16 (if in_len %16 ==0, add a full block) */
44- size_t padded_len = ((in_len + AES_BLOCK - 1 ) / AES_BLOCK ) * AES_BLOCK ;
45- if (padded_len == in_len ) padded_len += AES_BLOCK ;
46-
47- uint8_t * buf = (uint8_t * )malloc (padded_len );
48- if (!buf ) return -2 ;
20+ if (!out || !out_len || !iv || !encrypt ) return -1 ;
21+ if (!in && in_len > 0 ) return -1 ; // NULL only allowed if in_len==0
22+ if (!in ) in = (const uint8_t * )"" ; // empty input is valid
4923
50- /* create padded plaintext in buf */
51- size_t pad_len = pkcs7_pad (in , in_len , buf );
52- (void )pad_len ; /* padded_len equals in_len + pad_len */
24+ uint8_t * padded = NULL ;
25+ size_t padded_len = 0 ;
26+ if (pkcs7_pad (in , in_len , AES_BLOCK , & padded , & padded_len ) != PADDING_OK )
27+ return -2 ;
5328
5429 uint8_t prev [AES_BLOCK ];
5530 memcpy (prev , iv , AES_BLOCK );
5631
5732 for (size_t off = 0 ; off < padded_len ; off += AES_BLOCK ) {
5833 uint8_t block [AES_BLOCK ];
59- xor_block (block , buf + off , prev );
34+ xor_block (block , padded + off , prev );
6035 encrypt (block , out + off , ctx );
61- /* new prev = ciphertext block */
6236 memcpy (prev , out + off , AES_BLOCK );
6337 }
6438
6539 * out_len = padded_len ;
66- /* wipe sensitive buffer */
67- memset (buf , 0 , padded_len );
68- free (buf );
6940 memset (prev , 0 , AES_BLOCK );
41+ memset (padded , 0 , padded_len );
42+ free (padded );
7043 return 0 ;
7144}
7245
@@ -83,22 +56,24 @@ int aes_cbc_decrypt(const uint8_t *in, size_t in_len,
8356
8457 for (size_t off = 0 ; off < in_len ; off += AES_BLOCK ) {
8558 uint8_t tmp [AES_BLOCK ];
86- decrypt (in + off , tmp , ctx ); /* tmp = AES_DEC(Ci) */
87- xor_block (out + off , tmp , prev ); /* plaintext block = tmp XOR prev */
59+ decrypt (in + off , tmp , ctx );
60+ xor_block (out + off , tmp , prev );
8861 memcpy (prev , in + off , AES_BLOCK );
8962 }
9063
91- /* unpad in-place on out */
92- size_t unpadded_len = 0 ;
93- int r = pkcs7_unpad (out , in_len , & unpadded_len );
94- if (r != 0 ) {
95- /* wipe and return error */
64+ /* Unpad using reusable PKCS#7 function */
65+ uint8_t * unpadded = NULL ;
66+ size_t plain_len = 0 ;
67+ if (pkcs7_unpad (out , in_len , AES_BLOCK , & unpadded , & plain_len ) != PADDING_OK ) {
9668 memset (out , 0 , in_len );
9769 memset (prev , 0 , AES_BLOCK );
9870 return -3 ;
9971 }
10072
101- * out_len = unpadded_len ;
73+ memcpy (out , unpadded , plain_len );
74+ * out_len = plain_len ;
75+ memset (unpadded , 0 , plain_len );
76+ free (unpadded );
10277 memset (prev , 0 , AES_BLOCK );
10378 return 0 ;
10479}
0 commit comments