diff --git a/tree/arbo/circomproofs.go b/tree/arbo/circomproofs.go index 94a4aacbb..b43ea0059 100644 --- a/tree/arbo/circomproofs.go +++ b/tree/arbo/circomproofs.go @@ -1,7 +1,10 @@ package arbo import ( + "bytes" "encoding/json" + "fmt" + "slices" ) // CircomVerifierProof contains the needed data to check a Circom Verifier Proof @@ -89,3 +92,32 @@ func (t *Tree) GenerateCircomVerifierProof(k []byte) (*CircomVerifierProof, erro return &cp, nil } + +// CalculateProofNodes calculates the chain of hashes in the path of the proof. +// In the returned list, first item is the root, and last item is the hash of the leaf. +func (cvp CircomVerifierProof) CalculateProofNodes(hashFunc HashFunction) ([][]byte, error) { + paddedSiblings := slices.Clone(cvp.Siblings) + for k, v := range paddedSiblings { + if bytes.Equal(v, []byte{0}) { + paddedSiblings[k] = make([]byte, hashFunc.Len()) + } + } + packedSiblings, err := PackSiblings(hashFunc, paddedSiblings) + if err != nil { + return nil, err + } + return CalculateProofNodes(hashFunc, cvp.Key, cvp.Value, packedSiblings) +} + +// CheckProof verifies the given proof. The proof verification depends on the +// HashFunction passed as parameter. +func (cvp CircomVerifierProof) CheckProof(hashFunc HashFunction) (bool, error) { + hashes, err := cvp.CalculateProofNodes(hashFunc) + if err != nil { + return false, err + } + if !bytes.Equal(hashes[0], cvp.Root) { + return false, fmt.Errorf("calculated root doesn't match expected root") + } + return true, nil +} diff --git a/tree/arbo/proof.go b/tree/arbo/proof.go index ad19f671d..5fc3d116c 100644 --- a/tree/arbo/proof.go +++ b/tree/arbo/proof.go @@ -3,6 +3,7 @@ package arbo import ( "bytes" "encoding/binary" + "encoding/hex" "fmt" "math" "slices" @@ -160,19 +161,31 @@ func bytesToBitmap(b []byte) []bool { // CheckProof verifies the given proof. The proof verification depends on the // HashFunction passed as parameter. func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, error) { - siblings, err := UnpackSiblings(hashFunc, packedSiblings) + hashes, err := CalculateProofNodes(hashFunc, k, v, packedSiblings) if err != nil { return false, err } + return bytes.Equal(hashes[0], root), nil +} + +// CalculateProofNodes calculates the chain of hashes in the path of the given proof. +// In the returned list, first item is the root, and last item is the hash of the leaf. +func CalculateProofNodes(hashFunc HashFunction, k, v, packedSiblings []byte) ([][]byte, error) { + siblings, err := UnpackSiblings(hashFunc, packedSiblings) + if err != nil { + return nil, err + } keyPath := make([]byte, int(math.Ceil(float64(len(siblings))/float64(8)))) copy(keyPath, k) key, _, err := newLeafValue(hashFunc, k, v) if err != nil { - return false, err + return nil, err } + hashes := [][]byte{key} + path := getPath(len(siblings), keyPath) for i, sibling := range slices.Backward(siblings) { if path[i] { @@ -181,8 +194,58 @@ func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) (bool, key, _, err = newIntermediate(hashFunc, key, sibling) } if err != nil { - return false, err + return nil, err + } + hashes = append(hashes, key) + } + slices.Reverse(hashes) + return hashes, nil +} + +// CheckProofBatch verifies a batch of N proofs pairs (old and new). The proof verification depends on the +// HashFunction passed as parameter. +// Returns nil if the batch is valid, or an error otherwise. +func CheckProofBatch(hashFunc HashFunction, oldProofs, newProofs []*CircomVerifierProof) error { + newBranches := make(map[string]int) + newSiblings := make(map[string]int) + + if len(oldProofs) != len(newProofs) { + return fmt.Errorf("batch of proofs incomplete") + } + + for i := range oldProofs { + // Check all old proofs are valid + if valid, err := oldProofs[i].CheckProof(hashFunc); !valid { + return fmt.Errorf("old proof invalid: %w", err) + } + + // Map all new branches + nodes, err := newProofs[i].CalculateProofNodes(hashFunc) + if err != nil { + return fmt.Errorf("new proof invalid: %w", err) + } + // and check they are valid + if !bytes.Equal(newProofs[i].Root, nodes[0]) { + return fmt.Errorf("new proof invalid: root doesn't match") + } + + for level, hash := range nodes { + newBranches[hex.EncodeToString(hash)] = level + } + + for level := range newProofs[i].Siblings { + if !slices.Equal(oldProofs[i].Siblings[level], newProofs[i].Siblings[level]) { + // since in newBranch the root is level 0, we shift siblings to level + 1 + newSiblings[hex.EncodeToString(newProofs[i].Siblings[level])] = level + 1 + } } } - return bytes.Equal(key, root), nil + + for hash, level := range newSiblings { + if newBranches[hash] != newSiblings[hash] { + return fmt.Errorf("sibling %s (at level %d) changed but there's no proof why", hash, level) + } + } + + return nil } diff --git a/tree/arbo/proof_test.go b/tree/arbo/proof_test.go new file mode 100644 index 000000000..8fede0656 --- /dev/null +++ b/tree/arbo/proof_test.go @@ -0,0 +1,111 @@ +package arbo + +import ( + "math/big" + "slices" + "testing" + + qt "github.com/frankban/quicktest" + "go.vocdoni.io/dvote/db/metadb" +) + +func TestCheckProofBatch(t *testing.T) { + database := metadb.NewTest(t) + c := qt.New(t) + + keyLen := 1 + maxLevels := keyLen * 8 + tree, err := NewTree(Config{ + Database: database, MaxLevels: maxLevels, + HashFunction: HashFunctionBlake3, + }) + c.Assert(err, qt.IsNil) + + processID := []byte("01234567890123456789012345678900") + censusRoot := []byte("01234567890123456789012345678901") + ballotMode := []byte("1234") + + err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x00)), processID) + c.Assert(err, qt.IsNil) + + err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot) + c.Assert(err, qt.IsNil) + + err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode) + c.Assert(err, qt.IsNil) + + var oldProofs, newProofs []*CircomVerifierProof + + for i := int64(0x00); i <= int64(0x02); i++ { + proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i))) + c.Assert(err, qt.IsNil) + oldProofs = append(oldProofs, proof) + } + + censusRoot[0] = byte(0x02) + ballotMode[0] = byte(0x02) + + err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot) + c.Assert(err, qt.IsNil) + + err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode) + c.Assert(err, qt.IsNil) + + for i := int64(0x00); i <= int64(0x02); i++ { + proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i))) + c.Assert(err, qt.IsNil) + newProofs = append(newProofs, proof) + } + + // this mix should pass: proof 0 is unchanged, proof 1 + 2 verify together + err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs) + c.Assert(err, qt.IsNil) + + // omitting proof 0 (unchanged) should also pass + err = CheckProofBatch(HashFunctionBlake3, oldProofs[1:], newProofs[1:]) + c.Assert(err, qt.IsNil) + + // providing just proof 0 (unchanged) should also pass + err = CheckProofBatch(HashFunctionBlake3, oldProofs[:0], newProofs[:0]) + c.Assert(err, qt.IsNil) + + // length mismatch + err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs[:1]) + c.Assert(err, qt.ErrorMatches, "batch of proofs incomplete") + + // omitting proof 2 should fail (since changed siblings in proof 1 can't be explained) + err = CheckProofBatch(HashFunctionBlake3, oldProofs[:1], newProofs[:1]) + c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*") + + // the rest is mangling proofs to simulate other unexplained changes in the tree, all of these should fail + badProofs := deepClone(oldProofs) + badProofs[0].Root = []byte("01234567890123456789012345678900") + err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs) + c.Assert(err, qt.ErrorMatches, "old proof invalid: calculated root doesn't match expected root") + + badProofs = deepClone(oldProofs) + badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900") + err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs) + c.Assert(err, qt.ErrorMatches, "old proof invalid: calculated root doesn't match expected root") + + badProofs = deepClone(newProofs) + badProofs[0].Root = []byte("01234567890123456789012345678900") + err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs) + c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match") + + badProofs = deepClone(newProofs) + badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900") + err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs) + c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match") +} + +func deepClone(src []*CircomVerifierProof) []*CircomVerifierProof { + dst := slices.Clone(src) + for i := range src { + proof := *src[i] + dst[i] = &proof + + dst[i].Siblings = slices.Clone(src[i].Siblings) + } + return dst +}