From 56b51eb9ac917adcfb249f36c865551c78fbe15d Mon Sep 17 00:00:00 2001 From: Gui Iribarren Date: Tue, 29 Oct 2024 18:43:22 +0100 Subject: [PATCH] arbo: add CheckProofBatch and CalculateProofNodes --- tree/arbo/circomproofs.go | 33 +++++++++ tree/arbo/proof.go | 102 +++++++++++++++++++++++--- tree/arbo/proof_test.go | 150 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 275 insertions(+), 10 deletions(-) create mode 100644 tree/arbo/proof_test.go diff --git a/tree/arbo/circomproofs.go b/tree/arbo/circomproofs.go index 94a4aacbb..afb0797d3 100644 --- a/tree/arbo/circomproofs.go +++ b/tree/arbo/circomproofs.go @@ -1,7 +1,10 @@ package arbo import ( + "bytes" "encoding/json" + "fmt" + "slices" ) // CircomVerifierProof contains the needed data to check a Circom Verifier Proof @@ -89,3 +92,33 @@ func (t *Tree) GenerateCircomVerifierProof(k []byte) (*CircomVerifierProof, erro return &cp, nil } + +// CalculateProofNodes calculates the chain of hashes in the path of the proof. +// In the returned list, first item is the root, and last item is the hash of the leaf. +func (cvp CircomVerifierProof) CalculateProofNodes(hashFunc HashFunction) ([][]byte, error) { + paddedSiblings := slices.Clone(cvp.Siblings) + for k, v := range paddedSiblings { + if bytes.Equal(v, []byte{0}) { + paddedSiblings[k] = make([]byte, hashFunc.Len()) + } + } + packedSiblings, err := PackSiblings(hashFunc, paddedSiblings) + if err != nil { + return nil, err + } + return CalculateProofNodes(hashFunc, cvp.Key, cvp.Value, packedSiblings, cvp.OldKey, (cvp.Fnc == 1)) +} + +// CheckProof verifies the given proof. The proof verification depends on the +// HashFunction passed as parameter. +// Returns nil if the proof is valid, or an error otherwise. +func (cvp CircomVerifierProof) CheckProof(hashFunc HashFunction) error { + hashes, err := cvp.CalculateProofNodes(hashFunc) + if err != nil { + return err + } + if !bytes.Equal(hashes[0], cvp.Root) { + return fmt.Errorf("calculated vs expected root mismatch") + } + return nil +} diff --git a/tree/arbo/proof.go b/tree/arbo/proof.go index 701c4385f..4cdfce7c4 100644 --- a/tree/arbo/proof.go +++ b/tree/arbo/proof.go @@ -3,6 +3,7 @@ package arbo import ( "bytes" "encoding/binary" + "encoding/hex" "fmt" "math" "slices" @@ -161,32 +162,113 @@ func bytesToBitmap(b []byte) []bool { // HashFunction passed as parameter. // Returns nil if the proof is valid, or an error otherwise. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) error { - siblings, err := UnpackSiblings(hashFunc, packedSiblings) + hashes, err := CalculateProofNodes(hashFunc, k, v, packedSiblings, nil, false) if err != nil { return err } + if !bytes.Equal(hashes[0], root) { + return fmt.Errorf("calculated vs expected root mismatch") + } + return nil +} + +// CalculateProofNodes calculates the chain of hashes in the path of the given proof. +// In the returned list, first item is the root, and last item is the hash of the leaf. +func CalculateProofNodes(hashFunc HashFunction, k, v, packedSiblings, oldKey []byte, exclusion bool) ([][]byte, error) { + siblings, err := UnpackSiblings(hashFunc, packedSiblings) + if err != nil { + return nil, err + } keyPath := make([]byte, int(math.Ceil(float64(len(siblings))/float64(8)))) copy(keyPath, k) + path := getPath(len(siblings), keyPath) - key, _, err := newLeafValue(hashFunc, k, v) - if err != nil { - return err + key := slices.Clone(k) + + if exclusion { + if slices.Equal(k, oldKey) { + return nil, fmt.Errorf("exclusion proof invalid, key and oldKey are equal") + } + // we'll prove the path to the existing key (passed as oldKey) + key = slices.Clone(oldKey) } - path := getPath(len(siblings), keyPath) + hash, _, err := newLeafValue(hashFunc, key, v) + if err != nil { + return nil, err + } + hashes := [][]byte{hash} for i, sibling := range slices.Backward(siblings) { if path[i] { - key, _, err = newIntermediate(hashFunc, sibling, key) + hash, _, err = newIntermediate(hashFunc, sibling, hash) } else { - key, _, err = newIntermediate(hashFunc, key, sibling) + hash, _, err = newIntermediate(hashFunc, hash, sibling) } if err != nil { - return err + return nil, err } + hashes = append(hashes, hash) } - if !bytes.Equal(key, root) { - return fmt.Errorf("calculated vs expected root mismatch") + slices.Reverse(hashes) + return hashes, nil +} + +// CheckProofBatch verifies a batch of N proofs pairs (old and new). The proof verification depends on the +// HashFunction passed as parameter. +// Returns nil if the batch is valid, or an error otherwise. +// +// TODO: doesn't support removing leaves (newProofs can only update or add new leaves) +func CheckProofBatch(hashFunc HashFunction, oldProofs, newProofs []*CircomVerifierProof) error { + newBranches := make(map[string]int) + newSiblings := make(map[string]int) + + if len(oldProofs) != len(newProofs) { + return fmt.Errorf("batch of proofs incomplete") + } + + if len(oldProofs) == 0 { + return fmt.Errorf("empty batch") + } + + for i := range oldProofs { + // Map all old branches + oldNodes, err := oldProofs[i].CalculateProofNodes(hashFunc) + if err != nil { + return fmt.Errorf("old proof invalid: %w", err) + } + // and check they are valid + if !bytes.Equal(oldProofs[i].Root, oldNodes[0]) { + return fmt.Errorf("old proof invalid: root doesn't match") + } + + // Map all new branches + newNodes, err := newProofs[i].CalculateProofNodes(hashFunc) + if err != nil { + return fmt.Errorf("new proof invalid: %w", err) + } + // and check they are valid + if !bytes.Equal(newProofs[i].Root, newNodes[0]) { + return fmt.Errorf("new proof invalid: root doesn't match") + } + + for level, hash := range newNodes { + newBranches[hex.EncodeToString(hash)] = level + } + + for level := range newProofs[i].Siblings { + if !slices.Equal(oldProofs[i].Siblings[level], newProofs[i].Siblings[level]) { + // since in newBranch the root is level 0, we shift siblings to level + 1 + newSiblings[hex.EncodeToString(newProofs[i].Siblings[level])] = level + 1 + } + } } + + for hash, level := range newSiblings { + if newBranches[hash] != newSiblings[hash] { + return fmt.Errorf("sibling %s (at level %d) changed but there's no proof why", hash, level) + } + } + return nil } diff --git a/tree/arbo/proof_test.go b/tree/arbo/proof_test.go new file mode 100644 index 000000000..d095b6dac --- /dev/null +++ b/tree/arbo/proof_test.go @@ -0,0 +1,150 @@ +package arbo + +import ( + "math/big" + "slices" + "testing" + + qt "github.com/frankban/quicktest" + "go.vocdoni.io/dvote/db/metadb" +) + +func TestCheckProofBatch(t *testing.T) { + database := metadb.NewTest(t) + c := qt.New(t) + + keyLen := 1 + maxLevels := keyLen * 8 + tree, err := NewTree(Config{ + Database: database, MaxLevels: maxLevels, + HashFunction: HashFunctionBlake3, + }) + c.Assert(err, qt.IsNil) + + censusRoot := []byte("01234567890123456789012345678901") + ballotMode := []byte("1234") + + err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot) + c.Assert(err, qt.IsNil) + + err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode) + c.Assert(err, qt.IsNil) + + var oldProofs, newProofs []*CircomVerifierProof + + for i := int64(0x00); i <= int64(0x04); i++ { + proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i))) + c.Assert(err, qt.IsNil) + oldProofs = append(oldProofs, proof) + } + + censusRoot[0] = byte(0x02) + ballotMode[0] = byte(0x02) + + err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot) + c.Assert(err, qt.IsNil) + + err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode) + c.Assert(err, qt.IsNil) + + err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x03)), ballotMode) + c.Assert(err, qt.IsNil) + + for i := int64(0x00); i <= int64(0x04); i++ { + proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i))) + c.Assert(err, qt.IsNil) + newProofs = append(newProofs, proof) + } + + // passing all proofs should be OK: + // proof 1 + 2 + 3 are required + // proof 0 and 4 are of unchanged keys, but the new siblings are explained by the other proofs + err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs) + c.Assert(err, qt.IsNil) + + // omitting proof 0 and 4 (unchanged keys) should also be OK + err = CheckProofBatch(HashFunctionBlake3, oldProofs[1:4], newProofs[1:4]) + c.Assert(err, qt.IsNil) + + // providing an empty batch should not pass + err = CheckProofBatch(HashFunctionBlake3, []*CircomVerifierProof{}, []*CircomVerifierProof{}) + c.Assert(err, qt.ErrorMatches, "empty batch") + + // length mismatch + err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs[:1]) + c.Assert(err, qt.ErrorMatches, "batch of proofs incomplete") + + // providing just proof 0 (unchanged key) should not pass since siblings can't be explained + err = CheckProofBatch(HashFunctionBlake3, oldProofs[:1], newProofs[:1]) + c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*") + + // providing just proof 0 (unchanged key) and an add, should fail + err = CheckProofBatch(HashFunctionBlake3, oldProofs[:1], newProofs[3:4]) + c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*") + + // omitting proof 3 should fail (since changed siblings in other proofs can't be explained) + err = CheckProofBatch(HashFunctionBlake3, oldProofs[:3], newProofs[:3]) + c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*") + + // the next 4 are mangling proofs to simulate other unexplained changes in the tree, all of these should fail + badProofs := deepClone(oldProofs) + badProofs[0].Root = []byte("01234567890123456789012345678900") + err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs) + c.Assert(err, qt.ErrorMatches, "old proof invalid: root doesn't match") + + badProofs = deepClone(oldProofs) + badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900") + err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs) + c.Assert(err, qt.ErrorMatches, "old proof invalid: root doesn't match") + + badProofs = deepClone(newProofs) + badProofs[0].Root = []byte("01234567890123456789012345678900") + err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs) + c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match") + + badProofs = deepClone(newProofs) + badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900") + err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs) + c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match") + + // also test exclusion proofs: + // exclusion proof of key 0x04 can't be used to prove exclusion of 0x01, 0x03 or 0x05 obviously + badProofs = deepClone(oldProofs) + badProofs[4].Key = []byte{0x01} + err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:]) + c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match") + badProofs[4].Key = []byte{0x03} + err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:]) + c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match") + badProofs[4].Key = []byte{0x05} + err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:]) + c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match") + // also can't prove key 0x02 exclusion (since that leaf exists and is indeed the starting point of the proof) + badProofs[4].Key = []byte{0x02} + err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:]) + c.Assert(err, qt.ErrorMatches, "new proof invalid: exclusion proof invalid, key and oldKey are equal") + // but exclusion proof of key 0x04 can also prove exclusion of the whole prefix (0x00, 0x08, 0x0c, 0x10, etc) + badProofs[4].Key = []byte{0x00} + err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:]) + c.Assert(err, qt.IsNil) + badProofs[4].Key = []byte{0x08} + err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:]) + c.Assert(err, qt.IsNil) + badProofs[4].Key = []byte{0x0c} + err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:]) + c.Assert(err, qt.IsNil) + badProofs[4].Key = []byte{0x10} + err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:]) + c.Assert(err, qt.IsNil) +} + +func deepClone(src []*CircomVerifierProof) []*CircomVerifierProof { + dst := slices.Clone(src) + for i := range src { + proof := *src[i] + dst[i] = &proof + + dst[i].Siblings = slices.Clone(src[i].Siblings) + } + return dst +}