Skip to content

Commit

Permalink
multiposeidon circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasmenendez committed Nov 13, 2024
1 parent 32c548f commit 19e0c1a
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 31 deletions.
19 changes: 14 additions & 5 deletions arbo/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ import (
// prevLevel function calculates the previous level of the merkle tree given the
// current leaf, the current path bit of the leaf, the validity of the sibling
// and the sibling itself.
func prevLevel(api frontend.API, leaf, ipath, valid, sibling frontend.Variable) frontend.Variable {
func prevLevel(api frontend.API, leaf, ipath, valid, sibling frontend.Variable) (frontend.Variable, error) {
// l, r = path == 1 ? sibling, current : current, sibling
l, r := api.Select(ipath, sibling, leaf), api.Select(ipath, leaf, sibling)
// intermediateLeafKey = H(l | r)
intermediateLeafKey := poseidon.Hash(api, l, r)
intermediateLeafKey, err := poseidon.Hash(api, l, r)
if err != nil {
return 0, err
}
// newCurrent = valid == 1 ? current : intermediateLeafKey
return api.Select(valid, intermediateLeafKey, leaf)
return api.Select(valid, intermediateLeafKey, leaf), nil
}

// strictCmp function compares a and b and returns:
Expand All @@ -41,7 +44,10 @@ func CheckProof(api frontend.API, key, value, root frontend.Variable, siblings [
path := api.ToBinary(key, api.Compiler().FieldBitLen())
// calculate the value leaf to start with it to rebuild the tree
// leafValue = H(key | value | 1)
leafValue := poseidon.Hash(api, key, value, 1)
leafValue, err := poseidon.Hash(api, key, value, 1)
if err != nil {
return err
}
// calculate the root and compare it with the provided one
prevLeaf := leafValue
currentLeaf := leafValue
Expand All @@ -52,7 +58,10 @@ func CheckProof(api frontend.API, key, value, root frontend.Variable, siblings [
prevLeaf = currentLeaf
prevSibling = siblings[i]
// compute the next leaf value
currentLeaf = prevLevel(api, currentLeaf, path[i], valid, siblings[i])
currentLeaf, err = prevLevel(api, currentLeaf, path[i], valid, siblings[i])
if err != nil {
return err
}
}
api.AssertIsEqual(currentLeaf, root)
return nil
Expand Down
16 changes: 8 additions & 8 deletions arbo/verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import (
)

type testVerifierCircuit struct {
Root frontend.Variable
Key frontend.Variable
Value frontend.Variable
Siblings [160]frontend.Variable
Root frontend.Variable
Key frontend.Variable
Value frontend.Variable
Siblings [160]frontend.Variable
}

func (circuit *testVerifierCircuit) Define(api frontend.API) error {
Expand Down Expand Up @@ -74,10 +74,10 @@ func successInputs(t *testing.T, n int) testVerifierCircuit {
root, err := tree.Root()
c.Assert(err, qt.IsNil)
return testVerifierCircuit{
Root: arbo.BytesLEToBigInt(root),
Key: arbo.BytesLEToBigInt(key),
Value: value,
Siblings: siblings,
Root: arbo.BytesLEToBigInt(root),
Key: arbo.BytesLEToBigInt(key),
Value: value,
Siblings: siblings,
}
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/consensys/gnark-crypto v0.14.0
github.com/frankban/quicktest v1.14.6
github.com/iden3/go-iden3-crypto v0.0.17
github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241111130906-b8e8592696c6
github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241113074257-1a711ad38a6b
go.vocdoni.io/dvote v1.10.2-0.20241024102542-c1ce6d744bc5
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a h1:1ur3QoCqvE5f
github.com/syndtr/goleveldb v1.0.1-0.20220614013038-64ee5596c38a/go.mod h1:RRCYJbIwD5jmqPI9XoAFR0OcDxqUctll6zUj/+B4S48=
github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241111130906-b8e8592696c6 h1:Lnikgc2rZsnxZDwGbPhlsmq0yiLRotKDOGnuOOYU37o=
github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241111130906-b8e8592696c6/go.mod h1:B43i83saYhSReG+jNAj0igxWcZYHGjF2AeXunaXnCQE=
github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241113074257-1a711ad38a6b h1:Hf7CZnNm8XhGBb0XBYVDUir7iNhIryJSmSUz+wlV7vw=
github.com/vocdoni/vocdoni-z-sandbox v0.0.0-20241113074257-1a711ad38a6b/go.mod h1:B43i83saYhSReG+jNAj0igxWcZYHGjF2AeXunaXnCQE=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
Expand Down
87 changes: 80 additions & 7 deletions poseidon/poseidon.go
Original file line number Diff line number Diff line change
@@ -1,40 +1,115 @@
package poseidon

import (
"fmt"
"math/big"

"github.com/consensys/gnark/frontend"
)

// Poseidon struct represents a Poseidon hash function object that can be used
// to hash inputs. The Poseidon hash function is a cryptographic hash function
// that is designed to be efficient in terms of both time and space. It is
// based on the Merkle-Damgård construction and uses a sponge construction
// with a permutation that is based on the Poseidon permutation. The Poseidon
// permutation is a round-based permutation that uses a series of operations
// to mix the input data and produce the output hash. The Poseidon hash
// function is designed to be resistant to various cryptographic attacks,
// including differential and linear cryptanalysis, and is suitable for use
// in a wide range of applications, including blockchain and cryptocurrency
// systems.
type Poseidon struct {
api frontend.API
data []frontend.Variable
}

func Hash(api frontend.API, inputs ...frontend.Variable) frontend.Variable {
// Hash returns the hash of the provided inputs using the Poseidon hash
// function. This function supports up to 16 inputs. If more than 16 inputs
// are provided, it will return an error. If the number of inputs is 16 or less,
// it will return the hash of the inputs. This function is equivalent to calling
// NewPoseidon and then calling Write and Sum on the returned Poseidon object.
func Hash(api frontend.API, inputs ...frontend.Variable) (frontend.Variable, error) {
h := NewPoseidon(api)
h.Write(inputs...)
return h.Sum()
if err := h.Write(inputs...); err != nil {
return 0, err
}
return h.Sum(), nil
}

// MultiHash returns the hash of the provided inputs using the Poseidon hash
// function. This function supports up to 256 inputs. If more than 256 inputs
// are provided, it will return an error. If the number of inputs is 16 or
// less, it will return the result of Hash function. If the number of inputs
// is greater than 16, it will calculate the hash of the inputs by dividing
// them into chunks of 16 inputs each, hashing each chunk, and then hashing
// the resulting chunk hashes.
func MultiHash(api frontend.API, inputs ...frontend.Variable) (frontend.Variable, error) {
if l := len(inputs); l < 16 {
return Hash(api, inputs...)
} else if l > 256 {
return 0, fmt.Errorf("the maximum number of inputs supported is 256")
}
// calculate chunk hashes
hashed := []frontend.Variable{}
chunk := []frontend.Variable{}
hasher := NewPoseidon(api)
for _, input := range inputs {
if len(chunk) == 16 {
if err := hasher.Write(chunk...); err != nil {
return 0, err
}
hashed = append(hashed, hasher.Sum())
chunk = []frontend.Variable{}
hasher.Reset()
}
chunk = append(chunk, input)
}
// if the final chunk is not empty, hash it to get the last chunk hash
if len(chunk) > 0 {
if err := hasher.Write(chunk...); err != nil {
return 0, err
}
hashed = append(hashed, hasher.Sum())
hasher.Reset()
}
// if there is only one chunk, return its hash
if len(hashed) == 1 {
return hashed[0], nil
}
// return the hash of all chunk hashes
if err := hasher.Write(hashed...); err != nil {
return 0, err
}
return hasher.Sum(), nil
}

// NewPoseidon returns a new Poseidon object that can be used to hash inputs.
func NewPoseidon(api frontend.API) Poseidon {
return Poseidon{
api: api,
data: []frontend.Variable{},
}
}

func (h *Poseidon) Write(data ...frontend.Variable) {
// Write adds the provided inputs to the Poseidon object. If the number of
// inputs is greater than 16, it will return an error.
func (h *Poseidon) Write(data ...frontend.Variable) error {
if len(h.data)+len(data) > 16 {
return fmt.Errorf("poseidon hash only supports up to 16 inputs, use MultiHash instead")
}
h.data = append(h.data, data...)
return nil
}

// Reset resets the Poseidon object, removing all written inputs.
func (h *Poseidon) Reset() {
h.data = []frontend.Variable{}
}

// Sum returns the hash of the inputs written to the Poseidon object.
func (h *Poseidon) Sum() frontend.Variable {
nInputs := len(h.data)
// And rounded up to nearest integer that divides by t
// and rounded up to nearest integer that divides by t
nRoundsPC := [16]int{56, 57, 56, 60, 60, 63, 64, 63, 60, 66, 60, 65, 70, 60, 64, 68}
t := nInputs + 1
nRoundsF := 8
Expand Down Expand Up @@ -69,9 +144,7 @@ func (h *Poseidon) Sum() frontend.Variable {
state = h.mix(state, p)

for r := 0; r < nRoundsP; r++ {

state[0] = h.sigma(state[0])

state[0] = h.api.Add(state[0], c[(nRoundsF/2+1)*t+r])
newState0 := frontend.Variable(0)
for j := 0; j < len(state); j++ {
Expand Down
60 changes: 58 additions & 2 deletions poseidon/poseidon_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package poseidon

import (
"crypto/rand"
"fmt"
"math/big"
"testing"
"time"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/profile"
"github.com/consensys/gnark/test"
hash "github.com/iden3/go-iden3-crypto/poseidon"
"github.com/vocdoni/vocdoni-z-sandbox/hash/poseidon"
)

type testPoseidonCiruit struct {
Expand All @@ -17,7 +23,10 @@ type testPoseidonCiruit struct {
}

func (circuit *testPoseidonCiruit) Define(api frontend.API) error {
h := Hash(api, circuit.Data)
h, err := Hash(api, circuit.Data)
if err != nil {
return err
}
api.AssertIsEqual(h, circuit.Hash)
return nil
}
Expand All @@ -31,5 +40,52 @@ func TestPoseidon(t *testing.T) {
assignment.Data = input
assignment.Hash, _ = hash.Hash([]*big.Int{input})

assert.SolvingSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254), test.WithBackends(backend.PLONK))
assert.SolvingSucceeded(&circuit, &assignment, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16))
}

type testMultiPoseidonCircuit struct {
Data [32]frontend.Variable
Hash frontend.Variable `gnark:",public"`
}

func (circuit *testMultiPoseidonCircuit) Define(api frontend.API) error {
h, err := MultiHash(api, circuit.Data[:]...)
if err != nil {
return err
}
api.AssertIsEqual(h, circuit.Hash)
return nil
}

func TestMultiPoseidon(t *testing.T) {
p := profile.Start()
now := time.Now()
_, _ = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &testMultiPoseidonCircuit{})
fmt.Println("elapsed", time.Since(now))
p.Stop()
fmt.Println("constrains", p.NbConstraints())

var (
inputs [32]*big.Int
data [32]frontend.Variable
)
for i := 0; i < 32; i++ {
// generate random input
r, err := rand.Int(rand.Reader, ecc.BN254.ScalarField())
if err != nil {
t.Fatal(err)
}
inputs[i] = r
data[i] = r
}
hash, err := poseidon.MultiPoseidon(inputs[:]...)
if err != nil {
t.Fatal(err)
}
witness := &testMultiPoseidonCircuit{
Data: data,
Hash: hash,
}
assert := test.NewAssert(t)
assert.SolvingSucceeded(&testMultiPoseidonCircuit{}, witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16))
}
4 changes: 2 additions & 2 deletions smt/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import (
// provided, hashing it with the predefined hashing function 'H':
//
// newLeafValue = H(key | value | 1)
func endLeafValue(api frontend.API, key, value frontend.Variable) frontend.Variable {
func endLeafValue(api frontend.API, key, value frontend.Variable) (frontend.Variable, error) {
return poseidon.Hash(api, key, value, 1)
}

// intermediateLeafValue returns the encoded intermediate leaf value for the
// key-value pair provided, hashing it with the predefined hashing function 'H':
//
// intermediateLeafValue = H(l | r)
func intermediateLeafValue(api frontend.API, l, r frontend.Variable) frontend.Variable {
func intermediateLeafValue(api frontend.API, l, r frontend.Variable) (frontend.Variable, error) {
return poseidon.Hash(api, l, r)
}

Expand Down
24 changes: 18 additions & 6 deletions smt/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@ func smtverifier(api frontend.API,

// [STEP 1]
// hash1Old = H(oldKey | oldValue | 1)
hash1Old := endLeafValue(api, oldKey, oldValue)
hash1Old, err := endLeafValue(api, oldKey, oldValue)
if err != nil {
return err
}
// hash1New = H(key | value | 1)
hash1New := endLeafValue(api, key, value)
hash1New, err := endLeafValue(api, key, value)
if err != nil {
return err
}

// [STEP 2]
// component n2bNew = Num2Bits_strict();
Expand Down Expand Up @@ -79,7 +85,7 @@ func smtverifier(api frontend.API,
child = levels[i+1]
}

levels[i] = smtVerifierLevel(api,
levels[i], err = smtVerifierLevel(api,
sm[i][0], // st_top
sm[i][1], // st_iold
sm[i][3], // st_new
Expand All @@ -89,6 +95,9 @@ func smtverifier(api frontend.API,
n2bNew[i], // lrbit
child, // child
)
if err != nil {
return err
}
}

// component areKeyEquals = IsEqual();
Expand Down Expand Up @@ -170,7 +179,7 @@ func smtVerifierSM(api frontend.API,
}

func smtVerifierLevel(api frontend.API, stTop, stIold, stInew, sibling,
old1leaf, new1leaf, lrbit, child frontend.Variable) frontend.Variable {
old1leaf, new1leaf, lrbit, child frontend.Variable) (frontend.Variable, error) {
// component switcher = Switcher();
// switcher.sel <== lrbit;
// switcher.L <== child;
Expand All @@ -179,11 +188,14 @@ func smtVerifierLevel(api frontend.API, stTop, stIold, stInew, sibling,
// component proofHash = SMTHash2();
// proofHash.L <== switcher.outL;
// proofHash.R <== switcher.outR;
proofHash := intermediateLeafValue(api, l, r)
proofHash, err := intermediateLeafValue(api, l, r)
if err != nil {
return 0, err
}
// aux[0] <== proofHash.out * st_top;
aux0 := api.Mul(proofHash, stTop)
// aux[1] <== old1leaf * st_iold;
aux1 := api.Mul(old1leaf, stIold)
// root <== aux[0] + aux[1] + new1leaf * st_inew;
return api.Add(aux0, aux1, api.Mul(new1leaf, stInew))
return api.Add(aux0, aux1, api.Mul(new1leaf, stInew)), nil
}

0 comments on commit 19e0c1a

Please sign in to comment.