diff --git a/src/dilithium_py/dilithium/dilithium.py b/src/dilithium_py/dilithium/dilithium.py index 59671b0..94fdfe9 100644 --- a/src/dilithium_py/dilithium/dilithium.py +++ b/src/dilithium_py/dilithium/dilithium.py @@ -162,6 +162,16 @@ def _unpack_sk(self, sk_bytes): def _unpack_h(self, h_bytes): offsets = [0] + list(h_bytes[-self.k :]) + # check offsets are monotonic increasing + if any(offsets[i] > offsets[i + 1] for i in range(len(offsets) - 1)): + raise ValueError("Offsets in h_bytes are not monotonic increasing") + # check offset[-1] is smaller than the length of h_bytes + if offsets[-1] > self.omega: + raise ValueError("Accumulate offset of hints exceeds omega") + # check zero fields are all zeros + if any(b != 0 for b in h_bytes[offsets[-1] : self.omega]): + raise ValueError("Non-zero fields in h_bytes are not all zeros") + non_zero_positions = [ list(h_bytes[offsets[i] : offsets[i + 1]]) for i in range(self.k) ] @@ -169,6 +179,11 @@ def _unpack_h(self, h_bytes): matrix = [] for poly_non_zero in non_zero_positions: coeffs = [0 for _ in range(256)] + for i, non_zero in enumerate(poly_non_zero): + if i > 0 and non_zero < poly_non_zero[i - 1]: + raise ValueError( + "Non-zero positions in h_bytes are not monotonic increasing" + ) for non_zero in poly_non_zero: coeffs[non_zero] = 1 matrix.append([self.R(coeffs)]) @@ -282,7 +297,11 @@ def verify(self, pk_bytes, m, sig_bytes): signature """ rho, t1 = self._unpack_pk(pk_bytes) - c_tilde, z, h = self._unpack_sig(sig_bytes) + try: + c_tilde, z, h = self._unpack_sig(sig_bytes) + except ValueError: + # verify failed if malformed input signature + return False if h.sum_hint() > self.omega: return False