diff --git a/arbo/verifier.go b/arbo/verifier.go index 83ea4a7..3ffe3f8 100644 --- a/arbo/verifier.go +++ b/arbo/verifier.go @@ -8,13 +8,16 @@ import ( // prevLevel function calculates the previous level of the merkle tree given the // current leaf, the current path bit of the leaf, the validity of the sibling // and the sibling itself. -func prevLevel(api frontend.API, leaf, ipath, valid, sibling frontend.Variable) frontend.Variable { +func prevLevel(api frontend.API, leaf, ipath, valid, sibling frontend.Variable) (frontend.Variable, error) { // l, r = path == 1 ? sibling, current : current, sibling l, r := api.Select(ipath, sibling, leaf), api.Select(ipath, leaf, sibling) // intermediateLeafKey = H(l | r) - intermediateLeafKey := poseidon.Hash(api, l, r) + intermediateLeafKey, err := poseidon.Hash(api, l, r) + if err != nil { + return 0, err + } // newCurrent = valid == 1 ? current : intermediateLeafKey - return api.Select(valid, intermediateLeafKey, leaf) + return api.Select(valid, intermediateLeafKey, leaf), nil } // strictCmp function compares a and b and returns: @@ -41,7 +44,10 @@ func CheckProof(api frontend.API, key, value, root frontend.Variable, siblings [ path := api.ToBinary(key, api.Compiler().FieldBitLen()) // calculate the value leaf to start with it to rebuild the tree // leafValue = H(key | value | 1) - leafValue := poseidon.Hash(api, key, value, 1) + leafValue, err := poseidon.Hash(api, key, value, 1) + if err != nil { + return err + } // calculate the root and compare it with the provided one prevLeaf := leafValue currentLeaf := leafValue @@ -52,7 +58,10 @@ func CheckProof(api frontend.API, key, value, root frontend.Variable, siblings [ prevLeaf = currentLeaf prevSibling = siblings[i] // compute the next leaf value - currentLeaf = prevLevel(api, currentLeaf, path[i], valid, siblings[i]) + currentLeaf, err = prevLevel(api, currentLeaf, path[i], valid, siblings[i]) + if err != nil { + return err + } } api.AssertIsEqual(currentLeaf, root) return nil diff --git a/arbo/verifier_test.go b/arbo/verifier_test.go index ace9896..eca26c9 100644 --- a/arbo/verifier_test.go +++ b/arbo/verifier_test.go @@ -20,10 +20,10 @@ import ( ) type testVerifierCircuit struct { - Root frontend.Variable - Key frontend.Variable - Value frontend.Variable - Siblings [160]frontend.Variable + Root frontend.Variable + Key frontend.Variable + Value frontend.Variable + Siblings [160]frontend.Variable } func (circuit *testVerifierCircuit) Define(api frontend.API) error { @@ -74,10 +74,10 @@ func successInputs(t *testing.T, n int) testVerifierCircuit { root, err := tree.Root() c.Assert(err, qt.IsNil) return testVerifierCircuit{ - Root: arbo.BytesLEToBigInt(root), - Key: arbo.BytesLEToBigInt(key), - Value: value, - Siblings: siblings, + Root: arbo.BytesLEToBigInt(root), + Key: arbo.BytesLEToBigInt(key), + Value: value, + Siblings: siblings, } } diff --git a/go.mod b/go.mod index faed557..95efa4a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/consensys/gnark-crypto v0.14.0 github.com/frankban/quicktest v1.14.6 github.com/iden3/go-iden3-crypto v0.0.17 - github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241111130906-b8e8592696c6 + github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241113074257-1a711ad38a6b go.vocdoni.io/dvote v1.10.2-0.20241024102542-c1ce6d744bc5 ) diff --git a/go.sum b/go.sum index 7054df7..49cd4e6 100644 --- a/go.sum +++ b/go.sum @@ -127,6 +127,8 @@ github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a h1:1ur3QoCqvE5f github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a/go.mod h1:RRCYJbIwD5jmqPI9XoAFR0OcDxqUctll6zUj/+B4S48= github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241111130906-b8e8592696c6 h1:Lnikgc2rZsnxZDwGbPhlsmq0yiLRotKDOGnuOOYU37o= github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241111130906-b8e8592696c6/go.mod h1:B43i83saYhSReG+jNAj0igxWcZYHGjF2AeXunaXnCQE= +github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241113074257-1a711ad38a6b h1:Hf7CZnNm8XhGBb0XBYVDUir7iNhIryJSmSUz+wlV7vw= +github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241113074257-1a711ad38a6b/go.mod h1:B43i83saYhSReG+jNAj0igxWcZYHGjF2AeXunaXnCQE= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= diff --git a/poseidon/poseidon.go b/poseidon/poseidon.go index 7b24b3d..51577a3 100644 --- a/poseidon/poseidon.go +++ b/poseidon/poseidon.go @@ -1,22 +1,89 @@ package poseidon import ( + "fmt" "math/big" "github.com/consensys/gnark/frontend" ) +// Poseidon struct represents a Poseidon hash function object that can be used +// to hash inputs. The Poseidon hash function is a cryptographic hash function +// that is designed to be efficient in terms of both time and space. It is +// based on the Merkle-Damgård construction and uses a sponge construction +// with a permutation that is based on the Poseidon permutation. The Poseidon +// permutation is a round-based permutation that uses a series of operations +// to mix the input data and produce the output hash. The Poseidon hash +// function is designed to be resistant to various cryptographic attacks, +// including differential and linear cryptanalysis, and is suitable for use +// in a wide range of applications, including blockchain and cryptocurrency +// systems. type Poseidon struct { api frontend.API data []frontend.Variable } -func Hash(api frontend.API, inputs ...frontend.Variable) frontend.Variable { +// Hash returns the hash of the provided inputs using the Poseidon hash +// function. This function supports up to 16 inputs. If more than 16 inputs +// are provided, it will return an error. If the number of inputs is 16 or less, +// it will return the hash of the inputs. This function is equivalent to calling +// NewPoseidon and then calling Write and Sum on the returned Poseidon object. +func Hash(api frontend.API, inputs ...frontend.Variable) (frontend.Variable, error) { h := NewPoseidon(api) - h.Write(inputs...) - return h.Sum() + if err := h.Write(inputs...); err != nil { + return 0, err + } + return h.Sum(), nil } +// MultiHash returns the hash of the provided inputs using the Poseidon hash +// function. This function supports up to 256 inputs. If more than 256 inputs +// are provided, it will return an error. If the number of inputs is 16 or +// less, it will return the result of Hash function. If the number of inputs +// is greater than 16, it will calculate the hash of the inputs by dividing +// them into chunks of 16 inputs each, hashing each chunk, and then hashing +// the resulting chunk hashes. +func MultiHash(api frontend.API, inputs ...frontend.Variable) (frontend.Variable, error) { + if l := len(inputs); l < 16 { + return Hash(api, inputs...) + } else if l > 256 { + return 0, fmt.Errorf("the maximum number of inputs supported is 256") + } + // calculate chunk hashes + hashed := []frontend.Variable{} + chunk := []frontend.Variable{} + hasher := NewPoseidon(api) + for _, input := range inputs { + if len(chunk) == 16 { + if err := hasher.Write(chunk...); err != nil { + return 0, err + } + hashed = append(hashed, hasher.Sum()) + chunk = []frontend.Variable{} + hasher.Reset() + } + chunk = append(chunk, input) + } + // if the final chunk is not empty, hash it to get the last chunk hash + if len(chunk) > 0 { + if err := hasher.Write(chunk...); err != nil { + return 0, err + } + hashed = append(hashed, hasher.Sum()) + hasher.Reset() + } + // if there is only one chunk, return its hash + if len(hashed) == 1 { + return hashed[0], nil + } + // return the hash of all chunk hashes + if err := hasher.Write(hashed...); err != nil { + return 0, err + } + return hasher.Sum(), nil +} + +// NewPoseidon returns a new Poseidon object that can be used to hash inputs. func NewPoseidon(api frontend.API) Poseidon { return Poseidon{ api: api, @@ -24,17 +91,25 @@ func NewPoseidon(api frontend.API) Poseidon { } } -func (h *Poseidon) Write(data ...frontend.Variable) { +// Write adds the provided inputs to the Poseidon object. If the number of +// inputs is greater than 16, it will return an error. +func (h *Poseidon) Write(data ...frontend.Variable) error { + if len(h.data)+len(data) > 16 { + return fmt.Errorf("poseidon hash only supports up to 16 inputs, use MultiHash instead") + } h.data = append(h.data, data...) + return nil } +// Reset resets the Poseidon object, removing all written inputs. func (h *Poseidon) Reset() { h.data = []frontend.Variable{} } +// Sum returns the hash of the inputs written to the Poseidon object. func (h *Poseidon) Sum() frontend.Variable { nInputs := len(h.data) - // And rounded up to nearest integer that divides by t + // and rounded up to nearest integer that divides by t nRoundsPC := [16]int{56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68} t := nInputs + 1 nRoundsF := 8 @@ -69,9 +144,7 @@ func (h *Poseidon) Sum() frontend.Variable { state = h.mix(state, p) for r := 0; r < nRoundsP; r++ { - state[0] = h.sigma(state[0]) - state[0] = h.api.Add(state[0], c[(nRoundsF/2+1)*t+r]) newState0 := frontend.Variable(0) for j := 0; j < len(state); j++ { diff --git a/poseidon/poseidon_test.go b/poseidon/poseidon_test.go index 2eb3612..c884734 100644 --- a/poseidon/poseidon_test.go +++ b/poseidon/poseidon_test.go @@ -1,14 +1,20 @@ package poseidon import ( + "crypto/rand" + "fmt" "math/big" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/profile" "github.com/consensys/gnark/test" hash "github.com/iden3/go-iden3-crypto/poseidon" + "github.com/vocdoni/vocdoni-z-sandbox/hash/poseidon" ) type testPoseidonCiruit struct { @@ -17,7 +23,10 @@ type testPoseidonCiruit struct { } func (circuit *testPoseidonCiruit) Define(api frontend.API) error { - h := Hash(api, circuit.Data) + h, err := Hash(api, circuit.Data) + if err != nil { + return err + } api.AssertIsEqual(h, circuit.Hash) return nil } @@ -31,5 +40,52 @@ func TestPoseidon(t *testing.T) { assignment.Data = input assignment.Hash, _ = hash.Hash([]*big.Int{input}) - assert.SolvingSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254), test.WithBackends(backend.PLONK)) + assert.SolvingSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16)) +} + +type testMultiPoseidonCircuit struct { + Data [32]frontend.Variable + Hash frontend.Variable `gnark:",public"` +} + +func (circuit *testMultiPoseidonCircuit) Define(api frontend.API) error { + h, err := MultiHash(api, circuit.Data[:]...) + if err != nil { + return err + } + api.AssertIsEqual(h, circuit.Hash) + return nil +} + +func TestMultiPoseidon(t *testing.T) { + p := profile.Start() + now := time.Now() + _, _ = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &testMultiPoseidonCircuit{}) + fmt.Println("elapsed", time.Since(now)) + p.Stop() + fmt.Println("constrains", p.NbConstraints()) + + var ( + inputs [32]*big.Int + data [32]frontend.Variable + ) + for i := 0; i < 32; i++ { + // generate random input + r, err := rand.Int(rand.Reader, ecc.BN254.ScalarField()) + if err != nil { + t.Fatal(err) + } + inputs[i] = r + data[i] = r + } + hash, err := poseidon.MultiPoseidon(inputs[:]...) + if err != nil { + t.Fatal(err) + } + witness := &testMultiPoseidonCircuit{ + Data: data, + Hash: hash, + } + assert := test.NewAssert(t) + assert.SolvingSucceeded(&testMultiPoseidonCircuit{}, witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16)) } diff --git a/smt/utils.go b/smt/utils.go index d4da178..46e3054 100644 --- a/smt/utils.go +++ b/smt/utils.go @@ -10,7 +10,7 @@ import ( // provided, hashing it with the predefined hashing function 'H': // // newLeafValue = H(key | value | 1) -func endLeafValue(api frontend.API, key, value frontend.Variable) frontend.Variable { +func endLeafValue(api frontend.API, key, value frontend.Variable) (frontend.Variable, error) { return poseidon.Hash(api, key, value, 1) } @@ -18,7 +18,7 @@ func endLeafValue(api frontend.API, key, value frontend.Variable) frontend.Varia // key-value pair provided, hashing it with the predefined hashing function 'H': // // intermediateLeafValue = H(l | r) -func intermediateLeafValue(api frontend.API, l, r frontend.Variable) frontend.Variable { +func intermediateLeafValue(api frontend.API, l, r frontend.Variable) (frontend.Variable, error) { return poseidon.Hash(api, l, r) } diff --git a/smt/verifier.go b/smt/verifier.go index e280b69..4ef173f 100644 --- a/smt/verifier.go +++ b/smt/verifier.go @@ -25,9 +25,15 @@ func smtverifier(api frontend.API, // [STEP 1] // hash1Old = H(oldKey | oldValue | 1) - hash1Old := endLeafValue(api, oldKey, oldValue) + hash1Old, err := endLeafValue(api, oldKey, oldValue) + if err != nil { + return err + } // hash1New = H(key | value | 1) - hash1New := endLeafValue(api, key, value) + hash1New, err := endLeafValue(api, key, value) + if err != nil { + return err + } // [STEP 2] // component n2bNew = Num2Bits_strict(); @@ -79,7 +85,7 @@ func smtverifier(api frontend.API, child = levels[i+1] } - levels[i] = smtVerifierLevel(api, + levels[i], err = smtVerifierLevel(api, sm[i][0], // st_top sm[i][1], // st_iold sm[i][3], // st_new @@ -89,6 +95,9 @@ func smtverifier(api frontend.API, n2bNew[i], // lrbit child, // child ) + if err != nil { + return err + } } // component areKeyEquals = IsEqual(); @@ -170,7 +179,7 @@ func smtVerifierSM(api frontend.API, } func smtVerifierLevel(api frontend.API, stTop, stIold, stInew, sibling, - old1leaf, new1leaf, lrbit, child frontend.Variable) frontend.Variable { + old1leaf, new1leaf, lrbit, child frontend.Variable) (frontend.Variable, error) { // component switcher = Switcher(); // switcher.sel <== lrbit; // switcher.L <== child; @@ -179,11 +188,14 @@ func smtVerifierLevel(api frontend.API, stTop, stIold, stInew, sibling, // component proofHash = SMTHash2(); // proofHash.L <== switcher.outL; // proofHash.R <== switcher.outR; - proofHash := intermediateLeafValue(api, l, r) + proofHash, err := intermediateLeafValue(api, l, r) + if err != nil { + return 0, err + } // aux[0] <== proofHash.out * st_top; aux0 := api.Mul(proofHash, stTop) // aux[1] <== old1leaf * st_iold; aux1 := api.Mul(old1leaf, stIold) // root <== aux[0] + aux[1] + new1leaf * st_inew; - return api.Add(aux0, aux1, api.Mul(new1leaf, stInew)) + return api.Add(aux0, aux1, api.Mul(new1leaf, stInew)), nil }