diff --git a/tree/arbo/verifier.go b/tree/arbo/verifier.go index 8dc3afc..4f52b7a 100644 --- a/tree/arbo/verifier.go +++ b/tree/arbo/verifier.go @@ -5,14 +5,12 @@ import ( "github.com/vocdoni/gnark-crypto-primitives/utils" ) -type Hash func(frontend.API, ...frontend.Variable) (frontend.Variable, error) - // intermediateLeafKey function calculates the intermediate leaf key of the // path position provided. The leaf key is calculated by hashing the sibling // and the key provided. The position of the sibling and the key is decided by // the path position. If the current sibling is not valid, the method will // return the key provided. -func intermediateLeafKey(api frontend.API, hFn Hash, ipath, valid, key, sibling frontend.Variable) (frontend.Variable, error) { +func intermediateLeafKey(api frontend.API, hFn utils.Hasher, ipath, valid, key, sibling frontend.Variable) (frontend.Variable, error) { // l, r = path == 1 ? sibling, key : key, sibling l, r := api.Select(ipath, sibling, key), api.Select(ipath, key, sibling) // intermediateLeafKey = H(l | r) @@ -35,7 +33,7 @@ func isValid(api frontend.API, sibling, prevSibling, leaf, prevLeaf frontend.Var // CheckInclusionProof receives the parameters of an inclusion proof of Arbo to // recalculate the root with them and compare it with the provided one, // verifiying the proof. -func CheckInclusionProof(api frontend.API, hFn Hash, key, value, root frontend.Variable, +func CheckInclusionProof(api frontend.API, hFn utils.Hasher, key, value, root frontend.Variable, siblings []frontend.Variable, ) error { // calculate the path from the provided key to decide which leaf is the diff --git a/tree/smt/hash.go b/tree/smt/hash.go index 2fc896b..a189f0f 100644 --- a/tree/smt/hash.go +++ b/tree/smt/hash.go @@ -2,18 +2,26 @@ package smt import ( "github.com/consensys/gnark/frontend" - - "github.com/mdehoog/poseidon/circuits/poseidon" + "github.com/vocdoni/gnark-crypto-primitives/utils" ) // based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smthash_poseidon.circom -func Hash1(api frontend.API, key, value frontend.Variable) frontend.Variable { - inputs := []frontend.Variable{key, value, 1} - return poseidon.Hash(api, inputs) +func Hash1(api frontend.API, hFn utils.Hasher, key frontend.Variable, values ...frontend.Variable) frontend.Variable { + inputs := []frontend.Variable{key} + inputs = append(inputs, values...) + inputs = append(inputs, 1) + hash, err := hFn(api, inputs...) + if err != nil { + panic(err) + } + return hash } -func Hash2(api frontend.API, l, r frontend.Variable) frontend.Variable { - inputs := []frontend.Variable{l, r} - return poseidon.Hash(api, inputs) +func Hash2(api frontend.API, hFn utils.Hasher, l, r frontend.Variable) frontend.Variable { + hash, err := hFn(api, l, r) + if err != nil { + panic(err) + } + return hash } diff --git a/tree/smt/processor.go b/tree/smt/processor.go index e27c864..308cd75 100644 --- a/tree/smt/processor.go +++ b/tree/smt/processor.go @@ -2,15 +2,20 @@ package smt import ( "github.com/consensys/gnark/frontend" + "github.com/vocdoni/gnark-crypto-primitives/utils" ) // based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessor.circom -func Processor(api frontend.API, oldRoot frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, newKey, newValue, fnc0, fnc1 frontend.Variable) (newRoot frontend.Variable) { +func Processor(api frontend.API, hFn utils.Hasher, oldRoot frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, newKey, newValue, fnc0, fnc1 frontend.Variable) (newRoot frontend.Variable) { + hash1Old := Hash1(api, hFn, oldKey, oldValue) + hash1New := Hash1(api, hFn, newKey, newValue) + return ProcessorWithLeafHash(api, hFn, oldRoot, siblings, oldKey, hash1Old, isOld0, newKey, hash1New, fnc0, fnc1) +} + +func ProcessorWithLeafHash(api frontend.API, hFn utils.Hasher, oldRoot frontend.Variable, siblings []frontend.Variable, oldKey, hash1Old, isOld0, newKey, hash1New, fnc0, fnc1 frontend.Variable) (newRoot frontend.Variable) { levels := len(siblings) enabled := api.Sub(api.Add(fnc0, fnc1), api.Mul(fnc0, fnc1)) - hash1Old := Hash1(api, oldKey, oldValue) - hash1New := Hash1(api, newKey, newValue) n2bOld := api.ToBinary(oldKey, api.Compiler().FieldBitLen()) n2bNew := api.ToBinary(newKey, api.Compiler().FieldBitLen()) smtLevIns := LevIns(api, enabled, siblings) @@ -40,9 +45,9 @@ func Processor(api frontend.API, oldRoot frontend.Variable, siblings []frontend. levelsNewRoot := make([]frontend.Variable, levels) for i := levels - 1; i >= 0; i-- { if i == levels-1 { - levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0, 0) + levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, hFn, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0, 0) } else { - levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], levelsOldRoot[i+1], levelsNewRoot[i+1]) + levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, hFn, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], levelsOldRoot[i+1], levelsNewRoot[i+1]) } } diff --git a/tree/smt/processor_level.go b/tree/smt/processor_level.go index 46f8d0a..0723d5d 100644 --- a/tree/smt/processor_level.go +++ b/tree/smt/processor_level.go @@ -2,18 +2,19 @@ package smt import ( "github.com/consensys/gnark/frontend" + "github.com/vocdoni/gnark-crypto-primitives/utils" ) // based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessorlevel.circom -func ProcessorLevel(api frontend.API, stTop, stOld0, stBot, stNew1, stUpd, sibling, old1leaf, new1leaf, newlrbit, oldChild, newChild frontend.Variable) (oldRoot, newRoot frontend.Variable) { +func ProcessorLevel(api frontend.API, hFn utils.Hasher, stTop, stOld0, stBot, stNew1, stUpd, sibling, old1leaf, new1leaf, newlrbit, oldChild, newChild frontend.Variable) (oldRoot, newRoot frontend.Variable) { oldProofHashL, oldProofHashR := Switcher(api, newlrbit, oldChild, sibling) - oldProofHash := Hash2(api, oldProofHashL, oldProofHashR) + oldProofHash := Hash2(api, hFn, oldProofHashL, oldProofHashR) oldRoot = api.Add(api.Mul(old1leaf, api.Add(api.Add(stBot, stNew1), stUpd)), api.Mul(oldProofHash, stTop)) newProofHashL, newProofHashR := Switcher(api, newlrbit, api.Add(api.Mul(newChild, api.Add(stTop, stBot)), api.Mul(new1leaf, stNew1)), api.Add(api.Mul(sibling, stTop), api.Mul(old1leaf, stNew1))) - newProofHash := Hash2(api, newProofHashL, newProofHashR) + newProofHash := Hash2(api, hFn, newProofHashL, newProofHashR) newRoot = api.Add(api.Mul(newProofHash, api.Add(api.Add(stTop, stBot), stNew1)), api.Mul(new1leaf, api.Add(stOld0, stUpd))) return diff --git a/tree/smt/verifier.go b/tree/smt/verifier.go index ce88232..180b5b9 100644 --- a/tree/smt/verifier.go +++ b/tree/smt/verifier.go @@ -2,20 +2,21 @@ package smt import ( "github.com/consensys/gnark/frontend" + "github.com/vocdoni/gnark-crypto-primitives/utils" ) -func InclusionVerifier(api frontend.API, root frontend.Variable, siblings []frontend.Variable, key, value frontend.Variable) { - Verifier(api, 1, root, siblings, key, value, 0, key, value, 0) +func InclusionVerifier(api frontend.API, hFn utils.Hasher, root frontend.Variable, siblings []frontend.Variable, key, value frontend.Variable) { + Verifier(api, hFn, 1, root, siblings, key, value, 0, key, value, 0) } -func ExclusionVerifier(api frontend.API, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key frontend.Variable) { - Verifier(api, 1, root, siblings, oldKey, oldValue, isOld0, key, 0, 1) +func ExclusionVerifier(api frontend.API, hFn utils.Hasher, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key frontend.Variable) { + Verifier(api, hFn, 1, root, siblings, oldKey, oldValue, isOld0, key, 0, 1) } -func Verifier(api frontend.API, enabled, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key, value, fnc frontend.Variable) { +func Verifier(api frontend.API, hFn utils.Hasher, enabled, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key, value, fnc frontend.Variable) { nLevels := len(siblings) - hash1Old := Hash1(api, oldKey, oldValue) - hash1New := Hash1(api, key, value) + hash1Old := Hash1(api, hFn, oldKey, oldValue) + hash1New := Hash1(api, hFn, key, value) n2bNew := api.ToBinary(key, api.Compiler().FieldBitLen()) smtLevIns := LevIns(api, enabled, siblings) @@ -36,9 +37,9 @@ func Verifier(api frontend.API, enabled, root frontend.Variable, siblings []fron levels := make([]frontend.Variable, nLevels) for i := nLevels - 1; i >= 0; i-- { if i == nLevels-1 { - levels[i] = VerifierLevel(api, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0) + levels[i] = VerifierLevel(api, hFn, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0) } else { - levels[i] = VerifierLevel(api, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], levels[i+1]) + levels[i] = VerifierLevel(api, hFn, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], levels[i+1]) } } diff --git a/tree/smt/verifier_level.go b/tree/smt/verifier_level.go index b68cc8c..bbd8b1a 100644 --- a/tree/smt/verifier_level.go +++ b/tree/smt/verifier_level.go @@ -2,11 +2,12 @@ package smt import ( "github.com/consensys/gnark/frontend" + "github.com/vocdoni/gnark-crypto-primitives/utils" ) -func VerifierLevel(api frontend.API, stTop, stIOld, stINew, sibling, old1leaf, new1leaf, lrbit, child frontend.Variable) (root frontend.Variable) { +func VerifierLevel(api frontend.API, hFn utils.Hasher, stTop, stIOld, stINew, sibling, old1leaf, new1leaf, lrbit, child frontend.Variable) (root frontend.Variable) { proofHashL, proofHashR := Switcher(api, lrbit, child, sibling) - proofHash := Hash2(api, proofHashL, proofHashR) + proofHash := Hash2(api, hFn, proofHashL, proofHashR) root = api.Add(api.Add(api.Mul(proofHash, stTop), api.Mul(old1leaf, stIOld)), api.Mul(new1leaf, stINew)) return } diff --git a/utils/hashers.go b/utils/hashers.go new file mode 100644 index 0000000..12a0919 --- /dev/null +++ b/utils/hashers.go @@ -0,0 +1,20 @@ +package utils + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/mimc" +) + +type Hasher func(frontend.API, ...frontend.Variable) (frontend.Variable, error) + +// MiMCHasher is a hash function that hashes the data provided using the +// mimc hash function and the current compiler field. It is used to hash the +// leaves of the census tree during the proof verification. +func MiMCHasher(api frontend.API, data ...frontend.Variable) (frontend.Variable, error) { + h, err := mimc.NewMiMC(api) + if err != nil { + return 0, err + } + h.Write(data...) + return h.Sum(), nil +}