diff --git a/src/dilithium_py/ml_dsa/ml_dsa.py b/src/dilithium_py/ml_dsa/ml_dsa.py index 3070a48..f4d5188 100644 --- a/src/dilithium_py/ml_dsa/ml_dsa.py +++ b/src/dilithium_py/ml_dsa/ml_dsa.py @@ -152,6 +152,14 @@ def _sk_size(self) -> int: t0_len = 416 * self.k return 2 * 32 + 64 + s1_len + s2_len + t0_len + def _sig_size(self) -> int: + if self.gamma_1 == 131072: + s_bytes = 18 + else: + assert self.gamma_1 == 524288 + s_bytes = 20 + return self.c_tilde_bytes + self.l * 32 * s_bytes + self.omega + self.k + def _unpack_sk( self, sk: bytes ) -> tuple[bytes, bytes, bytes, Vector, Vector, Vector]: @@ -220,6 +228,8 @@ def _unpack_h(self, h_bytes: bytes) -> Vector: return self.M.vector(vector_coeffs) def _unpack_sig(self, sig: bytes) -> tuple[bytes, Vector, Vector]: + if len(sig) != self._sig_size(): + raise ValueError("Incorrect signature size") c_tilde = sig[: self.c_tilde_bytes] z_bytes = sig[self.c_tilde_bytes : -(self.k + self.omega)] h_bytes = sig[-(self.k + self.omega) :] diff --git a/tests/test_ml_dsa.py b/tests/test_ml_dsa.py index dd5ab08..11b434e 100644 --- a/tests/test_ml_dsa.py +++ b/tests/test_ml_dsa.py @@ -20,6 +20,10 @@ def generic_test_ml_dsa(self, ML_DSA, count=5): sig = ML_DSA.sign(sk, msg, ctx=ctx) check_verify = ML_DSA.verify(pk, msg, sig, ctx=ctx) + check_short_verify = ML_DSA.verify(pk, msg, sig[:-1], ctx=ctx) + check_long_verify = ML_DSA.verify(pk, msg, sig + b"\x00", ctx=ctx) + check_empty_verify = ML_DSA.verify(pk, msg, b"", ctx=ctx) + # Sign with external_mu instead external_mu = ML_DSA.prehash_external_mu(pk, msg, ctx=ctx) sig_external_mu = ML_DSA.sign_external_mu(sk, external_mu) @@ -40,6 +44,15 @@ def generic_test_ml_dsa(self, ML_DSA, count=5): # Check that signature works self.assertTrue(check_verify) + # Check that too short signature is rejected + self.assertFalse(check_short_verify) + + # Check that too short signature is rejected + self.assertFalse(check_long_verify) + + # Check that empty signature is rejected + self.assertFalse(check_empty_verify) + # Check that external_mu also works self.assertTrue(check_external_mu)