Skip to content

Commit

Permalink
arbo: add CheckProofBatch and CalculateProofNodes
Browse files Browse the repository at this point in the history
  • Loading branch information
altergui committed Nov 5, 2024
1 parent 98fff29 commit 56b51eb
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 10 deletions.
33 changes: 33 additions & 0 deletions tree/arbo/circomproofs.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package arbo

import (
"bytes"
"encoding/json"
"fmt"
"slices"
)

// CircomVerifierProof contains the needed data to check a Circom Verifier Proof
Expand Down Expand Up @@ -89,3 +92,33 @@ func (t *Tree) GenerateCircomVerifierProof(k []byte) (*CircomVerifierProof, erro

return &cp, nil
}

// CalculateProofNodes calculates the chain of hashes in the path of the proof.
// In the returned list, first item is the root, and last item is the hash of the leaf.
func (cvp CircomVerifierProof) CalculateProofNodes(hashFunc HashFunction) ([][]byte, error) {
paddedSiblings := slices.Clone(cvp.Siblings)
for k, v := range paddedSiblings {
if bytes.Equal(v, []byte{0}) {
paddedSiblings[k] = make([]byte, hashFunc.Len())
}
}
packedSiblings, err := PackSiblings(hashFunc, paddedSiblings)
if err != nil {
return nil, err
}
return CalculateProofNodes(hashFunc, cvp.Key, cvp.Value, packedSiblings, cvp.OldKey, (cvp.Fnc == 1))
}

// CheckProof verifies the given proof. The proof verification depends on the
// HashFunction passed as parameter.
// Returns nil if the proof is valid, or an error otherwise.
func (cvp CircomVerifierProof) CheckProof(hashFunc HashFunction) error {
hashes, err := cvp.CalculateProofNodes(hashFunc)
if err != nil {
return err
}
if !bytes.Equal(hashes[0], cvp.Root) {
return fmt.Errorf("calculated vs expected root mismatch")
}
return nil
}
102 changes: 92 additions & 10 deletions tree/arbo/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package arbo
import (
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
"slices"
Expand Down Expand Up @@ -161,32 +162,113 @@ func bytesToBitmap(b []byte) []bool {
// HashFunction passed as parameter.
// Returns nil if the proof is valid, or an error otherwise.
func CheckProof(hashFunc HashFunction, k, v, root, packedSiblings []byte) error {
siblings, err := UnpackSiblings(hashFunc, packedSiblings)
hashes, err := CalculateProofNodes(hashFunc, k, v, packedSiblings, nil, false)
if err != nil {
return err
}
if !bytes.Equal(hashes[0], root) {
return fmt.Errorf("calculated vs expected root mismatch")
}
return nil
}

// CalculateProofNodes calculates the chain of hashes in the path of the given proof.
// In the returned list, first item is the root, and last item is the hash of the leaf.
func CalculateProofNodes(hashFunc HashFunction, k, v, packedSiblings, oldKey []byte, exclusion bool) ([][]byte, error) {
siblings, err := UnpackSiblings(hashFunc, packedSiblings)
if err != nil {
return nil, err
}

keyPath := make([]byte, int(math.Ceil(float64(len(siblings))/float64(8))))
copy(keyPath, k)
path := getPath(len(siblings), keyPath)

key, _, err := newLeafValue(hashFunc, k, v)
if err != nil {
return err
key := slices.Clone(k)

if exclusion {
if slices.Equal(k, oldKey) {
return nil, fmt.Errorf("exclusion proof invalid, key and oldKey are equal")
}
// we'll prove the path to the existing key (passed as oldKey)
key = slices.Clone(oldKey)
}

path := getPath(len(siblings), keyPath)
hash, _, err := newLeafValue(hashFunc, key, v)
if err != nil {
return nil, err
}
hashes := [][]byte{hash}
for i, sibling := range slices.Backward(siblings) {
if path[i] {
key, _, err = newIntermediate(hashFunc, sibling, key)
hash, _, err = newIntermediate(hashFunc, sibling, hash)
} else {
key, _, err = newIntermediate(hashFunc, key, sibling)
hash, _, err = newIntermediate(hashFunc, hash, sibling)
}
if err != nil {
return err
return nil, err
}
hashes = append(hashes, hash)
}
if !bytes.Equal(key, root) {
return fmt.Errorf("calculated vs expected root mismatch")
slices.Reverse(hashes)
return hashes, nil
}

// CheckProofBatch verifies a batch of N proofs pairs (old and new). The proof verification depends on the
// HashFunction passed as parameter.
// Returns nil if the batch is valid, or an error otherwise.
//
// TODO: doesn't support removing leaves (newProofs can only update or add new leaves)
func CheckProofBatch(hashFunc HashFunction, oldProofs, newProofs []*CircomVerifierProof) error {
newBranches := make(map[string]int)
newSiblings := make(map[string]int)

if len(oldProofs) != len(newProofs) {
return fmt.Errorf("batch of proofs incomplete")
}

if len(oldProofs) == 0 {
return fmt.Errorf("empty batch")
}

for i := range oldProofs {
// Map all old branches
oldNodes, err := oldProofs[i].CalculateProofNodes(hashFunc)
if err != nil {
return fmt.Errorf("old proof invalid: %w", err)
}
// and check they are valid
if !bytes.Equal(oldProofs[i].Root, oldNodes[0]) {
return fmt.Errorf("old proof invalid: root doesn't match")
}

// Map all new branches
newNodes, err := newProofs[i].CalculateProofNodes(hashFunc)
if err != nil {
return fmt.Errorf("new proof invalid: %w", err)
}
// and check they are valid
if !bytes.Equal(newProofs[i].Root, newNodes[0]) {
return fmt.Errorf("new proof invalid: root doesn't match")
}

for level, hash := range newNodes {
newBranches[hex.EncodeToString(hash)] = level
}

for level := range newProofs[i].Siblings {
if !slices.Equal(oldProofs[i].Siblings[level], newProofs[i].Siblings[level]) {
// since in newBranch the root is level 0, we shift siblings to level + 1
newSiblings[hex.EncodeToString(newProofs[i].Siblings[level])] = level + 1
}
}
}

for hash, level := range newSiblings {
if newBranches[hash] != newSiblings[hash] {
return fmt.Errorf("sibling %s (at level %d) changed but there's no proof why", hash, level)
}
}

return nil
}
150 changes: 150 additions & 0 deletions tree/arbo/proof_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package arbo

import (
"math/big"
"slices"
"testing"

qt "github.com/frankban/quicktest"
"go.vocdoni.io/dvote/db/metadb"
)

func TestCheckProofBatch(t *testing.T) {
database := metadb.NewTest(t)
c := qt.New(t)

keyLen := 1
maxLevels := keyLen * 8
tree, err := NewTree(Config{
Database: database, MaxLevels: maxLevels,
HashFunction: HashFunctionBlake3,
})
c.Assert(err, qt.IsNil)

censusRoot := []byte("01234567890123456789012345678901")
ballotMode := []byte("1234")

err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot)
c.Assert(err, qt.IsNil)

err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode)
c.Assert(err, qt.IsNil)

var oldProofs, newProofs []*CircomVerifierProof

for i := int64(0x00); i <= int64(0x04); i++ {
proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i)))
c.Assert(err, qt.IsNil)
oldProofs = append(oldProofs, proof)
}

censusRoot[0] = byte(0x02)
ballotMode[0] = byte(0x02)

err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x01)), censusRoot)
c.Assert(err, qt.IsNil)

err = tree.Update(BigIntToBytesLE(keyLen, big.NewInt(0x02)), ballotMode)
c.Assert(err, qt.IsNil)

err = tree.Add(BigIntToBytesLE(keyLen, big.NewInt(0x03)), ballotMode)
c.Assert(err, qt.IsNil)

for i := int64(0x00); i <= int64(0x04); i++ {
proof, err := tree.GenerateCircomVerifierProof(BigIntToBytesLE(keyLen, big.NewInt(i)))
c.Assert(err, qt.IsNil)
newProofs = append(newProofs, proof)
}

// passing all proofs should be OK:
// proof 1 + 2 + 3 are required
// proof 0 and 4 are of unchanged keys, but the new siblings are explained by the other proofs
err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs)
c.Assert(err, qt.IsNil)

// omitting proof 0 and 4 (unchanged keys) should also be OK
err = CheckProofBatch(HashFunctionBlake3, oldProofs[1:4], newProofs[1:4])
c.Assert(err, qt.IsNil)

// providing an empty batch should not pass
err = CheckProofBatch(HashFunctionBlake3, []*CircomVerifierProof{}, []*CircomVerifierProof{})
c.Assert(err, qt.ErrorMatches, "empty batch")

// length mismatch
err = CheckProofBatch(HashFunctionBlake3, oldProofs, newProofs[:1])
c.Assert(err, qt.ErrorMatches, "batch of proofs incomplete")

// providing just proof 0 (unchanged key) should not pass since siblings can't be explained
err = CheckProofBatch(HashFunctionBlake3, oldProofs[:1], newProofs[:1])
c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*")

// providing just proof 0 (unchanged key) and an add, should fail
err = CheckProofBatch(HashFunctionBlake3, oldProofs[:1], newProofs[3:4])
c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*")

// omitting proof 3 should fail (since changed siblings in other proofs can't be explained)
err = CheckProofBatch(HashFunctionBlake3, oldProofs[:3], newProofs[:3])
c.Assert(err, qt.ErrorMatches, ".*changed but there's no proof why.*")

// the next 4 are mangling proofs to simulate other unexplained changes in the tree, all of these should fail
badProofs := deepClone(oldProofs)
badProofs[0].Root = []byte("01234567890123456789012345678900")
err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs)
c.Assert(err, qt.ErrorMatches, "old proof invalid: root doesn't match")

badProofs = deepClone(oldProofs)
badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900")
err = CheckProofBatch(HashFunctionBlake3, badProofs, newProofs)
c.Assert(err, qt.ErrorMatches, "old proof invalid: root doesn't match")

badProofs = deepClone(newProofs)
badProofs[0].Root = []byte("01234567890123456789012345678900")
err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs)
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")

badProofs = deepClone(newProofs)
badProofs[0].Siblings[0] = []byte("01234567890123456789012345678900")
err = CheckProofBatch(HashFunctionBlake3, oldProofs, badProofs)
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")

// also test exclusion proofs:
// exclusion proof of key 0x04 can't be used to prove exclusion of 0x01, 0x03 or 0x05 obviously
badProofs = deepClone(oldProofs)
badProofs[4].Key = []byte{0x01}
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")
badProofs[4].Key = []byte{0x03}
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")
badProofs[4].Key = []byte{0x05}
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
c.Assert(err, qt.ErrorMatches, "new proof invalid: root doesn't match")
// also can't prove key 0x02 exclusion (since that leaf exists and is indeed the starting point of the proof)
badProofs[4].Key = []byte{0x02}
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
c.Assert(err, qt.ErrorMatches, "new proof invalid: exclusion proof invalid, key and oldKey are equal")
// but exclusion proof of key 0x04 can also prove exclusion of the whole prefix (0x00, 0x08, 0x0c, 0x10, etc)
badProofs[4].Key = []byte{0x00}
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
c.Assert(err, qt.IsNil)
badProofs[4].Key = []byte{0x08}
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
c.Assert(err, qt.IsNil)
badProofs[4].Key = []byte{0x0c}
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
c.Assert(err, qt.IsNil)
badProofs[4].Key = []byte{0x10}
err = CheckProofBatch(HashFunctionBlake3, oldProofs[4:], badProofs[4:])
c.Assert(err, qt.IsNil)
}

func deepClone(src []*CircomVerifierProof) []*CircomVerifierProof {
dst := slices.Clone(src)
for i := range src {
proof := *src[i]
dst[i] = &proof

dst[i].Siblings = slices.Clone(src[i].Siblings)
}
return dst
}

0 comments on commit 56b51eb

Please sign in to comment.