diff --git a/lms/common/types.go b/lms/common/types.go index ae13a7f..9d98ee9 100644 --- a/lms/common/types.go +++ b/lms/common/types.go @@ -45,30 +45,53 @@ func (w window) Mask() uint8 { } } -type lms_type_code uint32 +// lmsTypecode represents a typecode for LMS. +// See https://www.iana.org/assignments/leighton-micali-signatures/leighton-micali-signatures.xhtml#leighton-micali-signatures-1 +type lmsTypecode uint32 const ( - LMS_RESERVED lms_type_code = iota - LMOTS_SHA256_N32_W1 - LMOTS_SHA256_N32_W2 - LMOTS_SHA256_N32_W4 - LMOTS_SHA256_N32_W8 - LMS_SHA256_M32_H5 - LMS_SHA256_M32_H10 - LMS_SHA256_M32_H15 - LMS_SHA256_M32_H20 - LMS_SHA256_M32_H25 + LMS_RESERVED lmsTypecode = 0x00000000 + lmsTypecodeFirst = LMS_SHA256_M32_H5 + LMS_SHA256_M32_H5 lmsTypecode = 0x00000005 + LMS_SHA256_M32_H10 lmsTypecode = 0x00000006 + LMS_SHA256_M32_H15 lmsTypecode = 0x00000007 + LMS_SHA256_M32_H20 lmsTypecode = 0x00000008 + LMS_SHA256_M32_H25 lmsTypecode = 0x00000009 + LMS_SHA256_M24_H5 lmsTypecode = 0x0000000A + LMS_SHA256_M24_H10 lmsTypecode = 0x0000000B + LMS_SHA256_M24_H15 lmsTypecode = 0x0000000C + LMS_SHA256_M24_H20 lmsTypecode = 0x0000000D + LMS_SHA256_M24_H25 lmsTypecode = 0x0000000E + lmsTypecodeLast = LMS_SHA256_M24_H25 +) + +// lmotsTypecode represents a typecode for LM-OTS. +// See https://www.iana.org/assignments/leighton-micali-signatures/leighton-micali-signatures.xhtml#lm-ots-signatures +type lmotsTypecode uint32 + +const ( + LMOTS_RESERVED lmotsTypecode = 0x00000000 + lmotsTypecodeFirst = LMOTS_SHA256_N32_W1 + LMOTS_SHA256_N32_W1 lmotsTypecode = 0x00000001 + LMOTS_SHA256_N32_W2 lmotsTypecode = 0x00000002 + LMOTS_SHA256_N32_W4 lmotsTypecode = 0x00000003 + LMOTS_SHA256_N32_W8 lmotsTypecode = 0x00000004 + LMOTS_SHA256_N24_W1 lmotsTypecode = 0x00000005 + LMOTS_SHA256_N24_W2 lmotsTypecode = 0x00000006 + LMOTS_SHA256_N24_W4 lmotsTypecode = 0x00000007 + LMOTS_SHA256_N24_W8 lmotsTypecode = 0x00000008 + lmotsTypecodeLast = LMOTS_SHA256_N24_W8 ) // LmsAlgorithmType represents a specific instance of LMS type LmsAlgorithmType interface { - LmsType() (lms_type_code, error) + LmsType() (lmsTypecode, error) LmsParams() (LmsParam, error) } // LmsOtsAlgorithmType represents a specific instance of LM-OTS type LmsOtsAlgorithmType interface { - LmsOtsType() (lms_type_code, error) + LmsOtsType() (lmotsTypecode, error) Params() (LmsOtsParam, error) } @@ -99,19 +122,19 @@ type LmsOtsParam struct { SIG_LEN uint64 // total byte length for a valid signature } -// Returns a lms_type_code, given a uint32 of the same value -func Uint32ToLmsType(x uint32) lms_type_code { - return lms_type_code(x) +// Returns a lmsTypecode, given a uint32 of the same value +func Uint32ToLmsType(x uint32) lmsTypecode { + return lmsTypecode(x) } -// Returns a uint32 of the same value as the lms_type_code -func (x lms_type_code) ToUint32() uint32 { +// Returns a uint32 of the same value as the lmsTypecode +func (x lmsTypecode) ToUint32() uint32 { return uint32(x) } -// Returns a lms_type_code if within a valid range for LMS; otherwise, an error -func (x lms_type_code) LmsType() (lms_type_code, error) { - if x >= LMS_SHA256_M32_H5 && x <= LMS_SHA256_M32_H25 { +// Returns a lmsTypecode if within a valid range for LMS; otherwise, an error +func (x lmsTypecode) LmsType() (lmsTypecode, error) { + if x >= lmsTypecodeFirst && x <= lmsTypecodeLast { return x, nil } else { return x, errors.New("LmsType(): invalid type code") @@ -119,8 +142,8 @@ func (x lms_type_code) LmsType() (lms_type_code, error) { } // Returns the expected signature length for an LMS type, given an associated LM-OTS type -func (x lms_type_code) LmsSigLength(otstc lms_type_code) (uint64, error) { - if x >= LMS_SHA256_M32_H5 && x <= LMS_SHA256_M32_H25 { +func (x lmsTypecode) LmsSigLength(otstc lmotsTypecode) (uint64, error) { + if x >= lmsTypecodeFirst && x <= lmsTypecodeLast { params, err := x.LmsParams() if err != nil { return 0, err @@ -135,9 +158,19 @@ func (x lms_type_code) LmsSigLength(otstc lms_type_code) (uint64, error) { } } -// Returns a lms_type_code if within a valid range for LM-OTS; otherwise, an error -func (x lms_type_code) LmsOtsType() (lms_type_code, error) { - if x >= LMOTS_SHA256_N32_W1 && x <= LMOTS_SHA256_N32_W8 { +// Returns a lmotsTypecode, given a uint32 of the same value +func Uint32ToLmotsType(x uint32) lmotsTypecode { + return lmotsTypecode(x) +} + +// Returns a uint32 of the same value as the lmotsTypecode +func (x lmotsTypecode) ToUint32() uint32 { + return uint32(x) +} + +// Returns a lmotsTypecode if within a valid range for LM-OTS; otherwise, an error +func (x lmotsTypecode) LmsOtsType() (lmotsTypecode, error) { + if x >= lmotsTypecodeFirst && x <= lmotsTypecodeLast { return x, nil } else { return x, errors.New("LmsOtsType(): invalid type code") @@ -145,8 +178,8 @@ func (x lms_type_code) LmsOtsType() (lms_type_code, error) { } // Returns the expected byte length of a given LM-OTS signature algorithm -func (x lms_type_code) LmsOtsSigLength() (uint64, error) { - if x >= LMOTS_SHA256_N32_W1 && x <= LMOTS_SHA256_N32_W8 { +func (x lmotsTypecode) LmsOtsSigLength() (uint64, error) { + if x >= lmotsTypecodeFirst && x <= lmotsTypecodeLast { params, err := x.Params() if err != nil { return 0, err @@ -157,8 +190,8 @@ func (x lms_type_code) LmsOtsSigLength() (uint64, error) { } } -// Returns a LmsParam corresponding to the lms_type_code, x -func (x lms_type_code) LmsParams() (LmsParam, error) { +// Returns a LmsParam corresponding to the lmsTypecode, x +func (x lmsTypecode) LmsParams() (LmsParam, error) { switch x { case LMS_SHA256_M32_H5: return LmsParam{ @@ -190,13 +223,43 @@ func (x lms_type_code) LmsParams() (LmsParam, error) { M: 32, H: 25, }, nil + case LMS_SHA256_M24_H5: + return LmsParam{ + Hash: Sha256Hasher{}, + M: 24, + H: 5, + }, nil + case LMS_SHA256_M24_H10: + return LmsParam{ + Hash: Sha256Hasher{}, + M: 24, + H: 10, + }, nil + case LMS_SHA256_M24_H15: + return LmsParam{ + Hash: Sha256Hasher{}, + M: 24, + H: 15, + }, nil + case LMS_SHA256_M24_H20: + return LmsParam{ + Hash: Sha256Hasher{}, + M: 24, + H: 20, + }, nil + case LMS_SHA256_M24_H25: + return LmsParam{ + Hash: Sha256Hasher{}, + M: 24, + H: 25, + }, nil default: return LmsParam{}, errors.New("LmsParams(): invalid type code") } } -// Returns a LmsOtsParam corresponding to the lms_type_code, x -func (x lms_type_code) Params() (LmsOtsParam, error) { +// Returns a LmsOtsParam corresponding to the lmsTypecode, x +func (x lmotsTypecode) Params() (LmsOtsParam, error) { switch x { case LMOTS_SHA256_N32_W1: return LmsOtsParam{ @@ -234,6 +297,42 @@ func (x lms_type_code) Params() (LmsOtsParam, error) { LS: 0, SIG_LEN: 1124, }, nil + case LMOTS_SHA256_N24_W1: + return LmsOtsParam{ + H: Sha256Hasher{}, + N: 24, + W: WINDOW_W1, + P: 200, + LS: 8, + SIG_LEN: 4828, + }, nil + case LMOTS_SHA256_N24_W2: + return LmsOtsParam{ + H: Sha256Hasher{}, + N: 24, + W: WINDOW_W2, + P: 101, + LS: 6, + SIG_LEN: 2452, + }, nil + case LMOTS_SHA256_N24_W4: + return LmsOtsParam{ + H: Sha256Hasher{}, + N: 24, + W: WINDOW_W4, + P: 51, + LS: 4, + SIG_LEN: 1252, + }, nil + case LMOTS_SHA256_N24_W8: + return LmsOtsParam{ + H: Sha256Hasher{}, + N: 24, + W: WINDOW_W8, + P: 26, + LS: 0, + SIG_LEN: 652, + }, nil default: return LmsOtsParam{}, errors.New("Params(): invalid type code") } diff --git a/lms/common/util.go b/lms/common/util.go index 8de1d99..82c1bfa 100644 --- a/lms/common/util.go +++ b/lms/common/util.go @@ -6,6 +6,7 @@ package common import ( "encoding/binary" + "hash" ) // Returns a []byte representing the Winternitz coefficients of x for a given window, w @@ -55,3 +56,17 @@ func Expand(msg []byte, mode LmsOtsAlgorithmType) ([]uint8, error) { return res[:params.P], nil } + +// HashWrite wraps h.Write with a panic. +func HashWrite(h hash.Hash, x []byte) { + _, err := h.Write(x) + if err != nil { + panic("hash.Hash.Write never errors") + } +} + +// HashSum wraps Hash.Sum, returning the n leftmost bytes. +// Panics if n is larger than the length of the hash output. +func HashSum(h hash.Hash, n uint64) []byte { + return h.Sum(nil)[:n] +} diff --git a/lms/lms/private.go b/lms/lms/private.go index 57240bf..ef14a88 100644 --- a/lms/lms/private.go +++ b/lms/lms/private.go @@ -11,17 +11,9 @@ import ( "github.com/trailofbits/lms-go/lms/ots" "crypto/rand" - "hash" "io" ) -func hash_write(h hash.Hash, x []byte) { - _, err := h.Write(x) - if err != nil { - panic("hash.Hash.Write never errors") - } -} - // NewPrivateKey returns a LmsPrivateKey, seeded by a cryptographically secure // random number generator. func NewPrivateKey(tc common.LmsAlgorithmType, otstc common.LmsOtsAlgorithmType) (LmsPrivateKey, error) { @@ -186,7 +178,7 @@ func LmsPrivateKeyFromBytes(b []byte) (LmsPrivateKey, error) { return LmsPrivateKey{}, err } // The OTS typecode is bytes 4-7 (4 bytes) - otstype, err := common.Uint32ToLmsType(binary.BigEndian.Uint32(b[4:8])).LmsOtsType() + otstype, err := common.Uint32ToLmotsType(binary.BigEndian.Uint32(b[4:8])).LmsOtsType() if err != nil { return LmsPrivateKey{}, err } @@ -249,11 +241,11 @@ func GeneratePKTree(tc common.LmsAlgorithmType, otstc common.LmsOtsAlgorithmType binary.BigEndian.PutUint32(r_be[:], r) hasher := ots_params.H.New() - hash_write(hasher, id[:]) - hash_write(hasher, r_be[:]) - hash_write(hasher, common.D_LEAF[:]) - hash_write(hasher, ots_pub.Key()) - authtree[r-1] = hasher.Sum(nil) + common.HashWrite(hasher, id[:]) + common.HashWrite(hasher, r_be[:]) + common.HashWrite(hasher, common.D_LEAF[:]) + common.HashWrite(hasher, ots_pub.Key()) + authtree[r-1] = common.HashSum(hasher, ots_params.N) j = i for j%2 == 1 { @@ -263,12 +255,12 @@ func GeneratePKTree(tc common.LmsAlgorithmType, otstc common.LmsOtsAlgorithmType binary.BigEndian.PutUint32(r_be[:], r) - hash_write(hasher, id[:]) - hash_write(hasher, r_be[:]) - hash_write(hasher, common.D_INTR[:]) - hash_write(hasher, authtree[2*r-1]) - hash_write(hasher, authtree[2*r]) - authtree[r-1] = hasher.Sum(nil) + common.HashWrite(hasher, id[:]) + common.HashWrite(hasher, r_be[:]) + common.HashWrite(hasher, common.D_INTR[:]) + common.HashWrite(hasher, authtree[2*r-1]) + common.HashWrite(hasher, authtree[2*r]) + authtree[r-1] = common.HashSum(hasher, ots_params.N) } } return authtree, nil diff --git a/lms/lms/public.go b/lms/lms/public.go index 55d6a6e..bbef1aa 100644 --- a/lms/lms/public.go +++ b/lms/lms/public.go @@ -67,27 +67,27 @@ func (pub *LmsPublicKey) Verify(msg []byte, sig LmsSignature) bool { binary.BigEndian.PutUint32(node_num_bytes[:], node_num) hasher := ots_params.H.New() - hash_write(hasher, pub.id[:]) - hash_write(hasher, node_num_bytes[:]) - hash_write(hasher, common.D_LEAF[:]) - hash_write(hasher, key_candidate.Key()) - tmp := hasher.Sum(nil) + common.HashWrite(hasher, pub.id[:]) + common.HashWrite(hasher, node_num_bytes[:]) + common.HashWrite(hasher, common.D_LEAF[:]) + common.HashWrite(hasher, key_candidate.Key()) + tmp := common.HashSum(hasher, ots_params.N) for i := 0; i < height; i++ { binary.BigEndian.PutUint32(tmp_be[:], node_num>>1) hasher := ots_params.H.New() - hash_write(hasher, pub.id[:]) - hash_write(hasher, tmp_be[:]) - hash_write(hasher, common.D_INTR[:]) + common.HashWrite(hasher, pub.id[:]) + common.HashWrite(hasher, tmp_be[:]) + common.HashWrite(hasher, common.D_INTR[:]) if node_num%2 == 1 { - hash_write(hasher, sig.path[i]) - hash_write(hasher, tmp) + common.HashWrite(hasher, sig.path[i]) + common.HashWrite(hasher, tmp) } else { - hash_write(hasher, tmp) - hash_write(hasher, sig.path[i]) + common.HashWrite(hasher, tmp) + common.HashWrite(hasher, sig.path[i]) } - tmp = hasher.Sum(nil) + tmp = common.HashSum(hasher, ots_params.N) node_num >>= 1 } return subtle.ConstantTimeCompare(tmp, pub.k) == 1 @@ -143,7 +143,7 @@ func LmsPublicKeyFromBytes(b []byte) (LmsPublicKey, error) { return LmsPublicKey{}, err } // The OTS typecode is bytes 4-7 (4 bytes) - otstype, err := common.Uint32ToLmsType(binary.BigEndian.Uint32(b[4:8])).LmsOtsType() + otstype, err := common.Uint32ToLmotsType(binary.BigEndian.Uint32(b[4:8])).LmsOtsType() if err != nil { return LmsPublicKey{}, err } diff --git a/lms/lms/public_test.go b/lms/lms/public_test.go index 70cd4a3..2e1025e 100644 --- a/lms/lms/public_test.go +++ b/lms/lms/public_test.go @@ -649,6 +649,147 @@ func TestRfc8554KAT2(t *testing.T) { assert.True(t, result) } +// From the SHA256/192 test vector in https://datatracker.ietf.org/doc/draft-fluhrer-lms-more-parm-sets/19/ +func TestSha256192KAT1(t *testing.T) { + var publicKeyBytes = []byte{ + // LMS type: LMS_SHA256_M32_H20 + 0x00, 0x00, 0x00, 0x0a, + // LM-OTS type: LMOTS_SHA256_N24_W8 + 0x00, 0x00, 0x00, 0x08, + // I + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + // K + 0x2c, 0x57, 0x14, 0x50, 0xae, 0xd9, 0x9c, 0xfb, + 0x4f, 0x4a, 0xc2, 0x85, 0xda, 0x14, 0x88, 0x27, + 0x96, 0x61, 0x83, 0x14, 0x50, 0x8b, 0x12, 0xd2, + } + + var signatureBytes = []byte{ + // q + 0x00, 0x00, 0x00, 0x05, + // LM-OTS type: LMOTS_SHA256_N24_W8 + 0x00, 0x00, 0x00, 0x08, + // C + 0x0b, 0x50, 0x40, 0xa1, 0x8c, 0x1b, 0x5c, 0xab, + 0xcb, 0xc8, 0x5b, 0x04, 0x74, 0x02, 0xec, 0x62, + 0x94, 0xa3, 0x0d, 0xd8, 0xda, 0x8f, 0xc3, 0xda, + // y + 0xe1, 0x3b, 0x9f, 0x08, 0x75, 0xf0, 0x93, 0x61, + 0xdc, 0x77, 0xfc, 0xc4, 0x48, 0x1e, 0xa4, 0x63, + 0xc0, 0x73, 0x71, 0x62, 0x49, 0x71, 0x91, 0x93, + 0x61, 0x4b, 0x83, 0x5b, 0x46, 0x94, 0xc0, 0x59, + 0xf1, 0x2d, 0x3a, 0xed, 0xd3, 0x4f, 0x3d, 0xb9, + 0x3f, 0x35, 0x80, 0xfb, 0x88, 0x74, 0x3b, 0x8b, + 0x3d, 0x06, 0x48, 0xc0, 0x53, 0x7b, 0x7a, 0x50, + 0xe4, 0x33, 0xd7, 0xea, 0x9d, 0x66, 0x72, 0xff, + 0xfc, 0x5f, 0x42, 0x77, 0x0f, 0xea, 0xb4, 0xf9, + 0x8e, 0xb3, 0xf3, 0xb2, 0x3f, 0xd2, 0x06, 0x1e, + 0x4d, 0x0b, 0x38, 0xf8, 0x32, 0x86, 0x0a, 0xe7, + 0x66, 0x73, 0xad, 0x1a, 0x1a, 0x52, 0xa9, 0x00, + 0x5d, 0xcf, 0x1b, 0xfb, 0x56, 0xfe, 0x16, 0xff, + 0x72, 0x36, 0x27, 0x61, 0x2f, 0x9a, 0x48, 0xf7, + 0x90, 0xf3, 0xc4, 0x7a, 0x67, 0xf8, 0x70, 0xb8, + 0x1e, 0x91, 0x9d, 0x99, 0x91, 0x9c, 0x8d, 0xb4, + 0x81, 0x68, 0x83, 0x8c, 0xec, 0xe0, 0xab, 0xfb, + 0x68, 0x3d, 0xa4, 0x8b, 0x92, 0x09, 0x86, 0x8b, + 0xe8, 0xec, 0x10, 0xc6, 0x3d, 0x8b, 0xf8, 0x0d, + 0x36, 0x49, 0x8d, 0xfc, 0x20, 0x5d, 0xc4, 0x5d, + 0x0d, 0xd8, 0x70, 0x57, 0x2d, 0x6d, 0x8f, 0x1d, + 0x90, 0x17, 0x7c, 0xf5, 0x13, 0x7b, 0x8b, 0xbf, + 0x7b, 0xcb, 0x67, 0xa4, 0x6f, 0x86, 0xf2, 0x6c, + 0xfa, 0x5a, 0x44, 0xcb, 0xca, 0xa4, 0xe1, 0x8d, + 0xa0, 0x99, 0xa9, 0x8b, 0x0b, 0x3f, 0x96, 0xd5, + 0xac, 0x8a, 0xc3, 0x75, 0xd8, 0xda, 0x2a, 0x7c, + 0x24, 0x80, 0x04, 0xba, 0x11, 0xd7, 0xac, 0x77, + 0x5b, 0x92, 0x18, 0x35, 0x9c, 0xdd, 0xab, 0x4c, + 0xf8, 0xcc, 0xc6, 0xd5, 0x4c, 0xb7, 0xe1, 0xb3, + 0x5a, 0x36, 0xdd, 0xc9, 0x26, 0x5c, 0x08, 0x70, + 0x63, 0xd2, 0xfc, 0x67, 0x42, 0xa7, 0x17, 0x78, + 0x76, 0x47, 0x6a, 0x32, 0x4b, 0x03, 0x29, 0x5b, + 0xfe, 0xd9, 0x9f, 0x2e, 0xaf, 0x1f, 0x38, 0x97, + 0x05, 0x83, 0xc1, 0xb2, 0xb6, 0x16, 0xaa, 0xd0, + 0xf3, 0x1c, 0xd7, 0xa4, 0xb1, 0xbb, 0x0a, 0x51, + 0xe4, 0x77, 0xe9, 0x4a, 0x01, 0xbb, 0xb4, 0xd6, + 0xf8, 0x86, 0x6e, 0x25, 0x28, 0xa1, 0x59, 0xdf, + 0x3d, 0x6c, 0xe2, 0x44, 0xd2, 0xb6, 0x51, 0x8d, + 0x1f, 0x02, 0x12, 0x28, 0x5a, 0x3c, 0x2d, 0x4a, + 0x92, 0x70, 0x54, 0xa1, 0xe1, 0x62, 0x0b, 0x5b, + 0x02, 0xaa, 0xb0, 0xc8, 0xc1, 0x0e, 0xd4, 0x8a, + 0xe5, 0x18, 0xea, 0x73, 0xcb, 0xa8, 0x1f, 0xcf, + 0xff, 0x88, 0xbf, 0xf4, 0x61, 0xda, 0xc5, 0x1e, + 0x7a, 0xb4, 0xca, 0x75, 0xf4, 0x7a, 0x62, 0x59, + 0xd2, 0x48, 0x20, 0xb9, 0x99, 0x57, 0x92, 0xd1, + 0x39, 0xf6, 0x1a, 0xe2, 0xa8, 0x18, 0x6a, 0xe4, + 0xe3, 0xc9, 0xbf, 0xe0, 0xaf, 0x2c, 0xc7, 0x17, + 0xf4, 0x24, 0xf4, 0x1a, 0xa6, 0x7f, 0x03, 0xfa, + 0xed, 0xb0, 0x66, 0x51, 0x15, 0xf2, 0x06, 0x7a, + 0x46, 0x84, 0x3a, 0x4c, 0xbb, 0xd2, 0x97, 0xd5, + 0xe8, 0x3b, 0xc1, 0xaa, 0xfc, 0x18, 0xd1, 0xd0, + 0x3b, 0x3d, 0x89, 0x4e, 0x85, 0x95, 0xa6, 0x52, + 0x60, 0x73, 0xf0, 0x2a, 0xb0, 0xf0, 0x8b, 0x99, + 0xfd, 0x9e, 0xb2, 0x08, 0xb5, 0x9f, 0xf6, 0x31, + 0x7e, 0x55, 0x45, 0xe6, 0xf9, 0xad, 0x5f, 0x9c, + 0x18, 0x3a, 0xbd, 0x04, 0x3d, 0x5a, 0xcd, 0x6e, + 0xb2, 0xdd, 0x4d, 0xa3, 0xf0, 0x2d, 0xbc, 0x31, + 0x67, 0xb4, 0x68, 0x72, 0x0a, 0x4b, 0x8b, 0x92, + 0xdd, 0xfe, 0x79, 0x60, 0x99, 0x8b, 0xb7, 0xa0, + 0xec, 0xf2, 0xa2, 0x6a, 0x37, 0x59, 0x82, 0x99, + 0x41, 0x3f, 0x7b, 0x2a, 0xec, 0xd3, 0x9a, 0x30, + 0xce, 0xc5, 0x27, 0xb4, 0xd9, 0x71, 0x0c, 0x44, + 0x73, 0x63, 0x90, 0x22, 0x45, 0x1f, 0x50, 0xd0, + 0x1c, 0x04, 0x57, 0x12, 0x5d, 0xa0, 0xfa, 0x44, + 0x29, 0xc0, 0x7d, 0xad, 0x85, 0x9c, 0x84, 0x6c, + 0xbb, 0xd9, 0x3a, 0xb5, 0xb9, 0x1b, 0x01, 0xbc, + 0x77, 0x0b, 0x08, 0x9c, 0xfe, 0xde, 0x6f, 0x65, + 0x1e, 0x86, 0xdd, 0x7c, 0x15, 0x98, 0x9c, 0x8b, + 0x53, 0x21, 0xde, 0xa9, 0xca, 0x60, 0x8c, 0x71, + 0xfd, 0x86, 0x23, 0x23, 0x07, 0x2b, 0x82, 0x7c, + 0xee, 0x7a, 0x7e, 0x28, 0xe4, 0xe2, 0xb9, 0x99, + 0x64, 0x72, 0x33, 0xc3, 0x45, 0x69, 0x44, 0xbb, + 0x7a, 0xef, 0x91, 0x87, 0xc9, 0x6b, 0x3f, 0x5b, + 0x79, 0xfb, 0x98, 0xbc, 0x76, 0xc3, 0x57, 0x4d, + 0xd0, 0x6f, 0x0e, 0x95, 0x68, 0x5e, 0x5b, 0x3a, + 0xef, 0x3a, 0x54, 0xc4, 0x15, 0x5f, 0xe3, 0xad, + 0x81, 0x77, 0x49, 0x62, 0x9c, 0x30, 0xad, 0xbe, + 0x89, 0x7c, 0x4f, 0x44, 0x54, 0xc8, 0x6c, 0x49, + // LMS type: LMS_SHA256_M24_H5 + 0x00, 0x00, 0x00, 0x0a, + // path + 0xe9, 0xca, 0x10, 0xea, 0xa8, 0x11, 0xb2, 0x2a, + 0xe0, 0x7f, 0xb1, 0x95, 0xe3, 0x59, 0x0a, 0x33, + 0x4e, 0xa6, 0x42, 0x09, 0x94, 0x2f, 0xba, 0xe3, + 0x38, 0xd1, 0x9f, 0x15, 0x21, 0x82, 0xc8, 0x07, + 0xd3, 0xc4, 0x0b, 0x18, 0x9d, 0x3f, 0xcb, 0xea, + 0x94, 0x2f, 0x44, 0x68, 0x24, 0x39, 0xb1, 0x91, + 0x33, 0x2d, 0x33, 0xae, 0x0b, 0x76, 0x1a, 0x2a, + 0x8f, 0x98, 0x4b, 0x56, 0xb2, 0xac, 0x2f, 0xd4, + 0xab, 0x08, 0x22, 0x3a, 0x69, 0xed, 0x1f, 0x77, + 0x19, 0xc7, 0xaa, 0x7e, 0x9e, 0xee, 0x96, 0x50, + 0x4b, 0x0e, 0x60, 0xc6, 0xbb, 0x5c, 0x94, 0x2d, + 0x69, 0x5f, 0x04, 0x93, 0xeb, 0x25, 0xf8, 0x0a, + 0x58, 0x71, 0xcf, 0xfd, 0x13, 0x1d, 0x0e, 0x04, + 0xff, 0xe5, 0x06, 0x5b, 0xc7, 0x87, 0x5e, 0x82, + 0xd3, 0x4b, 0x40, 0xb6, 0x9d, 0xd9, 0xf3, 0xc1, + } + var messageBytes = []byte{ + 0x54, 0x65, 0x73, 0x74, 0x20, 0x6d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x20, 0x66, 0x6f, 0x72, + 0x20, 0x53, 0x48, 0x41, 0x32, 0x35, 0x36, 0x2d, + 0x31, 0x39, 0x32, 0x0a, + } + sig, err := lms.LmsSignatureFromBytes(signatureBytes) + if err != nil { + t.Fatalf("LmsSignatureFromBytes() = %v", err) + } + pk, err := lms.LmsPublicKeyFromBytes(publicKeyBytes) + if err != nil { + t.Fatalf("LmsPublicKeyFromBytes() = %v", err) + } + result := pk.Verify(messageBytes, sig) + assert.True(t, result) +} + // Fuzzers func FuzzPubkeyFromBytes(f *testing.F) { pubkeybytes, _ := hex.DecodeString(PUBKEY_HEX) diff --git a/lms/lms/signature.go b/lms/lms/signature.go index d7f8fcb..a9e463c 100644 --- a/lms/lms/signature.go +++ b/lms/lms/signature.go @@ -55,7 +55,7 @@ func LmsSignatureFromBytes(b []byte) (LmsSignature, error) { q := binary.BigEndian.Uint32(b[0:4]) // The OTS signature starts at byte 4, with the typecode first - otstc := common.Uint32ToLmsType(binary.BigEndian.Uint32(b[4:8])) + otstc := common.Uint32ToLmotsType(binary.BigEndian.Uint32(b[4:8])) // Return error if not a valid LM-OTS algorithm: _, err = otstc.LmsOtsType() if err != nil { diff --git a/lms/ots/ots_test.go b/lms/ots/ots_test.go index 6fe60cb..21ce7f8 100644 --- a/lms/ots/ots_test.go +++ b/lms/ots/ots_test.go @@ -9,95 +9,105 @@ import ( "github.com/trailofbits/lms-go/lms/ots" ) -func testOtsSignVerify(t *testing.T, otstc common.LmsOtsAlgorithmType) { - var err error - - id, err := hex.DecodeString("d08fabd4a2091ff0a8cb4ed834e74534") - if err != nil { - panic(err) - } - - ots_priv, err := ots.NewPrivateKey(otstc, 0, common.ID(id)) - if err != nil { - panic(err) - } - - ots_pub, err := ots_priv.Public() - if err != nil { - panic(err) - } - ots_sig, err := ots_priv.Sign([]byte("example"), nil) - if err != nil { - panic(err) - } - - result := ots_pub.Verify([]byte("example"), ots_sig) - assert.True(t, result) -} - -func testOtsSignVerifyFail(t *testing.T, otstc common.LmsOtsAlgorithmType) { - var err error - - id, err := hex.DecodeString("d08fabd4a2091ff0a8cb4ed834e74534") - if err != nil { - panic(err) - } - - ots_priv, err := ots.NewPrivateKey(otstc, 0, common.ID(id)) - if err != nil { - panic(err) - } - - ots_pub, err := ots_priv.Public() - if err != nil { - panic(err) +func TestOtsSignVerify(t *testing.T) { + for _, tc := range []struct { + name string + typecode uint32 + }{ + { + name: "LMOTS_SHA256_N32_W1", + typecode: common.LMOTS_SHA256_N32_W1.ToUint32(), + }, + { + name: "LMOTS_SHA256_N32_W2", + typecode: common.LMOTS_SHA256_N32_W2.ToUint32(), + }, + { + name: "LMOTS_SHA256_N32_W4", + typecode: common.LMOTS_SHA256_N32_W4.ToUint32(), + }, + { + name: "LMOTS_SHA256_N32_W8", + typecode: common.LMOTS_SHA256_N32_W8.ToUint32(), + }, + { + name: "LMOTS_SHA256_N24_W1", + typecode: common.LMOTS_SHA256_N24_W1.ToUint32(), + }, + { + name: "LMOTS_SHA256_N24_W2", + typecode: common.LMOTS_SHA256_N24_W2.ToUint32(), + }, + { + name: "LMOTS_SHA256_N24_W4", + typecode: common.LMOTS_SHA256_N24_W4.ToUint32(), + }, + { + name: "LMOTS_SHA256_N24_W8", + typecode: common.LMOTS_SHA256_N24_W8.ToUint32(), + }, + } { + t.Run(tc.name, func(t *testing.T) { + var err error + + id, err := hex.DecodeString("d08fabd4a2091ff0a8cb4ed834e74534") + if err != nil { + t.Fatalf("hex.DecodeString() = %v", err) + } + + otsPriv, err := ots.NewPrivateKey(common.Uint32ToLmotsType(tc.typecode), 0, common.ID(id)) + if err != nil { + t.Fatalf("ots.NewPrivateKey() = %v", err) + } + + otsPub, err := otsPriv.Public() + if err != nil { + t.Fatalf("otsPriv.Public() = %v", err) + } + otsSig, err := otsPriv.Sign([]byte("example"), nil) + if err != nil { + t.Fatalf("otsPriv.Sign() = %v", err) + } + + t.Run("VerifyOK", func(t *testing.T) { + result := otsPub.Verify([]byte("example"), otsSig) + assert.True(t, result) + }) + + t.Run("VerifyBadPubFail", func(t *testing.T) { + // modify q so that the verification fails + otsPubBytes := otsPub.ToBytes() + otsPubBytes[23] ^= 1 + otsPub2, err := ots.LmsOtsPublicKeyFromBytes(otsPubBytes) + if err != nil { + t.Fatalf("LmsOtsPublicKeyFromBytes() = %v", err) + } + result := otsPub2.Verify([]byte("example"), otsSig) + assert.False(t, result) + }) + + t.Run("VerifyBadSigFail", func(t *testing.T) { + // modify sig so that the verification fails + otsSigBytes, err := otsSig.ToBytes() + if err != nil { + t.Fatalf("otsSig.ToBytes() = %v", err) + } + otsSigBytes[23] ^= 1 + otsSig2, err := ots.LmsOtsSignatureFromBytes(otsSigBytes) + if err != nil { + t.Fatalf("LmsOtsPublicKeyFromBytes() = %v", err) + } + result := otsPub.Verify([]byte("example"), otsSig2) + assert.False(t, result) + }) + + t.Run("VerifyBadMsgFail", func(t *testing.T) { + // try to verify a different message + result := otsPub.Verify([]byte("example2"), otsSig) + assert.False(t, result) + }) + }) } - ots_sig, err := ots_priv.Sign([]byte("example"), nil) - if err != nil { - panic(err) - } - - // modify q so that the verification fails - ots_pub_bytes := ots_pub.ToBytes() - ots_pub_bytes[23] = 1 - ots_pub, err = ots.LmsOtsPublicKeyFromBytes(ots_pub_bytes) - if err != nil { - panic(err) - } - result := ots_pub.Verify([]byte("example"), ots_sig) - assert.False(t, result) -} - -func TestOtsSignVerifyW1(t *testing.T) { - testOtsSignVerify(t, common.LMOTS_SHA256_N32_W1) -} - -func TestOtsSignVerifyW2(t *testing.T) { - testOtsSignVerify(t, common.LMOTS_SHA256_N32_W2) -} - -func TestOtsSignVerifyW4(t *testing.T) { - testOtsSignVerify(t, common.LMOTS_SHA256_N32_W4) -} - -func TestOtsSignVerifyW8(t *testing.T) { - testOtsSignVerify(t, common.LMOTS_SHA256_N32_W8) -} - -func TestOtsSignVerifyW1Fail(t *testing.T) { - testOtsSignVerifyFail(t, common.LMOTS_SHA256_N32_W1) -} - -func TestOtsSignVerifyW2Fail(t *testing.T) { - testOtsSignVerifyFail(t, common.LMOTS_SHA256_N32_W2) -} - -func TestOtsSignVerifyW4Fail(t *testing.T) { - testOtsSignVerifyFail(t, common.LMOTS_SHA256_N32_W4) -} - -func TestOtsSignVerifyW8Fail(t *testing.T) { - testOtsSignVerifyFail(t, common.LMOTS_SHA256_N32_W8) } func TestDoubleSign(t *testing.T) { diff --git a/lms/ots/private.go b/lms/ots/private.go index 9bcb700..452f037 100644 --- a/lms/ots/private.go +++ b/lms/ots/private.go @@ -9,17 +9,9 @@ import ( "crypto/rand" "encoding/binary" "errors" - "hash" "io" ) -func hash_write(h hash.Hash, x []byte) { - _, err := h.Write(x) - if err != nil { - panic("hash.Hash.Write never errors") - } -} - // NewPrivateKey returns a LmsOtsPrivateKey, seeded by a cryptographically secure // random number generator. func NewPrivateKey(tc common.LmsOtsAlgorithmType, q uint32, id common.ID) (LmsOtsPrivateKey, error) { @@ -54,13 +46,13 @@ func NewPrivateKeyFromSeed(tc common.LmsOtsAlgorithmType, q uint32, id common.ID binary.BigEndian.PutUint32(q_be[:], q) binary.BigEndian.PutUint16(i_be[:], uint16(i)) - hash_write(hasher, id[:]) - hash_write(hasher, q_be[:]) - hash_write(hasher, i_be[:]) - hash_write(hasher, []byte{0xff}) - hash_write(hasher, seed) + common.HashWrite(hasher, id[:]) + common.HashWrite(hasher, q_be[:]) + common.HashWrite(hasher, i_be[:]) + common.HashWrite(hasher, []byte{0xff}) + common.HashWrite(hasher, seed) - x[i] = hasher.Sum(nil) + x[i] = common.HashSum(hasher, params.N) } return LmsOtsPrivateKey{ @@ -84,9 +76,9 @@ func (x *LmsOtsPrivateKey) Public() (LmsOtsPublicKey, error) { hasher := params.H.New() binary.BigEndian.PutUint32(be32[:], x.q) - hash_write(hasher, x.id[:]) - hash_write(hasher, be32[:]) - hash_write(hasher, common.D_PBLC[:]) + common.HashWrite(hasher, x.id[:]) + common.HashWrite(hasher, be32[:]) + common.HashWrite(hasher, common.D_PBLC[:]) for i := uint64(0); i < params.P; i++ { tmp = make([]byte, len(x.x[i])) @@ -98,23 +90,23 @@ func (x *LmsOtsPrivateKey) Public() (LmsOtsPublicKey, error) { binary.BigEndian.PutUint32(be32[:], x.q) binary.BigEndian.PutUint16(be16[:], uint16(i)) - hash_write(inner, x.id[:]) - hash_write(inner, be32[:]) - hash_write(inner, be16[:]) - hash_write(inner, []byte{byte(j)}) - hash_write(inner, tmp) + common.HashWrite(inner, x.id[:]) + common.HashWrite(inner, be32[:]) + common.HashWrite(inner, be16[:]) + common.HashWrite(inner, []byte{byte(j)}) + common.HashWrite(inner, tmp) - tmp = inner.Sum(nil) + tmp = common.HashSum(inner, params.N) } - hash_write(hasher, tmp) + common.HashWrite(hasher, tmp) } return LmsOtsPublicKey{ typecode: x.typecode, q: x.q, id: x.id, - k: hasher.Sum(nil), + k: common.HashSum(hasher, params.N), }, nil } @@ -145,13 +137,13 @@ func (x *LmsOtsPrivateKey) Sign(msg []byte, rng io.Reader) (LmsOtsSignature, err binary.BigEndian.PutUint32(be32[:], x.q) - hash_write(hasher, x.id[:]) - hash_write(hasher, be32[:]) - hash_write(hasher, common.D_MESG[:]) - hash_write(hasher, c) - hash_write(hasher, msg) + common.HashWrite(hasher, x.id[:]) + common.HashWrite(hasher, be32[:]) + common.HashWrite(hasher, common.D_MESG[:]) + common.HashWrite(hasher, c) + common.HashWrite(hasher, msg) - q := hasher.Sum(nil) + q := common.HashSum(hasher, params.N) expanded, err := common.Expand(q, x.typecode) if err != nil { return LmsOtsSignature{}, err @@ -170,13 +162,13 @@ func (x *LmsOtsPrivateKey) Sign(msg []byte, rng io.Reader) (LmsOtsSignature, err binary.BigEndian.PutUint32(be32[:], x.q) binary.BigEndian.PutUint16(be16[:], uint16(i)) - hash_write(inner, x.id[:]) - hash_write(inner, be32[:]) - hash_write(inner, be16[:]) - hash_write(inner, []byte{byte(j)}) - hash_write(inner, y[i]) + common.HashWrite(inner, x.id[:]) + common.HashWrite(inner, be32[:]) + common.HashWrite(inner, be16[:]) + common.HashWrite(inner, []byte{byte(j)}) + common.HashWrite(inner, y[i]) - y[i] = inner.Sum(nil) + y[i] = common.HashSum(inner, params.N) } // y[i] is now the correct value diff --git a/lms/ots/public.go b/lms/ots/public.go index 86618f5..1080fe4 100644 --- a/lms/ots/public.go +++ b/lms/ots/public.go @@ -38,7 +38,7 @@ func (sig *LmsOtsSignature) RecoverPublicKey(msg []byte, pubtype common.LmsOtsAl return LmsOtsPublicKey{}, false } hasher := params.H.New() - hash_len := hasher.Size() + hash_len := int(params.N) // verify length of nonce if len(sig.c) != hash_len { @@ -57,22 +57,22 @@ func (sig *LmsOtsSignature) RecoverPublicKey(msg []byte, pubtype common.LmsOtsAl binary.BigEndian.PutUint32(be32[:], q) - hash_write(hasher, id[:]) - hash_write(hasher, be32[:]) - hash_write(hasher, common.D_MESG[:]) - hash_write(hasher, sig.c) - hash_write(hasher, msg) + common.HashWrite(hasher, id[:]) + common.HashWrite(hasher, be32[:]) + common.HashWrite(hasher, common.D_MESG[:]) + common.HashWrite(hasher, sig.c) + common.HashWrite(hasher, msg) - Q := hasher.Sum(nil) + Q := common.HashSum(hasher, params.N) expanded, err := common.Expand(Q, sig.typecode) if err != nil { return LmsOtsPublicKey{}, false } hasher.Reset() - hash_write(hasher, id[:]) - hash_write(hasher, be32[:]) - hash_write(hasher, common.D_PBLC[:]) + common.HashWrite(hasher, id[:]) + common.HashWrite(hasher, be32[:]) + common.HashWrite(hasher, common.D_PBLC[:]) for i := uint64(0); i < params.P; i++ { a := uint64(expanded[i]) @@ -85,23 +85,23 @@ func (sig *LmsOtsSignature) RecoverPublicKey(msg []byte, pubtype common.LmsOtsAl binary.BigEndian.PutUint32(be32[:], q) binary.BigEndian.PutUint16(be16[:], uint16(i)) - hash_write(inner, id[:]) - hash_write(inner, be32[:]) - hash_write(inner, be16[:]) - hash_write(inner, []byte{byte(j)}) - hash_write(inner, tmp) + common.HashWrite(inner, id[:]) + common.HashWrite(inner, be32[:]) + common.HashWrite(inner, be16[:]) + common.HashWrite(inner, []byte{byte(j)}) + common.HashWrite(inner, tmp) - tmp = inner.Sum(nil) + tmp = common.HashSum(inner, params.N) } - hash_write(hasher, tmp) + common.HashWrite(hasher, tmp) } return LmsOtsPublicKey{ typecode: sig.typecode, q: q, id: id, - k: hasher.Sum(nil), + k: common.HashSum(hasher, params.N), }, true } @@ -118,7 +118,7 @@ func LmsOtsPublicKeyFromBytes(b []byte) (LmsOtsPublicKey, error) { return LmsOtsPublicKey{}, errors.New("LmsOtsPublicKeyFromBytes(): OTS public key too short") } // The typecode is bytes 0-3 (4 bytes) - typecode, err := common.Uint32ToLmsType(binary.BigEndian.Uint32(b[0:4])).LmsOtsType() + typecode, err := common.Uint32ToLmotsType(binary.BigEndian.Uint32(b[0:4])).LmsOtsType() if err != nil { return LmsOtsPublicKey{}, err } diff --git a/lms/ots/signature.go b/lms/ots/signature.go index dc090e3..5774a30 100644 --- a/lms/ots/signature.go +++ b/lms/ots/signature.go @@ -17,7 +17,7 @@ func LmsOtsSignatureFromBytes(b []byte) (LmsOtsSignature, error) { } // Typecode is the first 4 bytes - typecode := common.Uint32ToLmsType(binary.BigEndian.Uint32(b[0:4])) + typecode := common.Uint32ToLmotsType(binary.BigEndian.Uint32(b[0:4])) // Panic if not a valid LM-OTS algorithm: params, err := typecode.Params() if err != nil {