Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

homomorphic: big refactor, renamed homomorphic -> elgamal #5

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions elgamal/ciphertext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package elgamal

import (
ecc_tweds "github.com/consensys/gnark-crypto/ecc/twistededwards"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/algebra/native/twistededwards"
"github.com/iden3/go-iden3-crypto/babyjub"
)

type Ciphertext struct {
C1, C2 twistededwards.Point
}

func NewCiphertext() *Ciphertext {
zero := babyjub.NewPoint()
return &Ciphertext{C1: twistededwards.Point{X: zero.X, Y: zero.Y}, C2: twistededwards.Point{X: zero.X, Y: zero.Y}}
}

// Add sets z to the sum x+y and returns z.
//
// Panics if twistededwards curve init fails.
func (z *Ciphertext) Add(api frontend.API, x, y *Ciphertext) *Ciphertext {
curve, err := twistededwards.NewEdCurve(api, ecc_tweds.BN254)
if err != nil {
panic(err)
}
for _, p := range []twistededwards.Point{x.C1, x.C2, y.C1, y.C2} {
curve.AssertIsOnCurve(p)
}
z.C1 = curve.Add(x.C1, y.C1)
z.C2 = curve.Add(x.C2, y.C2)
return z
}

// AssertIsEqual fails if any of the fields differ between z and x
func (z *Ciphertext) AssertIsEqual(api frontend.API, x *Ciphertext) {
api.AssertIsEqual(z.C1.X, z.C1.X)
api.AssertIsEqual(z.C1.Y, x.C1.Y)
api.AssertIsEqual(z.C2.X, x.C2.X)
api.AssertIsEqual(z.C2.Y, x.C2.Y)
}

// Select if b is true, sets z = i1, else z = i2, and returns z
func (z *Ciphertext) Select(api frontend.API, b frontend.Variable, i1 *Ciphertext, i2 *Ciphertext) *Ciphertext {
z.C1.X = api.Select(b, i1.C1.X, i2.C1.X)
z.C1.Y = api.Select(b, i1.C1.Y, i2.C1.Y)
z.C2.X = api.Select(b, i1.C2.X, i2.C2.X)
z.C2.Y = api.Select(b, i1.C2.Y, i2.C2.Y)
return z
}

// Serialize returns a slice with the C1.X, C1.Y, C2.X, C2.Y in order
func (z *Ciphertext) Serialize() []frontend.Variable {
return []frontend.Variable{
z.C1.X,
z.C1.Y,
z.C2.X,
z.C2.Y,
}
}
89 changes: 41 additions & 48 deletions homomorphic/add_test.go → elgamal/ciphertext_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package hadd
package elgamal

import (
"crypto/rand"
Expand All @@ -18,34 +18,21 @@ import (
"github.com/vocdoni/vocdoni-z-sandbox/ecc/format"
)

type testHomomorphicAddCircuit struct {
A1 twistededwards.Point `gnark:"a1,public"`
A2 twistededwards.Point `gnark:"a2,public"`
B1 twistededwards.Point `gnark:"b1,public"`
B2 twistededwards.Point `gnark:"b2,public"`
C1 twistededwards.Point `gnark:"c1,public"`
C2 twistededwards.Point `gnark:"c2,public"`
type testElGamalAddCircuit struct {
A Ciphertext `gnark:",public"`
B Ciphertext `gnark:",public"`
Sum Ciphertext `gnark:",public"`
}

func (c *testHomomorphicAddCircuit) Define(api frontend.API) error {
// calculate and check c1
c1, err := Add(api, c.A1, c.B1)
if err != nil {
return err
}
api.AssertIsEqual(c.C1.X, c1.X)
api.AssertIsEqual(c.C1.Y, c1.Y)
// calculate and check c2
c2, err := Add(api, c.A2, c.B2)
if err != nil {
return err
}
api.AssertIsEqual(c.C2.X, c2.X)
api.AssertIsEqual(c.C2.Y, c2.Y)
func (c *testElGamalAddCircuit) Define(api frontend.API) error {
// calculate and check sum
sum := &Ciphertext{}
sum.Add(api, &c.A, &c.B)
sum.AssertIsEqual(api, &c.Sum)
return nil
}

func TestHomomorphicAdd(t *testing.T) {
func TestElGamalAdd(t *testing.T) {
// generate a public mocked key and a random k to encrypt first message
_, pubKey := generateKeyPair()
k1, err := randomK()
Expand Down Expand Up @@ -80,40 +67,46 @@ func TestHomomorphicAdd(t *testing.T) {
// profiling the circuit compilation
p := profile.Start()
now := time.Now()
_, _ = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &testHomomorphicAddCircuit{})
_, _ = frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &testElGamalAddCircuit{})
fmt.Println("elapsed", time.Since(now))
p.Stop()
fmt.Println("constrains", p.NbConstraints())
// run the test to prove the homomorphic property
assert := test.NewAssert(t)
inputs := &testHomomorphicAddCircuit{
A1: twistededwards.Point{
X: xA1RTE,
Y: yA1RTE,
},
A2: twistededwards.Point{
X: xA2RTE,
Y: yA2RTE,
},
B1: twistededwards.Point{
X: xB1RTE,
Y: yB1RTE,
},
B2: twistededwards.Point{
X: xB2RTE,
Y: yB2RTE,
inputs := &testElGamalAddCircuit{
A: Ciphertext{
C1: twistededwards.Point{
X: xA1RTE,
Y: yA1RTE,
},
C2: twistededwards.Point{
X: xA2RTE,
Y: yA2RTE,
},
},
C1: twistededwards.Point{
X: xC1RTE,
Y: yC1RTE,
B: Ciphertext{
C1: twistededwards.Point{
X: xB1RTE,
Y: yB1RTE,
},
C2: twistededwards.Point{
X: xB2RTE,
Y: yB2RTE,
},
},
C2: twistededwards.Point{
X: xC2RTE,
Y: yC2RTE,
Sum: Ciphertext{
C1: twistededwards.Point{
X: xC1RTE,
Y: yC1RTE,
},
C2: twistededwards.Point{
X: xC2RTE,
Y: yC2RTE,
},
},
}
now = time.Now()
assert.SolvingSucceeded(&testHomomorphicAddCircuit{}, inputs, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16))
assert.SolvingSucceeded(&testElGamalAddCircuit{}, inputs, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16))
fmt.Println("elapsed", time.Since(now))
}

Expand Down
17 changes: 0 additions & 17 deletions homomorphic/add.go

This file was deleted.

6 changes: 2 additions & 4 deletions tree/arbo/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import (
"github.com/vocdoni/gnark-crypto-primitives/utils"
)

type Hash func(frontend.API, ...frontend.Variable) (frontend.Variable, error)

// intermediateLeafKey function calculates the intermediate leaf key of the
// path position provided. The leaf key is calculated by hashing the sibling
// and the key provided. The position of the sibling and the key is decided by
// the path position. If the current sibling is not valid, the method will
// return the key provided.
func intermediateLeafKey(api frontend.API, hFn Hash, ipath, valid, key, sibling frontend.Variable) (frontend.Variable, error) {
func intermediateLeafKey(api frontend.API, hFn utils.Hasher, ipath, valid, key, sibling frontend.Variable) (frontend.Variable, error) {
// l, r = path == 1 ? sibling, key : key, sibling
l, r := api.Select(ipath, sibling, key), api.Select(ipath, key, sibling)
// intermediateLeafKey = H(l | r)
Expand All @@ -35,7 +33,7 @@ func isValid(api frontend.API, sibling, prevSibling, leaf, prevLeaf frontend.Var
// CheckInclusionProof receives the parameters of an inclusion proof of Arbo to
// recalculate the root with them and compare it with the provided one,
// verifiying the proof.
func CheckInclusionProof(api frontend.API, hFn Hash, key, value, root frontend.Variable,
func CheckInclusionProof(api frontend.API, hFn utils.Hasher, key, value, root frontend.Variable,
siblings []frontend.Variable,
) error {
// calculate the path from the provided key to decide which leaf is the
Expand Down
24 changes: 16 additions & 8 deletions tree/smt/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@ package smt

import (
"github.com/consensys/gnark/frontend"

"github.com/mdehoog/poseidon/circuits/poseidon"
"github.com/vocdoni/gnark-crypto-primitives/utils"
)

// based on https://github.com/iden3/circomlib/blob/master/circuits/smt/smthash_poseidon.circom

func Hash1(api frontend.API, key, value frontend.Variable) frontend.Variable {
inputs := []frontend.Variable{key, value, 1}
return poseidon.Hash(api, inputs)
func Hash1(api frontend.API, hFn utils.Hasher, key frontend.Variable, values ...frontend.Variable) frontend.Variable {
inputs := []frontend.Variable{key}
inputs = append(inputs, values...)
inputs = append(inputs, 1)
hash, err := hFn(api, inputs...)
if err != nil {
panic(err)
}
return hash
}

func Hash2(api frontend.API, l, r frontend.Variable) frontend.Variable {
inputs := []frontend.Variable{l, r}
return poseidon.Hash(api, inputs)
func Hash2(api frontend.API, hFn utils.Hasher, l, r frontend.Variable) frontend.Variable {
hash, err := hFn(api, l, r)
if err != nil {
panic(err)
}
return hash
}
15 changes: 10 additions & 5 deletions tree/smt/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@ package smt

import (
"github.com/consensys/gnark/frontend"
"github.com/vocdoni/gnark-crypto-primitives/utils"
)

// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessor.circom

func Processor(api frontend.API, oldRoot frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, newKey, newValue, fnc0, fnc1 frontend.Variable) (newRoot frontend.Variable) {
func Processor(api frontend.API, hFn utils.Hasher, oldRoot frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, newKey, newValue, fnc0, fnc1 frontend.Variable) (newRoot frontend.Variable) {
hash1Old := Hash1(api, hFn, oldKey, oldValue)
hash1New := Hash1(api, hFn, newKey, newValue)
return ProcessorWithLeafHash(api, hFn, oldRoot, siblings, oldKey, hash1Old, isOld0, newKey, hash1New, fnc0, fnc1)
}

func ProcessorWithLeafHash(api frontend.API, hFn utils.Hasher, oldRoot frontend.Variable, siblings []frontend.Variable, oldKey, hash1Old, isOld0, newKey, hash1New, fnc0, fnc1 frontend.Variable) (newRoot frontend.Variable) {
levels := len(siblings)
enabled := api.Sub(api.Add(fnc0, fnc1), api.Mul(fnc0, fnc1))
hash1Old := Hash1(api, oldKey, oldValue)
hash1New := Hash1(api, newKey, newValue)
n2bOld := api.ToBinary(oldKey, api.Compiler().FieldBitLen())
n2bNew := api.ToBinary(newKey, api.Compiler().FieldBitLen())
smtLevIns := LevIns(api, enabled, siblings)
Expand Down Expand Up @@ -40,9 +45,9 @@ func Processor(api frontend.API, oldRoot frontend.Variable, siblings []frontend.
levelsNewRoot := make([]frontend.Variable, levels)
for i := levels - 1; i >= 0; i-- {
if i == levels-1 {
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0, 0)
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, hFn, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0, 0)
} else {
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], levelsOldRoot[i+1], levelsNewRoot[i+1])
levelsOldRoot[i], levelsNewRoot[i] = ProcessorLevel(api, hFn, stTop[i], stOld0[i], stBot[i], stNew1[i], stUpd[i], siblings[i], hash1Old, hash1New, n2bNew[i], levelsOldRoot[i+1], levelsNewRoot[i+1])
}
}

Expand Down
7 changes: 4 additions & 3 deletions tree/smt/processor_level.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ package smt

import (
"github.com/consensys/gnark/frontend"
"github.com/vocdoni/gnark-crypto-primitives/utils"
)

// based on https://github.com/iden3/circomlib/blob/cff5ab6288b55ef23602221694a6a38a0239dcc0/circuits/smt/smtprocessorlevel.circom

func ProcessorLevel(api frontend.API, stTop, stOld0, stBot, stNew1, stUpd, sibling, old1leaf, new1leaf, newlrbit, oldChild, newChild frontend.Variable) (oldRoot, newRoot frontend.Variable) {
func ProcessorLevel(api frontend.API, hFn utils.Hasher, stTop, stOld0, stBot, stNew1, stUpd, sibling, old1leaf, new1leaf, newlrbit, oldChild, newChild frontend.Variable) (oldRoot, newRoot frontend.Variable) {
oldProofHashL, oldProofHashR := Switcher(api, newlrbit, oldChild, sibling)
oldProofHash := Hash2(api, oldProofHashL, oldProofHashR)
oldProofHash := Hash2(api, hFn, oldProofHashL, oldProofHashR)

oldRoot = api.Add(api.Mul(old1leaf, api.Add(api.Add(stBot, stNew1), stUpd)), api.Mul(oldProofHash, stTop))

newProofHashL, newProofHashR := Switcher(api, newlrbit, api.Add(api.Mul(newChild, api.Add(stTop, stBot)), api.Mul(new1leaf, stNew1)), api.Add(api.Mul(sibling, stTop), api.Mul(old1leaf, stNew1)))
newProofHash := Hash2(api, newProofHashL, newProofHashR)
newProofHash := Hash2(api, hFn, newProofHashL, newProofHashR)

newRoot = api.Add(api.Mul(newProofHash, api.Add(api.Add(stTop, stBot), stNew1)), api.Mul(new1leaf, api.Add(stOld0, stUpd)))
return
Expand Down
19 changes: 10 additions & 9 deletions tree/smt/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@ package smt

import (
"github.com/consensys/gnark/frontend"
"github.com/vocdoni/gnark-crypto-primitives/utils"
)

func InclusionVerifier(api frontend.API, root frontend.Variable, siblings []frontend.Variable, key, value frontend.Variable) {
Verifier(api, 1, root, siblings, key, value, 0, key, value, 0)
func InclusionVerifier(api frontend.API, hFn utils.Hasher, root frontend.Variable, siblings []frontend.Variable, key, value frontend.Variable) {
Verifier(api, hFn, 1, root, siblings, key, value, 0, key, value, 0)
}

func ExclusionVerifier(api frontend.API, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key frontend.Variable) {
Verifier(api, 1, root, siblings, oldKey, oldValue, isOld0, key, 0, 1)
func ExclusionVerifier(api frontend.API, hFn utils.Hasher, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key frontend.Variable) {
Verifier(api, hFn, 1, root, siblings, oldKey, oldValue, isOld0, key, 0, 1)
}

func Verifier(api frontend.API, enabled, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key, value, fnc frontend.Variable) {
func Verifier(api frontend.API, hFn utils.Hasher, enabled, root frontend.Variable, siblings []frontend.Variable, oldKey, oldValue, isOld0, key, value, fnc frontend.Variable) {
nLevels := len(siblings)
hash1Old := Hash1(api, oldKey, oldValue)
hash1New := Hash1(api, key, value)
hash1Old := Hash1(api, hFn, oldKey, oldValue)
hash1New := Hash1(api, hFn, key, value)
n2bNew := api.ToBinary(key, api.Compiler().FieldBitLen())
smtLevIns := LevIns(api, enabled, siblings)

Expand All @@ -36,9 +37,9 @@ func Verifier(api frontend.API, enabled, root frontend.Variable, siblings []fron
levels := make([]frontend.Variable, nLevels)
for i := nLevels - 1; i >= 0; i-- {
if i == nLevels-1 {
levels[i] = VerifierLevel(api, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0)
levels[i] = VerifierLevel(api, hFn, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], 0)
} else {
levels[i] = VerifierLevel(api, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], levels[i+1])
levels[i] = VerifierLevel(api, hFn, stTop[i], stIOld[i], stINew[i], siblings[i], hash1Old, hash1New, n2bNew[i], levels[i+1])
}
}

Expand Down
5 changes: 3 additions & 2 deletions tree/smt/verifier_level.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package smt

import (
"github.com/consensys/gnark/frontend"
"github.com/vocdoni/gnark-crypto-primitives/utils"
)

func VerifierLevel(api frontend.API, stTop, stIOld, stINew, sibling, old1leaf, new1leaf, lrbit, child frontend.Variable) (root frontend.Variable) {
func VerifierLevel(api frontend.API, hFn utils.Hasher, stTop, stIOld, stINew, sibling, old1leaf, new1leaf, lrbit, child frontend.Variable) (root frontend.Variable) {
proofHashL, proofHashR := Switcher(api, lrbit, child, sibling)
proofHash := Hash2(api, proofHashL, proofHashR)
proofHash := Hash2(api, hFn, proofHashL, proofHashR)
root = api.Add(api.Add(api.Mul(proofHash, stTop), api.Mul(old1leaf, stIOld)), api.Mul(new1leaf, stINew))
return
}
20 changes: 20 additions & 0 deletions utils/hashers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package utils

import (
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash/mimc"
)

type Hasher func(frontend.API, ...frontend.Variable) (frontend.Variable, error)

// MiMCHasher is a hash function that hashes the data provided using the
// mimc hash function and the current compiler field. It is used to hash the
// leaves of the census tree during the proof verification.
func MiMCHasher(api frontend.API, data ...frontend.Variable) (frontend.Variable, error) {
h, err := mimc.NewMiMC(api)
if err != nil {
return 0, err
}
h.Write(data...)
return h.Sum(), nil
}
Loading