From 7a27f880727ee82a71e3e7a0e7c715a98432175e Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Tue, 25 Feb 2025 15:39:41 -0600 Subject: [PATCH 01/62] build: update gnark-crypto dep --- constraint/bls12-377/gkr.go | 3 +-- go.mod | 2 +- go.sum | 4 ++-- std/permutation/poseidon2/gkr.go | 4 +++- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/constraint/bls12-377/gkr.go b/constraint/bls12-377/gkr.go index 948ed1510d..8c93902065 100644 --- a/constraint/bls12-377/gkr.go +++ b/constraint/bls12-377/gkr.go @@ -29,9 +29,8 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) - var found bool for i := range noPtr { - if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/go.mod b/go.mod index 8296d75b3d..de44b4615b 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.29 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.16.1-0.20250217214835-5ed804970f85 + github.com/consensys/gnark-crypto v0.16.1-0.20250225213100-17b431f9839c github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index c31ddf141c..df92850436 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.1.29 h1:fobxIYksIQ+ZSrTJUuQgu+HIJwclrAPcdXqd7H2hh github.com/consensys/bavard v0.1.29/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.16.1-0.20250217214835-5ed804970f85 h1:3ht4gGH3smFGVLFhpFTKvDbEdagC6eSaPXnHjCQGh94= -github.com/consensys/gnark-crypto v0.16.1-0.20250217214835-5ed804970f85/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= +github.com/consensys/gnark-crypto v0.16.1-0.20250225213100-17b431f9839c h1:HJ5kUkkD8MKn6X8J3n4FZPVolGH4xfg44eZEmE2pSz0= +github.com/consensys/gnark-crypto v0.16.1-0.20250225213100-17b431f9839c/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/std/permutation/poseidon2/gkr.go b/std/permutation/poseidon2/gkr.go index f52d750bf0..eb23a3653e 100644 --- a/std/permutation/poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr.go @@ -390,7 +390,9 @@ func RegisterGkrSolverOptions(curves ...ecc.ID) { csBls12377.RegisterHashBuilder("mimc", func() hash.Hash { return mimcBls12377.NewMiMC() }) - gkrPoseidon2Bls12377.RegisterGkrGates() + if err := gkrPoseidon2Bls12377.RegisterGkrGates(); err != nil { + panic(err) + } default: panic(fmt.Sprintf("curve %s not currently supported", curve)) } From be478af8db6895f4608d8c8d6b23054c626b2ae1 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Tue, 25 Feb 2025 16:09:04 -0600 Subject: [PATCH 02/62] fix: use registry in testing.go --- constraint/bls12-381/gkr.go | 3 +- constraint/bls24-315/gkr.go | 3 +- constraint/bls24-317/gkr.go | 3 +- constraint/bn254/gkr.go | 3 +- constraint/bw6-633/gkr.go | 3 +- constraint/bw6-761/gkr.go | 3 +- .../template/representations/gkr.go.tmpl | 3 +- std/gkr/testing.go | 130 +++++++++--------- 8 files changed, 69 insertions(+), 82 deletions(-) diff --git a/constraint/bls12-381/gkr.go b/constraint/bls12-381/gkr.go index acf57d9d88..fa81371379 100644 --- a/constraint/bls12-381/gkr.go +++ b/constraint/bls12-381/gkr.go @@ -29,9 +29,8 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) - var found bool for i := range noPtr { - if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bls24-315/gkr.go b/constraint/bls24-315/gkr.go index e39d7447c7..6a018868c1 100644 --- a/constraint/bls24-315/gkr.go +++ b/constraint/bls24-315/gkr.go @@ -29,9 +29,8 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) - var found bool for i := range noPtr { - if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bls24-317/gkr.go b/constraint/bls24-317/gkr.go index 76d080a489..346b397d48 100644 --- a/constraint/bls24-317/gkr.go +++ b/constraint/bls24-317/gkr.go @@ -29,9 +29,8 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) - var found bool for i := range noPtr { - if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bn254/gkr.go b/constraint/bn254/gkr.go index 88dd7905ab..fcf064b696 100644 --- a/constraint/bn254/gkr.go +++ b/constraint/bn254/gkr.go @@ -29,9 +29,8 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) - var found bool for i := range noPtr { - if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bw6-633/gkr.go b/constraint/bw6-633/gkr.go index 81fbe4c52c..2e7d58eff8 100644 --- a/constraint/bw6-633/gkr.go +++ b/constraint/bw6-633/gkr.go @@ -29,9 +29,8 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) - var found bool for i := range noPtr { - if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bw6-761/gkr.go b/constraint/bw6-761/gkr.go index 0066d302c7..35dafd570f 100644 --- a/constraint/bw6-761/gkr.go +++ b/constraint/bw6-761/gkr.go @@ -29,9 +29,8 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) - var found bool for i := range noPtr { - if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/internal/generator/backend/template/representations/gkr.go.tmpl b/internal/generator/backend/template/representations/gkr.go.tmpl index 9030cc43a7..5d788bd570 100644 --- a/internal/generator/backend/template/representations/gkr.go.tmpl +++ b/internal/generator/backend/template/representations/gkr.go.tmpl @@ -22,9 +22,8 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) - var found bool for i := range noPtr { - if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/std/gkr/testing.go b/std/gkr/testing.go index 50111a60e4..c5464c5041 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math/big" + "sync" "github.com/consensys/gnark-crypto/ecc" frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -29,35 +30,13 @@ import ( // This method only works under the test engine and should only be called to debug a GKR circuit, as the GKR prover's errors can be obscure. func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable { res := make([][]frontend.Variable, len(api.toStore.Circuit)) - degreeTestedGates := make(map[string]struct{}) + var degreeTestedGates sync.Map for i, w := range api.toStore.Circuit { res[i] = make([]frontend.Variable, api.nbInstances()) copy(res[i], api.assignments[i]) if len(w.Inputs) == 0 { continue } - degree := Gates[w.Gate].Degree() - var degreeFr int - if parentApi.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) == 0 { - degreeFr = gkrBls12377.Gates[w.Gate].Degree() - } else if parentApi.Compiler().Field().Cmp(ecc.BN254.ScalarField()) == 0 { - degreeFr = gkrBn254.Gates[w.Gate].Degree() - } else if parentApi.Compiler().Field().Cmp(ecc.BLS24_315.ScalarField()) == 0 { - degreeFr = gkrBls24315.Gates[w.Gate].Degree() - } else if parentApi.Compiler().Field().Cmp(ecc.BW6_761.ScalarField()) == 0 { - degreeFr = gkrBw6761.Gates[w.Gate].Degree() - } else if parentApi.Compiler().Field().Cmp(ecc.BLS12_381.ScalarField()) == 0 { - degreeFr = gkrBls12381.Gates[w.Gate].Degree() - } else if parentApi.Compiler().Field().Cmp(ecc.BLS24_317.ScalarField()) == 0 { - degreeFr = gkrBls24317.Gates[w.Gate].Degree() - } else if parentApi.Compiler().Field().Cmp(ecc.BW6_633.ScalarField()) == 0 { - degreeFr = gkrBw6633.Gates[w.Gate].Degree() - } else { - panic("field not yet supported") - } - if degree != degreeFr { - panic(fmt.Errorf("gate \"%s\" degree mismatch: SNARK %d, Raw %d", w.Gate, degree, degreeFr)) - } } for instanceI := range api.nbInstances() { for wireI, w := range api.toStore.Circuit { @@ -84,7 +63,7 @@ func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable for i, in := range w.Inputs { ins[i] = res[in][instanceI] } - expectedV, err := parentApi.Compiler().NewHint(frGateHint(w.Gate, degreeTestedGates), 1, ins...) + expectedV, err := parentApi.Compiler().NewHint(frGateHint(w.Gate, °reeTestedGates), 1, ins...) if err != nil { panic(err) } @@ -96,23 +75,27 @@ func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable return res } -func frGateHint(gateName string, degreeTestedGates map[string]struct{}) hint.Hint { +func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { return func(mod *big.Int, ins, outs []*big.Int) error { + const dummyGateName = "dummy-solve-in-test-engine-gate" + degreeFr := -1 + nbInFr := -1 if len(outs) != 1 { return errors.New("gate must have one output") } if ecc.BLS12_377.ScalarField().Cmp(mod) == 0 { - gate := gkrBls12377.Gates[gateName] + gate := gkrBls12377.GetGate(gateName) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } - if _, ok := degreeTestedGates[gateName]; !ok { - if err := gkrBls12377.TestGateDegree(gate, len(ins)); err != nil { - return fmt.Errorf("gate %s: %w", gateName, err) + degreeFr = gate.Degree() + nbInFr = gate.NbIn() + if _, ok := degreeTestedGates.Load(gateName); !ok { + // re-register the gate to make sure the degree is correct + if err := gkrBls12377.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBls12377.WithDegree(degreeFr)); err != nil { + return err } - degreeTestedGates[gateName] = struct{}{} } - x := make([]frBls12377.Element, len(ins)) for i := range ins { x[i].SetBigInt(ins[i]) @@ -120,17 +103,18 @@ func frGateHint(gateName string, degreeTestedGates map[string]struct{}) hint.Hin y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BN254.ScalarField().Cmp(mod) == 0 { - gate := gkrBn254.Gates[gateName] + gate := gkrBn254.GetGate(gateName) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } - if _, ok := degreeTestedGates[gateName]; !ok { - if err := gkrBn254.TestGateDegree(gate, len(ins)); err != nil { - return fmt.Errorf("gate %s: %w", gateName, err) + degreeFr = gate.Degree() + nbInFr = gate.NbIn() + if _, ok := degreeTestedGates.Load(gateName); !ok { + // re-register the gate to make sure the degree is correct + if err := gkrBn254.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBn254.WithDegree(degreeFr)); err != nil { + return err } - degreeTestedGates[gateName] = struct{}{} } - x := make([]frBn254.Element, len(ins)) for i := range ins { x[i].SetBigInt(ins[i]) @@ -138,17 +122,18 @@ func frGateHint(gateName string, degreeTestedGates map[string]struct{}) hint.Hin y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BLS24_315.ScalarField().Cmp(mod) == 0 { - gate := gkrBls24315.Gates[gateName] + gate := gkrBls24315.GetGate(gateName) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } - if _, ok := degreeTestedGates[gateName]; !ok { - if err := gkrBls24315.TestGateDegree(gate, len(ins)); err != nil { - return fmt.Errorf("gate %s: %w", gateName, err) + degreeFr = gate.Degree() + nbInFr = gate.NbIn() + if _, ok := degreeTestedGates.Load(gateName); !ok { + // re-register the gate to make sure the degree is correct + if err := gkrBls24315.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBls24315.WithDegree(degreeFr)); err != nil { + return err } - degreeTestedGates[gateName] = struct{}{} } - x := make([]frBls24315.Element, len(ins)) for i := range ins { x[i].SetBigInt(ins[i]) @@ -156,17 +141,11 @@ func frGateHint(gateName string, degreeTestedGates map[string]struct{}) hint.Hin y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BW6_761.ScalarField().Cmp(mod) == 0 { - gate := gkrBw6761.Gates[gateName] + gate := gkrBw6761.GetGate(gateName) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } - if _, ok := degreeTestedGates[gateName]; !ok { - if err := gkrBw6761.TestGateDegree(gate, len(ins)); err != nil { - return fmt.Errorf("gate %s: %w", gateName, err) - } - degreeTestedGates[gateName] = struct{}{} - } - + degreeFr = gate.Degree() x := make([]frBw6761.Element, len(ins)) for i := range ins { x[i].SetBigInt(ins[i]) @@ -174,17 +153,18 @@ func frGateHint(gateName string, degreeTestedGates map[string]struct{}) hint.Hin y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BLS12_381.ScalarField().Cmp(mod) == 0 { - gate := gkrBls12381.Gates[gateName] + gate := gkrBls12381.GetGate(gateName) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } - if _, ok := degreeTestedGates[gateName]; !ok { - if err := gkrBls12381.TestGateDegree(gate, len(ins)); err != nil { - return fmt.Errorf("gate %s: %w", gateName, err) + degreeFr = gate.Degree() + nbInFr = gate.NbIn() + if _, ok := degreeTestedGates.Load(gateName); !ok { + // re-register the gate to make sure the degree is correct + if err := gkrBls12381.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBls12381.WithDegree(degreeFr)); err != nil { + return err } - degreeTestedGates[gateName] = struct{}{} } - x := make([]frBls12381.Element, len(ins)) for i := range ins { x[i].SetBigInt(ins[i]) @@ -192,17 +172,18 @@ func frGateHint(gateName string, degreeTestedGates map[string]struct{}) hint.Hin y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BLS24_317.ScalarField().Cmp(mod) == 0 { - gate := gkrBls24317.Gates[gateName] + gate := gkrBls24317.GetGate(gateName) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } - if _, ok := degreeTestedGates[gateName]; !ok { - if err := gkrBls24317.TestGateDegree(gate, len(ins)); err != nil { - return fmt.Errorf("gate %s: %w", gateName, err) + degreeFr = gate.Degree() + nbInFr = gate.NbIn() + if _, ok := degreeTestedGates.Load(gateName); !ok { + // re-register the gate to make sure the degree is correct + if err := gkrBls24317.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBls24317.WithDegree(degreeFr)); err != nil { + return err } - degreeTestedGates[gateName] = struct{}{} } - x := make([]frBls24317.Element, len(ins)) for i := range ins { x[i].SetBigInt(ins[i]) @@ -210,15 +191,17 @@ func frGateHint(gateName string, degreeTestedGates map[string]struct{}) hint.Hin y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BW6_633.ScalarField().Cmp(mod) == 0 { - gate := gkrBw6633.Gates[gateName] + gate := gkrBw6633.GetGate(gateName) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } - if _, ok := degreeTestedGates[gateName]; !ok { - if err := gkrBw6633.TestGateDegree(gate, len(ins)); err != nil { - return fmt.Errorf("gate %s: %w", gateName, err) + degreeFr = gate.Degree() + nbInFr = gate.NbIn() + if _, ok := degreeTestedGates.Load(gateName); !ok { + // re-register the gate to make sure the degree is correct + if err := gkrBw6633.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBw6633.WithDegree(degreeFr)); err != nil { + return err } - degreeTestedGates[gateName] = struct{}{} } x := make([]frBw6633.Element, len(ins)) for i := range ins { @@ -229,6 +212,17 @@ func frGateHint(gateName string, degreeTestedGates map[string]struct{}) hint.Hin } else { return errors.New("field not supported") } + + degreeTestedGates.Store(gateName, struct{}{}) + + if degreeFr != Gates[gateName].Degree() { + return fmt.Errorf("gate \"%s\" degree mismatch: SNARK %d, Raw %d", gateName, Gates[gateName].Degree(), degreeFr) + } + + if nbInFr != len(ins) { // TODO @Tabaie also check against Gates[gateName].NbIn() + return fmt.Errorf("gate \"%s\" input count mismatch: SNARK %d, Raw %d", gateName, len(ins), nbInFr) + } + return nil } } From 54a8fcc49081c9e952628ce733a87cf3c2dc53d4 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Tue, 4 Mar 2025 12:56:21 -0600 Subject: [PATCH 03/62] refactor: adapt gkr package to gnark-crypto changes --- go.mod | 2 +- go.sum | 4 +- std/gkr/api.go | 12 +- std/gkr/api_test.go | 30 +++- std/gkr/compile.go | 2 +- std/gkr/gkr.go | 228 ++++++++++++++++++++------ std/gkr/gkr_test.go | 17 +- std/gkr/internal/bn254_wrapper_api.go | 201 +++++++++++++++++++++++ std/gkr/testing.go | 8 +- 9 files changed, 425 insertions(+), 79 deletions(-) create mode 100644 std/gkr/internal/bn254_wrapper_api.go diff --git a/go.mod b/go.mod index de44b4615b..9c60f9bdd6 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.29 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.16.1-0.20250225213100-17b431f9839c + github.com/consensys/gnark-crypto v0.16.1-0.20250304175949-a15b42865c78 github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index df92850436..2285d82464 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.1.29 h1:fobxIYksIQ+ZSrTJUuQgu+HIJwclrAPcdXqd7H2hh github.com/consensys/bavard v0.1.29/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.16.1-0.20250225213100-17b431f9839c h1:HJ5kUkkD8MKn6X8J3n4FZPVolGH4xfg44eZEmE2pSz0= -github.com/consensys/gnark-crypto v0.16.1-0.20250225213100-17b431f9839c/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= +github.com/consensys/gnark-crypto v0.16.1-0.20250304175949-a15b42865c78 h1:6CmnJn2aDi2g3NcJ7XpmETQiZVCasZmJNOeGtvgL1Wg= +github.com/consensys/gnark-crypto v0.16.1-0.20250304175949-a15b42865c78/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/std/gkr/api.go b/std/gkr/api.go index eb1acd2afe..2751f31d4c 100644 --- a/std/gkr/api.go +++ b/std/gkr/api.go @@ -28,18 +28,18 @@ func (api *API) namedGate2PlusIn(gate string, in1, in2 constraint.GkrVariable, i return api.NamedGate(gate, inCombined...) } -func (api *API) Add(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { - return api.namedGate2PlusIn("add", i1, i2, in...) +func (api *API) Add(i1, i2 constraint.GkrVariable) constraint.GkrVariable { + return api.namedGate2PlusIn("add2", i1, i2) } func (api *API) Neg(i1 constraint.GkrVariable) constraint.GkrVariable { return api.NamedGate("neg", i1) } -func (api *API) Sub(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { - return api.namedGate2PlusIn("sub", i1, i2, in...) +func (api *API) Sub(i1, i2 constraint.GkrVariable) constraint.GkrVariable { + return api.namedGate2PlusIn("sub2", i1, i2) } -func (api *API) Mul(i1, i2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { - return api.namedGate2PlusIn("mul", i1, i2, in...) +func (api *API) Mul(i1, i2 constraint.GkrVariable) constraint.GkrVariable { + return api.namedGate2PlusIn("mul2", i1, i2) } diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go index 10817fed7b..8371acaa2c 100644 --- a/std/gkr/api_test.go +++ b/std/gkr/api_test.go @@ -433,8 +433,32 @@ func init() { } func registerMiMCGate() { - Gates["mimc"] = MiMCCipherGate{Ark: 0} - gkr.Gates["mimc"] = mimcCipherGate{} + panicIfError(RegisterGate("mimc", func(api frontend.API, input ...frontend.Variable) frontend.Variable { + mimcSnarkTotalCalls++ + + if len(input) != 2 { + panic("mimc has fan-in 2") + } + sum := api.Add(input[0], input[1] /*, m.Ark*/) + + sumCubed := api.Mul(sum, sum, sum) // sum^3 + return api.Mul(sumCubed, sumCubed, sum) + }, 2, WithDegree(7))) + + panicIfError(gkr.RegisterGate("mimc", func(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + mimcFrTotalCalls++ + return res + }, 2, gkr.WithDegree(7))) } type constPseudoHash int @@ -597,7 +621,6 @@ func BenchmarkMiMCNoGkrFullDepthSolve(b *testing.B) { func TestMiMCFullDepthNoDepSolve(t *testing.T) { assert := test.NewAssert(t) - registerMiMC() for i := 0; i < 100; i++ { circuit, assignment := mimcNoDepCircuits(5, 1<<2, "-20") assert.Run(func(assert *test.Assert) { @@ -608,7 +631,6 @@ func TestMiMCFullDepthNoDepSolve(t *testing.T) { func TestMiMCFullDepthNoDepSolveWithMiMCHash(t *testing.T) { assert := test.NewAssert(t) - registerMiMC() circuit, assignment := mimcNoDepCircuits(5, 1<<2, "mimc") assert.CheckCircuit(circuit, test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254)) } diff --git a/std/gkr/compile.go b/std/gkr/compile.go index 4623bb0db0..b077063368 100644 --- a/std/gkr/compile.go +++ b/std/gkr/compile.go @@ -223,7 +223,7 @@ func newCircuitDataForSnark(info constraint.GkrInfo, assignment assignment) circ for i := range circuit { w := info.Circuit[i] circuit[i] = Wire{ - Gate: ite(w.IsInput(), Gates[w.Gate], Gate(IdentityGate{})), + Gate: GetGate(ite(w.IsInput(), w.Gate, "identity")), Inputs: utils.Map(w.Inputs, circuitAt), nbUniqueOutputs: w.NbUniqueOutputs, } diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index 0da52730a9..0751cbfa42 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -1,9 +1,13 @@ package gkr import ( + "crypto/rand" "errors" "fmt" + bn254Gkr "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" + "github.com/consensys/gnark/std/gkr/internal" "strconv" + "sync" "github.com/consensys/gnark/frontend" fiatshamir "github.com/consensys/gnark/std/fiat-shamir" @@ -15,14 +19,171 @@ import ( // The goal is to prove/verify evaluations of many instances of the same circuit -// Gate must be a low-degree polynomial -type Gate interface { - Evaluate(frontend.API, ...frontend.Variable) frontend.Variable - Degree() int +type GateFunction func(frontend.API, ...frontend.Variable) frontend.Variable + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + linearVar int // if there is a variable of degree 1, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// LinearVar returns the index of a variable of degree 1 in the gate's polynomial. If there is no such variable, it returns -1. +func (g *Gate) LinearVar() int { + return g.linearVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +var ( + gates = make(map[string]*Gate) + gatesLock sync.Mutex +) + +/*type registerGateSettings struct { + linearVar int + noLinearVarVerification bool + noDegreeVerification bool + degree int +}*/ + +// here options are not defined as functions on settings to make translation to their field counterpart easier +// TODO @Tabaie once GKR is moved to gnark, use the same options/settings type for all curves, obviating this + +type registerGateOptionType byte + +const ( + registerGateOptionTypeWithLinearVar registerGateOptionType = iota + registerGateOptionTypeWithUnverifiedLinearVar + registerGateOptionTypeWithNoLinearVar + registerGateOptionTypeWithUnverifiedDegree + registerGateOptionTypeWithDegree +) + +type registerGateOption struct { + tp registerGateOptionType + param int +} + +// WithLinearVar gives the index of a variable of degree 1 in the gate's polynomial. RegisterGate will return an error if the given index is not correct. +func WithLinearVar(linearVar int) *registerGateOption { + return ®isterGateOption{ + tp: registerGateOptionTypeWithLinearVar, + param: linearVar, + } +} + +// WithUnverifiedLinearVar sets the index of a variable of degree 1 in the gate's polynomial. RegisterGate will not verify that the given index is correct. +func WithUnverifiedLinearVar(linearVar int) *registerGateOption { + return ®isterGateOption{ + tp: registerGateOptionTypeWithUnverifiedLinearVar, + param: linearVar, + } +} + +// WithNoLinearVar sets the gate as having no variable of degree 1. RegisterGate will not check the correctness of this claim. +func WithNoLinearVar() *registerGateOption { + return ®isterGateOption{ + tp: registerGateOptionTypeWithNoLinearVar, + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) *registerGateOption { + return ®isterGateOption{ + tp: registerGateOptionTypeWithUnverifiedDegree, + param: degree, + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) *registerGateOption { + return ®isterGateOption{ + tp: registerGateOptionTypeWithDegree, + param: degree, + } +} + +// RegisterGate creates a gate object and stores it in the gates registry +// name is a human-readable name for the gate +// f is the polynomial function defining the gate +// nbIn is the number of inputs to the gate +// NB! This package generally expects certain properties of the gate to be invariant across all curves. +// In particular the degree is computed and verified over BN254. If the leading coefficient is divided by +// the curve's order, the degree will be computed incorrectly. +func RegisterGate(name string, f GateFunction, nbIn int, options ...*registerGateOption) error { + frF := internal.ToBn254GateFunction(f) // delegate tests to bn254 + var nameRand [4]byte + if _, err := rand.Read(nameRand[:]); err != nil { + return err + } + frName := fmt.Sprintf("%s-test-%x", name, nameRand) + frOptions := make([]bn254Gkr.RegisterGateOption, 0, len(options)) + + // translate options + for _, opt := range options { + switch opt.tp { + case registerGateOptionTypeWithLinearVar: + frOptions = append(frOptions, bn254Gkr.WithLinearVar(opt.param)) + case registerGateOptionTypeWithUnverifiedLinearVar: + frOptions = append(frOptions, bn254Gkr.WithUnverifiedLinearVar(opt.param)) + case registerGateOptionTypeWithNoLinearVar: + frOptions = append(frOptions, bn254Gkr.WithNoLinearVar()) + case registerGateOptionTypeWithUnverifiedDegree: + frOptions = append(frOptions, bn254Gkr.WithUnverifiedDegree(opt.param)) + case registerGateOptionTypeWithDegree: + frOptions = append(frOptions, bn254Gkr.WithDegree(opt.param)) + default: + return fmt.Errorf("unknown option type %d", opt.tp) + } + } + + if err := bn254Gkr.RegisterGate(frName, frF, nbIn, frOptions...); err != nil { + return err + } + bn254Gate := bn254Gkr.GetGate(frName) + bn254Gkr.RemoveGate(frName) + + gatesLock.Lock() + defer gatesLock.Unlock() + + gates[name] = &Gate{ + Evaluate: f, + nbIn: nbIn, + degree: bn254Gate.Degree(), + linearVar: bn254Gate.LinearVar(), + } + + return nil +} + +func GetGate(name string) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +func RemoveGate(name string) bool { + gatesLock.Lock() + defer gatesLock.Unlock() + _, found := gates[name] + if found { + delete(gates, name) + } + return found } type Wire struct { - Gate Gate + Gate *Gate Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) } @@ -349,16 +510,6 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, return nil } -type IdentityGate struct{} - -func (IdentityGate) Evaluate(_ frontend.API, input ...frontend.Variable) frontend.Variable { - return input[0] -} - -func (IdentityGate) Degree() int { - return 1 -} - // outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. func outputsList(c Circuit, indexes map[*Wire]int) [][]int { res := make([][]int, len(c)) @@ -366,7 +517,7 @@ func outputsList(c Circuit, indexes map[*Wire]int) [][]int { res[i] = make([]int, 0) c[i].nbUniqueOutputs = 0 if c[i].IsInput() { - c[i].Gate = IdentityGate{} + c[i].Gate = GetGate("identity") } } ins := make(map[int]struct{}, len(c)) @@ -533,39 +684,20 @@ func DeserializeProof(sorted []*Wire, serializedProof []frontend.Variable) (Proo return proof, nil } -type MulGate struct{} - -func (g MulGate) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Variable { - if len(x) != 2 { - panic("mul has fan-in 2") - } - return api.Mul(x[0], x[1]) -} - -// TODO: Degree must take nbInputs as an argument and return degree = nbInputs -func (g MulGate) Degree() int { - return 2 +func init() { + panicIfError(RegisterGate("mul2", func(api frontend.API, x ...frontend.Variable) frontend.Variable { + return api.Mul(x[0], x[1]) + }, 2, WithUnverifiedDegree(2), WithNoLinearVar())) + panicIfError(RegisterGate("add2", func(api frontend.API, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[1]) + }, 2, WithUnverifiedDegree(1), WithUnverifiedLinearVar(0))) + panicIfError(RegisterGate("identity", func(api frontend.API, x ...frontend.Variable) frontend.Variable { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedLinearVar(0))) } -type AddGate struct{} - -func (a AddGate) Evaluate(api frontend.API, v ...frontend.Variable) frontend.Variable { - switch len(v) { - case 0: - return 0 - case 1: - return v[0] +func panicIfError(err error) { + if err != nil { + panic(err) } - rest := v[2:] - return api.Add(v[0], v[1], rest...) -} - -func (a AddGate) Degree() int { - return 1 -} - -var Gates = map[string]Gate{ - "identity": IdentityGate{}, - "add": AddGate{}, - "mul": MulGate{}, } diff --git a/std/gkr/gkr_test.go b/std/gkr/gkr_test.go index d24b25a95c..324160a985 100644 --- a/std/gkr/gkr_test.go +++ b/std/gkr/gkr_test.go @@ -249,8 +249,7 @@ func (c CircuitInfo) toCircuit() (circuit Circuit, err error) { circuit[i].Inputs[iAsInput] = input } - var found bool - if circuit[i].Gate, found = Gates[wireInfo.Gate]; !found && wireInfo.Gate != "" { + if circuit[i].Gate = GetGate(wireInfo.Gate); circuit[i].Gate == nil && wireInfo.Gate != "" { err = fmt.Errorf("undefined gate \"%s\"", wireInfo.Gate) } } @@ -258,18 +257,10 @@ func (c CircuitInfo) toCircuit() (circuit Circuit, err error) { return } -type _select int - func init() { - Gates["select-input-3"] = _select(2) -} - -func (g _select) Evaluate(_ frontend.API, in ...frontend.Variable) frontend.Variable { - return in[g] -} - -func (g _select) Degree() int { - return 1 + panicIfError(RegisterGate("select-input-3", func(api frontend.API, in ...frontend.Variable) frontend.Variable { + return in[2] + }, 3, WithDegree(1))) } type PrintableProof []PrintableSumcheckProof diff --git a/std/gkr/internal/bn254_wrapper_api.go b/std/gkr/internal/bn254_wrapper_api.go new file mode 100644 index 0000000000..270c26201a --- /dev/null +++ b/std/gkr/internal/bn254_wrapper_api.go @@ -0,0 +1,201 @@ +package internal + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/utils" + "math/big" +) + +// wrap BN254 scalar field arithmetic in a frontend.API +// bn254WrapperApi uses *fr.Element as its variable type +type bn254WrapperApi struct { + err error +} + +func ToBn254GateFunction(f func(frontend.API, ...frontend.Variable) frontend.Variable) gkr.GateFunction { + var wrapper bn254WrapperApi + + return func(x ...fr.Element) fr.Element { + if wrapper.err != nil { + return fr.Element{} + } + res := f(&wrapper, utils.Map(x, func(x fr.Element) frontend.Variable { + return &x + })...).(*fr.Element) + + return *res + } +} + +func (w *bn254WrapperApi) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + var res fr.Element + res.Add(w.cast(i1), w.cast(i2)) + for i := range in { + res.Add(&res, w.cast(in[i])) + } + + return &res +} + +func (w *bn254WrapperApi) MulAcc(a, b, c frontend.Variable) frontend.Variable { + var res fr.Element + res.Mul(w.cast(b), w.cast(c)) + res.Add(&res, w.cast(a)) + return &res +} + +func (w *bn254WrapperApi) Neg(i1 frontend.Variable) frontend.Variable { + var res fr.Element + res.Neg(w.cast(i1)) + return &res +} + +func (w *bn254WrapperApi) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + var res fr.Element + res.Sub(w.cast(i1), w.cast(i2)) + for i := range in { + res.Sub(&res, w.cast(in[i])) + } + return &res +} + +func (w *bn254WrapperApi) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + var res fr.Element + res.Mul(w.cast(i1), w.cast(i2)) + for i := range in { + res.Mul(&res, w.cast(in[i])) + } + return &res +} + +func (w *bn254WrapperApi) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { + return w.Div(i1, i2) +} + +func (w *bn254WrapperApi) Div(i1, i2 frontend.Variable) frontend.Variable { + return w.Mul(i1, w.Inverse(i2)) +} + +func (w *bn254WrapperApi) Inverse(i1 frontend.Variable) frontend.Variable { + w.newError("only polynomial (ring) operations supported") + return nil +} + +func (w *bn254WrapperApi) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) FromBinary(b ...frontend.Variable) frontend.Variable { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) Xor(a, b frontend.Variable) frontend.Variable { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) Or(a, b frontend.Variable) frontend.Variable { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) And(a, b frontend.Variable) frontend.Variable { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) Select(frontend.Variable, frontend.Variable, frontend.Variable) frontend.Variable { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) Lookup2(frontend.Variable, frontend.Variable, frontend.Variable, frontend.Variable, frontend.Variable, frontend.Variable) frontend.Variable { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) IsZero(frontend.Variable) frontend.Variable { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) Cmp(frontend.Variable, frontend.Variable) frontend.Variable { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) AssertIsEqual(i1, i2 frontend.Variable) { + w.newError("only field operations supported") +} + +func (w *bn254WrapperApi) AssertIsDifferent(frontend.Variable, frontend.Variable) { + w.newError("only field operations supported") +} + +func (w *bn254WrapperApi) AssertIsBoolean(frontend.Variable) { + w.newError("only field operations supported") +} + +func (w *bn254WrapperApi) AssertIsCrumb(frontend.Variable) { + w.newError("only field operations supported") +} + +func (w *bn254WrapperApi) AssertIsLessOrEqual(frontend.Variable, frontend.Variable) { + w.newError("only field operations supported") +} + +func (w *bn254WrapperApi) Println(a ...frontend.Variable) { + toPrint := make([]any, len(a)) + for i, v := range a { + if x := w.cast(v); w.err == nil { + toPrint[i] = x[i] + } else { + return + } + } + fmt.Println(toPrint...) +} + +func (w *bn254WrapperApi) Compiler() frontend.Compiler { + w.newError("only field operations supported") + return nil +} + +func (w *bn254WrapperApi) NewHint(solver.Hint, int, ...frontend.Variable) ([]frontend.Variable, error) { + err := errors.New("only field operations supported") + w.emitError(err) + return nil, err +} + +func (w *bn254WrapperApi) ConstantValue(frontend.Variable) (*big.Int, bool) { + w.newError("only field operations supported") + return nil, false +} + +func (w *bn254WrapperApi) cast(v frontend.Variable) *fr.Element { + var res fr.Element + if w.err != nil { + return &res + } + if _, err := res.SetInterface(v); err != nil { + w.emitError(err) + } + return &res +} + +func (w *bn254WrapperApi) emitError(err error) { + if w.err == nil { + w.err = err + } +} + +func (w *bn254WrapperApi) newError(msg string) { + w.emitError(errors.New(msg)) +} diff --git a/std/gkr/testing.go b/std/gkr/testing.go index c5464c5041..03eff78698 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -67,7 +67,7 @@ func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable if err != nil { panic(err) } - res[wireI][instanceI] = Gates[w.Gate].Evaluate(parentApi, ins...) + res[wireI][instanceI] = GetGate(w.Gate).Evaluate(parentApi, ins...) parentApi.AssertIsEqual(expectedV[0], res[wireI][instanceI]) // snark and raw gate evaluations must agree } } @@ -215,11 +215,11 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { degreeTestedGates.Store(gateName, struct{}{}) - if degreeFr != Gates[gateName].Degree() { - return fmt.Errorf("gate \"%s\" degree mismatch: SNARK %d, Raw %d", gateName, Gates[gateName].Degree(), degreeFr) + if degreeFr != GetGate(gateName).Degree() { + return fmt.Errorf("gate \"%s\" degree mismatch: SNARK %d, Raw %d", gateName, GetGate(gateName).Degree(), degreeFr) } - if nbInFr != len(ins) { // TODO @Tabaie also check against Gates[gateName].NbIn() + if nbInFr != len(ins) { // TODO @Tabaie also check against GetGate(gateName].NbIn() return fmt.Errorf("gate \"%s\" input count mismatch: SNARK %d, Raw %d", gateName, len(ins), nbInFr) } From d66a1b9bfd64e5eb9dd477869a2fcb94bb775095 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Wed, 5 Mar 2025 16:07:47 -0600 Subject: [PATCH 04/62] refactor: linearVar -> solvableVar --- go.mod | 2 +- go.sum | 4 +- std/gkr/gkr.go | 68 +++++------ std/gkr/testing.go | 13 ++- std/permutation/poseidon2/gkr.go | 156 +++++++++++--------------- std/permutation/poseidon2/gkr_test.go | 8 +- 6 files changed, 119 insertions(+), 132 deletions(-) diff --git a/go.mod b/go.mod index 9c60f9bdd6..763403f54f 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.29 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.16.1-0.20250304175949-a15b42865c78 + github.com/consensys/gnark-crypto v0.16.1-0.20250305220457-01ad8f324f1c github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index 2285d82464..cb2bafcefa 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,8 @@ github.com/consensys/bavard v0.1.29 h1:fobxIYksIQ+ZSrTJUuQgu+HIJwclrAPcdXqd7H2hh github.com/consensys/bavard v0.1.29/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.16.1-0.20250304175949-a15b42865c78 h1:6CmnJn2aDi2g3NcJ7XpmETQiZVCasZmJNOeGtvgL1Wg= -github.com/consensys/gnark-crypto v0.16.1-0.20250304175949-a15b42865c78/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= +github.com/consensys/gnark-crypto v0.16.1-0.20250305220457-01ad8f324f1c h1:C6dAj70uBKJQe/x1b6c6InJYKcgR2SM4mvbXa3ZdLkI= +github.com/consensys/gnark-crypto v0.16.1-0.20250305220457-01ad8f324f1c/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index 0751cbfa42..60efa50421 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -23,10 +23,10 @@ type GateFunction func(frontend.API, ...frontend.Variable) frontend.Variable // A Gate is a low-degree multivariate polynomial type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - linearVar int // if there is a variable of degree 1, its index, -1 otherwise + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a variable whose value can be uniquely determined from the value of the gate and the other inputs, its index, -1 otherwise } // Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 @@ -34,9 +34,9 @@ func (g *Gate) Degree() int { return g.degree } -// LinearVar returns the index of a variable of degree 1 in the gate's polynomial. If there is no such variable, it returns -1. -func (g *Gate) LinearVar() int { - return g.linearVar +// SolvableVar returns the index of a variable of degree 1 in the gate's polynomial. If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar } // NbIn returns the number of inputs to the gate (its fan-in) @@ -50,8 +50,8 @@ var ( ) /*type registerGateSettings struct { - linearVar int - noLinearVarVerification bool + solvableVar int + noSolvableVarVerification bool noDegreeVerification bool degree int }*/ @@ -62,9 +62,9 @@ var ( type registerGateOptionType byte const ( - registerGateOptionTypeWithLinearVar registerGateOptionType = iota - registerGateOptionTypeWithUnverifiedLinearVar - registerGateOptionTypeWithNoLinearVar + registerGateOptionTypeWithSolvableVar registerGateOptionType = iota + registerGateOptionTypeWithUnverifiedSolvableVar + registerGateOptionTypeWithNoSolvableVar registerGateOptionTypeWithUnverifiedDegree registerGateOptionTypeWithDegree ) @@ -74,26 +74,26 @@ type registerGateOption struct { param int } -// WithLinearVar gives the index of a variable of degree 1 in the gate's polynomial. RegisterGate will return an error if the given index is not correct. -func WithLinearVar(linearVar int) *registerGateOption { +// WithSolvableVar gives the index of a variable of degree 1 in the gate's polynomial. RegisterGate will return an error if the given index is not correct. +func WithSolvableVar(linearVar int) *registerGateOption { return ®isterGateOption{ - tp: registerGateOptionTypeWithLinearVar, + tp: registerGateOptionTypeWithSolvableVar, param: linearVar, } } -// WithUnverifiedLinearVar sets the index of a variable of degree 1 in the gate's polynomial. RegisterGate will not verify that the given index is correct. -func WithUnverifiedLinearVar(linearVar int) *registerGateOption { +// WithUnverifiedSolvableVar sets the index of a variable of degree 1 in the gate's polynomial. RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(linearVar int) *registerGateOption { return ®isterGateOption{ - tp: registerGateOptionTypeWithUnverifiedLinearVar, + tp: registerGateOptionTypeWithUnverifiedSolvableVar, param: linearVar, } } -// WithNoLinearVar sets the gate as having no variable of degree 1. RegisterGate will not check the correctness of this claim. -func WithNoLinearVar() *registerGateOption { +// WithNoSolvableVar sets the gate as having no variable of degree 1. RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() *registerGateOption { return ®isterGateOption{ - tp: registerGateOptionTypeWithNoLinearVar, + tp: registerGateOptionTypeWithNoSolvableVar, } } @@ -132,12 +132,12 @@ func RegisterGate(name string, f GateFunction, nbIn int, options ...*registerGat // translate options for _, opt := range options { switch opt.tp { - case registerGateOptionTypeWithLinearVar: - frOptions = append(frOptions, bn254Gkr.WithLinearVar(opt.param)) - case registerGateOptionTypeWithUnverifiedLinearVar: - frOptions = append(frOptions, bn254Gkr.WithUnverifiedLinearVar(opt.param)) - case registerGateOptionTypeWithNoLinearVar: - frOptions = append(frOptions, bn254Gkr.WithNoLinearVar()) + case registerGateOptionTypeWithSolvableVar: + frOptions = append(frOptions, bn254Gkr.WithSolvableVar(opt.param)) + case registerGateOptionTypeWithUnverifiedSolvableVar: + frOptions = append(frOptions, bn254Gkr.WithUnverifiedSolvableVar(opt.param)) + case registerGateOptionTypeWithNoSolvableVar: + frOptions = append(frOptions, bn254Gkr.WithNoSolvableVar()) case registerGateOptionTypeWithUnverifiedDegree: frOptions = append(frOptions, bn254Gkr.WithUnverifiedDegree(opt.param)) case registerGateOptionTypeWithDegree: @@ -157,10 +157,10 @@ func RegisterGate(name string, f GateFunction, nbIn int, options ...*registerGat defer gatesLock.Unlock() gates[name] = &Gate{ - Evaluate: f, - nbIn: nbIn, - degree: bn254Gate.Degree(), - linearVar: bn254Gate.LinearVar(), + Evaluate: f, + nbIn: nbIn, + degree: bn254Gate.Degree(), + solvableVar: bn254Gate.SolvableVar(), } return nil @@ -687,13 +687,13 @@ func DeserializeProof(sorted []*Wire, serializedProof []frontend.Variable) (Proo func init() { panicIfError(RegisterGate("mul2", func(api frontend.API, x ...frontend.Variable) frontend.Variable { return api.Mul(x[0], x[1]) - }, 2, WithUnverifiedDegree(2), WithNoLinearVar())) + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar())) panicIfError(RegisterGate("add2", func(api frontend.API, x ...frontend.Variable) frontend.Variable { return api.Add(x[0], x[1]) - }, 2, WithUnverifiedDegree(1), WithUnverifiedLinearVar(0))) + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0))) panicIfError(RegisterGate("identity", func(api frontend.API, x ...frontend.Variable) frontend.Variable { return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedLinearVar(0))) + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0))) } func panicIfError(err error) { diff --git a/std/gkr/testing.go b/std/gkr/testing.go index 03eff78698..357c01943c 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -78,8 +78,9 @@ func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { return func(mod *big.Int, ins, outs []*big.Int) error { const dummyGateName = "dummy-solve-in-test-engine-gate" - degreeFr := -1 + var degreeFr int nbInFr := -1 + solvableVarFr := -1 if len(outs) != 1 { return errors.New("gate must have one output") } @@ -90,6 +91,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { } degreeFr = gate.Degree() nbInFr = gate.NbIn() + solvableVarFr = gate.SolvableVar() if _, ok := degreeTestedGates.Load(gateName); !ok { // re-register the gate to make sure the degree is correct if err := gkrBls12377.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBls12377.WithDegree(degreeFr)); err != nil { @@ -109,6 +111,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { } degreeFr = gate.Degree() nbInFr = gate.NbIn() + solvableVarFr = gate.SolvableVar() if _, ok := degreeTestedGates.Load(gateName); !ok { // re-register the gate to make sure the degree is correct if err := gkrBn254.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBn254.WithDegree(degreeFr)); err != nil { @@ -128,6 +131,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { } degreeFr = gate.Degree() nbInFr = gate.NbIn() + solvableVarFr = gate.SolvableVar() if _, ok := degreeTestedGates.Load(gateName); !ok { // re-register the gate to make sure the degree is correct if err := gkrBls24315.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBls24315.WithDegree(degreeFr)); err != nil { @@ -159,6 +163,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { } degreeFr = gate.Degree() nbInFr = gate.NbIn() + solvableVarFr = gate.SolvableVar() if _, ok := degreeTestedGates.Load(gateName); !ok { // re-register the gate to make sure the degree is correct if err := gkrBls12381.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBls12381.WithDegree(degreeFr)); err != nil { @@ -178,6 +183,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { } degreeFr = gate.Degree() nbInFr = gate.NbIn() + solvableVarFr = gate.SolvableVar() if _, ok := degreeTestedGates.Load(gateName); !ok { // re-register the gate to make sure the degree is correct if err := gkrBls24317.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBls24317.WithDegree(degreeFr)); err != nil { @@ -197,6 +203,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { } degreeFr = gate.Degree() nbInFr = gate.NbIn() + solvableVarFr = gate.SolvableVar() if _, ok := degreeTestedGates.Load(gateName); !ok { // re-register the gate to make sure the degree is correct if err := gkrBw6633.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBw6633.WithDegree(degreeFr)); err != nil { @@ -223,6 +230,10 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { return fmt.Errorf("gate \"%s\" input count mismatch: SNARK %d, Raw %d", gateName, len(ins), nbInFr) } + if solvableVarFr != GetGate(gateName).SolvableVar() { + return fmt.Errorf("gate \"%s\" designated solvable variable mismatch: SNARK %d, Raw %d", gateName, GetGate(gateName).SolvableVar(), solvableVarFr) + } + return nil } } diff --git a/std/permutation/poseidon2/gkr.go b/std/permutation/poseidon2/gkr.go index eb23a3653e..9fc42c4b5c 100644 --- a/std/permutation/poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr.go @@ -23,25 +23,17 @@ import ( // extKeyGate applies the external matrix mul, then adds the round key // because of its symmetry, we don't need to define distinct x1 and x2 versions of it -type extKeyGate struct { - roundKey *big.Int -} - -func (g *extKeyGate) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Variable { - if len(x) != 2 { - panic("expected 2 inputs") +func extKeyGate(roundKey *big.Int) gkr.GateFunction { + return func(api frontend.API, x ...frontend.Variable) frontend.Variable { + if len(x) != 2 { + panic("expected 2 inputs") + } + return api.Add(api.Mul(x[0], 2), x[1], roundKey) } - return api.Add(api.Mul(x[0], 2), x[1], g.roundKey) -} - -func (g *extKeyGate) Degree() int { - return 1 } // pow4Gate computes a -> a⁴ -type pow4Gate struct{} - -func (g pow4Gate) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Variable { +func pow4Gate(api frontend.API, x ...frontend.Variable) frontend.Variable { if len(x) != 1 { panic("expected 1 input") } @@ -51,14 +43,8 @@ func (g pow4Gate) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Va return y } -func (g pow4Gate) Degree() int { - return 4 -} - -// pow4Gate computes a, b -> a⁴ * b -type pow4TimesGate struct{} - -func (g pow4TimesGate) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Variable { +// pow4TimesGate computes a, b -> a⁴ * b +func pow4TimesGate(api frontend.API, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 1 input") } @@ -68,103 +54,80 @@ func (g pow4TimesGate) Evaluate(api frontend.API, x ...frontend.Variable) fronte return api.Mul(y, x[1]) } -func (g pow4TimesGate) Degree() int { - return 5 -} - -type pow2Gate struct{} - -func (g pow2Gate) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Variable { +// pow2Gate computes a -> a² +func pow2Gate(api frontend.API, x ...frontend.Variable) frontend.Variable { if len(x) != 1 { panic("expected 1 input") } return api.Mul(x[0], x[0]) } -func (g pow2Gate) Degree() int { - return 2 -} - -type pow2TimesGate struct{} - -func (g pow2TimesGate) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Variable { +// pow2TimesGate computes a, b -> a² * b +func pow2TimesGate(api frontend.API, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } return api.Mul(x[0], x[0], x[1]) } -func (g pow2TimesGate) Degree() int { - return 3 -} - // for x1, the partial round gates are identical to full round gates // for x2, the partial round gates are just a linear combination // TODO @Tabaie try eliminating the x2 partial round gates and have the x1 gates depend on i - rf/2 or so previous x1's // extGate2 applies the external matrix mul, outputting the second element of the result -type extGate2 struct { -} - -func (g *extGate2) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Variable { +func extGate2(api frontend.API, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } return api.Add(api.Mul(x[1], 2), x[0]) } -func (g *extGate2) Degree() int { - return 1 -} - // intKeyGate2 applies the internal matrix mul, then adds the round key -type intKeyGate2 struct { - roundKey *big.Int -} - -func (g *intKeyGate2) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Variable { - if len(x) != 2 { - panic("expected 2 inputs") +func intKeyGate2(roundKey *big.Int) gkr.GateFunction { + return func(api frontend.API, x ...frontend.Variable) frontend.Variable { + if len(x) != 2 { + panic("expected 2 inputs") + } + return api.Add(api.Mul(x[1], 3), x[0], roundKey) } - return api.Add(api.Mul(x[1], 3), x[0], g.roundKey) -} - -func (g *intKeyGate2) Degree() int { - return 1 } -type extGate struct{} - -func (g extGate) Evaluate(api frontend.API, x ...frontend.Variable) frontend.Variable { +// extGate applies the first row of the external matrix +func extGate(api frontend.API, x ...frontend.Variable) frontend.Variable { if len(x) != 2 { panic("expected 2 inputs") } return api.Add(api.Mul(x[0], 2), x[1]) } -func (g extGate) Degree() int { - return 1 +// extAddGate applies the first row of the external matrix to the first two elements and adds the third +func extAddGate(api frontend.API, x ...frontend.Variable) frontend.Variable { + if len(x) != 3 { + panic("expected 3 inputs") + } + return api.Add(api.Mul(x[0], 2), x[1], x[2]) } -type GkrPermutations struct { +type GkrCompressions struct { api frontend.API ins1 []frontend.Variable ins2 []frontend.Variable outs []frontend.Variable } -// NewGkrPermutations returns an object that can compute the Poseidon2 permutation (currently only for BLS12-377) -// The correctness of the permutations is proven using GKR +// NewGkrCompressions returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) +// which consists of a permutation along with the input fed forward. +// The correctness of the compression functions is proven using GKR. // Note that the solver will need the function RegisterGkrSolverOptions to be called with the desired curves -func NewGkrPermutations(api frontend.API) *GkrPermutations { - res := GkrPermutations{ +func NewGkrCompressions(api frontend.API) *GkrCompressions { + res := GkrCompressions{ api: api, } api.Compiler().Defer(res.finalize) return &res } -func (p *GkrPermutations) Permute(a, b frontend.Variable) frontend.Variable { +func (p *GkrCompressions) Compress(a, b frontend.Variable) frontend.Variable { s, err := p.api.Compiler().NewHint(permuteHint, 1, a, b) if err != nil { panic(err) @@ -205,18 +168,23 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. return nil, -1, err } y, err := gkrApi.Import(insRight) + y0 := y // save to feed forward at the end if err != nil { return nil, -1, err } // unique names for linear rounds - gateNameLinear := func(varI, round int) string { + gateNameSolvable := func(varI, round int) string { return fmt.Sprintf("x%d-l-op-round=%d;%s", varI, round, params) } // the s-Box gates: u¹⁷ = (u⁴)⁴ * u - gkr.Gates["pow4"] = pow4Gate{} - gkr.Gates["pow4Times"] = pow4TimesGate{} + if err = gkr.RegisterGate("pow4", pow4Gate, 1, gkr.WithUnverifiedDegree(4), gkr.WithNoSolvableVar()); err != nil { + return nil, -1, err + } + if err = gkr.RegisterGate("pow4Times", pow4TimesGate, 2, gkr.WithUnverifiedDegree(5), gkr.WithNoSolvableVar()); err != nil { + return nil, -1, err + } // *** helper functions to register and apply gates *** @@ -235,9 +203,9 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // register and apply external matrix multiplication and round key addition // round dependent due to the round key extKeySBox := func(round, varI int, a, b constraint.GkrVariable) constraint.GkrVariable { - gate := gateNameLinear(varI, round) - gkr.Gates[gate] = &extKeyGate{ - roundKey: frToInt(&roundKeysFr[round][varI]), + gate := gateNameSolvable(varI, round) + if err = gkr.RegisterGate(gate, extKeyGate(frToInt(&roundKeysFr[round][varI])), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return -1 } return sBox(gkrApi.NamedGate(gate, a, b)) } @@ -247,9 +215,9 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // for the second variable // round independent due to the round key intKeySBox2 := func(round int, a, b constraint.GkrVariable) constraint.GkrVariable { - gate := gateNameLinear(yI, round) - gkr.Gates[gate] = &intKeyGate2{ - roundKey: frToInt(&roundKeysFr[round][1]), + gate := gateNameSolvable(yI, round) + if err = gkr.RegisterGate(gate, intKeyGate2(frToInt(&roundKeysFr[round][1])), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return -1 } return sBox(gkrApi.NamedGate(gate, a, b)) } @@ -271,8 +239,10 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // still using the external matrix, since the linear operation still belongs to a full (canonical) round x1 := extKeySBox(halfRf, xI, x, y) - gate := gateNameLinear(yI, halfRf) - gkr.Gates[gate] = &extGate2{} + gate := gateNameSolvable(yI, halfRf) + if err = gkr.RegisterGate(gate, extGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return nil, -1, err + } x, y = x1, gkrApi.NamedGate(gate, x, y) } @@ -280,9 +250,9 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. for i := halfRf + 1; i < halfRf+rP; i++ { x1 := extKeySBox(i, xI, x, y) // the first row of the internal matrix is the same as that of the external matrix - gate := gateNameLinear(yI, i) - gkr.Gates[gate] = &intKeyGate2{ - roundKey: zero, + gate := gateNameSolvable(yI, i) + if err = gkr.RegisterGate(gate, intKeyGate2(zero), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return nil, -1, err } x, y = x1, gkrApi.NamedGate(gate, x, y) } @@ -300,14 +270,16 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. } // apply the external matrix one last time to obtain the final value of y - gate := gateNameLinear(yI, rP+rF) - gkr.Gates[gate] = extGate{} - y = gkrApi.NamedGate(gate, y, x) + gate := gateNameSolvable(yI, rP+rF) + if err = gkr.RegisterGate(gate, extAddGate, 3, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return nil, -1, err + } + y = gkrApi.NamedGate(gate, y, x, y0) return gkrApi, y, nil } -func (p *GkrPermutations) finalize(api frontend.API) error { +func (p *GkrCompressions) finalize(api frontend.API) error { if p.api != api { panic("unexpected API") } @@ -333,7 +305,7 @@ func (p *GkrPermutations) finalize(api frontend.API) error { if err != nil { return err } - + // connect to output // TODO can we save 1 constraint per instance by giving the desired outputs to the gkr api? solution, err := gkrApi.Solve(api) @@ -367,8 +339,10 @@ func permuteHint(m *big.Int, ins, outs []*big.Int) error { var x [2]frBls12377.Element x[0].SetBigInt(ins[0]) x[1].SetBigInt(ins[1]) + y0 := x[1] err := bls12377Permutation().Permutation(x[:]) + x[1].Add(&x[1], &y0) // feed forward x[1].BigInt(outs[0]) return err } diff --git a/std/permutation/poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr_test.go index 1cc39c4053..96a9c2494c 100644 --- a/std/permutation/poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr_test.go @@ -11,7 +11,7 @@ import ( "testing" ) -func TestGkrPermutation(t *testing.T) { +func TestGkrCompression(t *testing.T) { const n = 2 var k int64 ins := make([][2]frontend.Variable, n) @@ -22,8 +22,10 @@ func TestGkrPermutation(t *testing.T) { x[0].SetInt64(k) x[1].SetInt64(k + 1) + y0 := x[1] require.NoError(t, bls12377Permutation().Permutation(x[:])) + x[1].Add(&x[1], &y0) outs[i] = x[1] k += 2 @@ -46,10 +48,10 @@ type testGkrPermutationCircuit struct { func (c *testGkrPermutationCircuit) Define(api frontend.API) error { - pos2 := NewGkrPermutations(api) + pos2 := NewGkrCompressions(api) api.AssertIsEqual(len(c.Ins), len(c.Outs)) for i := range c.Ins { - api.AssertIsEqual(c.Outs[i], pos2.Permute(c.Ins[i][0], c.Ins[i][1])) + api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) } return nil From c19628b9f652dda10086baf2da3ed30f7dae7a84 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Wed, 5 Mar 2025 16:17:24 -0600 Subject: [PATCH 05/62] fix: poseidon2 Compress feed fwd --- std/permutation/poseidon2/poseidon2.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index f7a8d4068a..55afe73be5 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -328,5 +328,5 @@ func (h *Permutation) Compress(left, right frontend.Variable) frontend.Variable if err := h.Permutation(vars[:]); err != nil { panic(err) // this would never happen } - return vars[1] + return h.api.Add(vars[1], right) } From fc62f364f18bffc0307d9e33a7df8f9231c53b51 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Wed, 5 Mar 2025 16:17:43 -0600 Subject: [PATCH 06/62] style: gofmt --- std/permutation/poseidon2/gkr.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/permutation/poseidon2/gkr.go b/std/permutation/poseidon2/gkr.go index 9fc42c4b5c..63ab381878 100644 --- a/std/permutation/poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr.go @@ -305,7 +305,7 @@ func (p *GkrCompressions) finalize(api frontend.API) error { if err != nil { return err } - + // connect to output // TODO can we save 1 constraint per instance by giving the desired outputs to the gkr api? solution, err := gkrApi.Solve(api) From 4cd2932eacfc1cb0f6875e05d9be137beb2f4988 Mon Sep 17 00:00:00 2001 From: Arya Tabaie <15056835+Tabaie@users.noreply.github.com> Date: Thu, 6 Mar 2025 12:59:13 -0600 Subject: [PATCH 07/62] fix: GKR testing --- go.mod | 2 +- go.sum | 2 -- .../resources/single_input_two_outs.json | 2 +- std/gkr/test_vectors/resources/single_mul_gate.json | 2 +- std/gkr/testing.go | 12 +++++++++--- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 763403f54f..0c70d6e553 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.29 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.16.1-0.20250305220457-01ad8f324f1c + github.com/consensys/gnark-crypto v0.16.1-0.20250306184109-d5afd4fa04d3 github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index cb2bafcefa..3493afadc1 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,6 @@ github.com/consensys/bavard v0.1.29 h1:fobxIYksIQ+ZSrTJUuQgu+HIJwclrAPcdXqd7H2hh github.com/consensys/bavard v0.1.29/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.16.1-0.20250305220457-01ad8f324f1c h1:C6dAj70uBKJQe/x1b6c6InJYKcgR2SM4mvbXa3ZdLkI= -github.com/consensys/gnark-crypto v0.16.1-0.20250305220457-01ad8f324f1c/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/std/gkr/test_vectors/resources/single_input_two_outs.json b/std/gkr/test_vectors/resources/single_input_two_outs.json index c577c1cace..3a39e5625f 100644 --- a/std/gkr/test_vectors/resources/single_input_two_outs.json +++ b/std/gkr/test_vectors/resources/single_input_two_outs.json @@ -4,7 +4,7 @@ "inputs": [] }, { - "gate": "mul", + "gate": "mul2", "inputs": [0, 0] }, { diff --git a/std/gkr/test_vectors/resources/single_mul_gate.json b/std/gkr/test_vectors/resources/single_mul_gate.json index 0f65a07edf..d009ebe03d 100644 --- a/std/gkr/test_vectors/resources/single_mul_gate.json +++ b/std/gkr/test_vectors/resources/single_mul_gate.json @@ -8,7 +8,7 @@ "inputs": [] }, { - "gate": "mul", + "gate": "mul2", "inputs": [0, 1] } ] \ No newline at end of file diff --git a/std/gkr/testing.go b/std/gkr/testing.go index 357c01943c..7ee8c5760a 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -78,9 +78,7 @@ func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { return func(mod *big.Int, ins, outs []*big.Int) error { const dummyGateName = "dummy-solve-in-test-engine-gate" - var degreeFr int - nbInFr := -1 - solvableVarFr := -1 + var degreeFr, nbInFr, solvableVarFr int if len(outs) != 1 { return errors.New("gate must have one output") } @@ -150,6 +148,14 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { return fmt.Errorf("gate \"%s\" not found", gateName) } degreeFr = gate.Degree() + nbInFr = gate.NbIn() + solvableVarFr = gate.SolvableVar() + if _, ok := degreeTestedGates.Load(gateName); !ok { + // re-register the gate to make sure the degree is correct + if err := gkrBw6761.RegisterGate(dummyGateName, gate.Evaluate, nbInFr, gkrBw6761.WithDegree(degreeFr)); err != nil { + return err + } + } x := make([]frBw6761.Element, len(ins)) for i := range ins { x[i].SetBigInt(ins[i]) From 7f13b98b75f5076ce5f9567dc0b7637951d43127 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 6 Mar 2025 13:12:29 -0600 Subject: [PATCH 08/62] build: gnark-crypto checksum --- go.sum | 2 ++ 1 file changed, 2 insertions(+) diff --git a/go.sum b/go.sum index 3493afadc1..c133dfac89 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,8 @@ github.com/consensys/bavard v0.1.29 h1:fobxIYksIQ+ZSrTJUuQgu+HIJwclrAPcdXqd7H2hh github.com/consensys/bavard v0.1.29/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= +github.com/consensys/gnark-crypto v0.16.1-0.20250306184109-d5afd4fa04d3 h1:MfPrVLq28Vvg0Xp9g8UMzuEsRxTsXssotJ3w0uFPh3A= +github.com/consensys/gnark-crypto v0.16.1-0.20250306184109-d5afd4fa04d3/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= From 71b310175f00da62462217a8ca81f309839a8388 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 26 Mar 2025 19:20:52 -0500 Subject: [PATCH 09/62] build: update gnark-crypto dep --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 0c70d6e553..e3aa584208 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,9 @@ toolchain go1.22.6 require ( github.com/bits-and-blooms/bitset v1.20.0 github.com/blang/semver/v4 v4.0.0 - github.com/consensys/bavard v0.1.29 + github.com/consensys/bavard v0.1.31-0.20250314194434-b30d4344e6d4 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.16.1-0.20250306184109-d5afd4fa04d3 + github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1 github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index c133dfac89..7e436d0c47 100644 --- a/go.sum +++ b/go.sum @@ -57,12 +57,12 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/consensys/bavard v0.1.29 h1:fobxIYksIQ+ZSrTJUuQgu+HIJwclrAPcdXqd7H2hh1k= -github.com/consensys/bavard v0.1.29/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= +github.com/consensys/bavard v0.1.31-0.20250314194434-b30d4344e6d4 h1:0J+ppRic2ZXsQE+Y+Lr9miam+RQVcWqwqe3SeiggR6s= +github.com/consensys/bavard v0.1.31-0.20250314194434-b30d4344e6d4/go.mod h1:k/zVjHHC4B+PQy1Pg7fgvG3ALicQw540Crag8qx+dZs= github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19Ntk= github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= -github.com/consensys/gnark-crypto v0.16.1-0.20250306184109-d5afd4fa04d3 h1:MfPrVLq28Vvg0Xp9g8UMzuEsRxTsXssotJ3w0uFPh3A= -github.com/consensys/gnark-crypto v0.16.1-0.20250306184109-d5afd4fa04d3/go.mod h1:A2URlMHUT81ifJ0UlLzSlm7TmnE3t7VxEThApdMukJw= +github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1 h1:6cK71BoMAjWHNl+EpvBh2PDDa0PIeoz1KFJ/6R16DjQ= +github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1/go.mod h1:uV1HwfBwGRj50DGK3LbDLeCvq0RX/vFXST3CRSAu0Fs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= From 9d5b88ede1842066d018128950ea144d05773011 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 26 Mar 2025 19:25:28 -0500 Subject: [PATCH 10/62] build: go generate --- internal/tinyfield/element.go | 10 ++++++ internal/tinyfield/element_test.go | 50 +++++++++++++++--------------- internal/tinyfield/vector.go | 24 ++++++++++++++ internal/tinyfield/vector_test.go | 2 +- 4 files changed, 60 insertions(+), 26 deletions(-) diff --git a/internal/tinyfield/element.go b/internal/tinyfield/element.go index dcaa56b23f..cab95f0cbc 100644 --- a/internal/tinyfield/element.go +++ b/internal/tinyfield/element.go @@ -302,6 +302,16 @@ func (z *Element) SetRandom() (*Element, error) { } } +// MustSetRandom sets z to a uniform random value in [0, q). +// +// It panics if reading from crypto/rand.Reader errors. +func (z *Element) MustSetRandom() *Element { + if _, err := z.SetRandom(); err != nil { + panic(err) + } + return z +} + // smallerThanModulus returns true if z < q // This is not constant time func (z *Element) smallerThanModulus() bool { diff --git a/internal/tinyfield/element_test.go b/internal/tinyfield/element_test.go index 7383f4dac3..7d2286ca78 100644 --- a/internal/tinyfield/element_test.go +++ b/internal/tinyfield/element_test.go @@ -30,8 +30,8 @@ var benchResElement Element func BenchmarkElementSelect(b *testing.B) { var x, y Element - x.SetRandom() - y.SetRandom() + x.MustSetRandom() + y.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -41,17 +41,17 @@ func BenchmarkElementSelect(b *testing.B) { func BenchmarkElementSetRandom(b *testing.B) { var x Element - x.SetRandom() + x.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - _, _ = x.SetRandom() + x.MustSetRandom() } } func BenchmarkElementSetBytes(b *testing.B) { var x Element - x.SetRandom() + x.MustSetRandom() bb := x.Bytes() b.ResetTimer() @@ -63,21 +63,21 @@ func BenchmarkElementSetBytes(b *testing.B) { func BenchmarkElementMulByConstants(b *testing.B) { b.Run("mulBy3", func(b *testing.B) { - benchResElement.SetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { MulBy3(&benchResElement) } }) b.Run("mulBy5", func(b *testing.B) { - benchResElement.SetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { MulBy5(&benchResElement) } }) b.Run("mulBy13", func(b *testing.B) { - benchResElement.SetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { MulBy13(&benchResElement) @@ -87,8 +87,8 @@ func BenchmarkElementMulByConstants(b *testing.B) { func BenchmarkElementInverse(b *testing.B) { var x Element - x.SetRandom() - benchResElement.SetRandom() + x.MustSetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { @@ -99,8 +99,8 @@ func BenchmarkElementInverse(b *testing.B) { func BenchmarkElementButterfly(b *testing.B) { var x Element - x.SetRandom() - benchResElement.SetRandom() + x.MustSetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { Butterfly(&x, &benchResElement) @@ -109,8 +109,8 @@ func BenchmarkElementButterfly(b *testing.B) { func BenchmarkElementExp(b *testing.B) { var x Element - x.SetRandom() - benchResElement.SetRandom() + x.MustSetRandom() + benchResElement.MustSetRandom() b1, _ := rand.Int(rand.Reader, Modulus()) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -119,7 +119,7 @@ func BenchmarkElementExp(b *testing.B) { } func BenchmarkElementDouble(b *testing.B) { - benchResElement.SetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { benchResElement.Double(&benchResElement) @@ -128,8 +128,8 @@ func BenchmarkElementDouble(b *testing.B) { func BenchmarkElementAdd(b *testing.B) { var x Element - x.SetRandom() - benchResElement.SetRandom() + x.MustSetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { benchResElement.Add(&x, &benchResElement) @@ -138,8 +138,8 @@ func BenchmarkElementAdd(b *testing.B) { func BenchmarkElementSub(b *testing.B) { var x Element - x.SetRandom() - benchResElement.SetRandom() + x.MustSetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { benchResElement.Sub(&x, &benchResElement) @@ -147,7 +147,7 @@ func BenchmarkElementSub(b *testing.B) { } func BenchmarkElementNeg(b *testing.B) { - benchResElement.SetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { benchResElement.Neg(&benchResElement) @@ -156,8 +156,8 @@ func BenchmarkElementNeg(b *testing.B) { func BenchmarkElementDiv(b *testing.B) { var x Element - x.SetRandom() - benchResElement.SetRandom() + x.MustSetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { benchResElement.Div(&x, &benchResElement) @@ -165,7 +165,7 @@ func BenchmarkElementDiv(b *testing.B) { } func BenchmarkElementFromMont(b *testing.B) { - benchResElement.SetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { benchResElement.fromMont() @@ -173,7 +173,7 @@ func BenchmarkElementFromMont(b *testing.B) { } func BenchmarkElementSquare(b *testing.B) { - benchResElement.SetRandom() + benchResElement.MustSetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { benchResElement.Square(&benchResElement) @@ -248,7 +248,7 @@ func TestElementNegZero(t *testing.T) { var a, b Element b.SetZero() for a.IsZero() { - a.SetRandom() + a.MustSetRandom() } a.Neg(&b) if !a.IsZero() { diff --git a/internal/tinyfield/vector.go b/internal/tinyfield/vector.go index 0755dabf7e..dbdc94fb7f 100644 --- a/internal/tinyfield/vector.go +++ b/internal/tinyfield/vector.go @@ -185,6 +185,30 @@ func (vector Vector) Swap(i, j int) { vector[i], vector[j] = vector[j], vector[i] } +// SetRandom sets the elements in vector to independent uniform random values in [0, q). +// +// This might error only if reading from crypto/rand.Reader errors, +// in which case the values in vector are undefined. +func (vector Vector) SetRandom() error { + for i := range vector { + if _, err := vector[i].SetRandom(); err != nil { + return err + } + } + return nil +} + +// MustSetRandom sets the elements in vector to independent uniform random values in [0, q). +// +// It panics if reading from crypto/rand.Reader errors. +func (vector Vector) MustSetRandom() { + for i := range vector { + if _, err := vector[i].SetRandom(); err != nil { + panic(err) + } + } +} + func addVecGeneric(res, a, b Vector) { if len(a) != len(b) || len(a) != len(res) { panic("vector.Add: vectors don't have the same length") diff --git a/internal/tinyfield/vector_test.go b/internal/tinyfield/vector_test.go index 36ab4f9fb6..c923c0df53 100644 --- a/internal/tinyfield/vector_test.go +++ b/internal/tinyfield/vector_test.go @@ -234,7 +234,7 @@ func BenchmarkVectorOps(b *testing.B) { b1 := make(Vector, N) c1 := make(Vector, N) var mixer Element - mixer.SetRandom() + mixer.MustSetRandom() for i := 1; i < N; i++ { a1[i-1].SetUint64(uint64(i)). Mul(&a1[i-1], &mixer) From 218128acb6cd848a9f12e25a09610d6520b89ada Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 26 Mar 2025 19:49:51 -0500 Subject: [PATCH 11/62] refactor: to match GKR Gates@gnark-crypto --- constraint/bls12-377/gkr.go | 2 +- constraint/bls12-381/gkr.go | 2 +- constraint/bls24-315/gkr.go | 2 +- constraint/bls24-317/gkr.go | 2 +- constraint/bn254/gkr.go | 2 +- constraint/bw6-633/gkr.go | 2 +- constraint/bw6-761/gkr.go | 2 +- .../template/representations/gkr.go.tmpl | 2 +- std/gkr/gkr.go | 152 +----------------- std/gkr/registry.go | 119 ++++++++++++++ std/gkr/testing.go | 14 +- 11 files changed, 135 insertions(+), 166 deletions(-) create mode 100644 std/gkr/registry.go diff --git a/constraint/bls12-377/gkr.go b/constraint/bls12-377/gkr.go index 8c93902065..744f22525c 100644 --- a/constraint/bls12-377/gkr.go +++ b/constraint/bls12-377/gkr.go @@ -30,7 +30,7 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) for i := range noPtr { - if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(gkr.GateName(noPtr[i].Gate)); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bls12-381/gkr.go b/constraint/bls12-381/gkr.go index fa81371379..b3b22b9a95 100644 --- a/constraint/bls12-381/gkr.go +++ b/constraint/bls12-381/gkr.go @@ -30,7 +30,7 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) for i := range noPtr { - if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(gkr.GateName(noPtr[i].Gate)); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bls24-315/gkr.go b/constraint/bls24-315/gkr.go index 6a018868c1..ba328c8bb4 100644 --- a/constraint/bls24-315/gkr.go +++ b/constraint/bls24-315/gkr.go @@ -30,7 +30,7 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) for i := range noPtr { - if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(gkr.GateName(noPtr[i].Gate)); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bls24-317/gkr.go b/constraint/bls24-317/gkr.go index 346b397d48..be02e3455c 100644 --- a/constraint/bls24-317/gkr.go +++ b/constraint/bls24-317/gkr.go @@ -30,7 +30,7 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) for i := range noPtr { - if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(gkr.GateName(noPtr[i].Gate)); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bn254/gkr.go b/constraint/bn254/gkr.go index fcf064b696..21731b8ac9 100644 --- a/constraint/bn254/gkr.go +++ b/constraint/bn254/gkr.go @@ -30,7 +30,7 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) for i := range noPtr { - if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(gkr.GateName(noPtr[i].Gate)); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bw6-633/gkr.go b/constraint/bw6-633/gkr.go index 2e7d58eff8..125da817df 100644 --- a/constraint/bw6-633/gkr.go +++ b/constraint/bw6-633/gkr.go @@ -30,7 +30,7 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) for i := range noPtr { - if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(gkr.GateName(noPtr[i].Gate)); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/constraint/bw6-761/gkr.go b/constraint/bw6-761/gkr.go index 35dafd570f..f40856cc36 100644 --- a/constraint/bw6-761/gkr.go +++ b/constraint/bw6-761/gkr.go @@ -30,7 +30,7 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) for i := range noPtr { - if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(gkr.GateName(noPtr[i].Gate)); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/internal/generator/backend/template/representations/gkr.go.tmpl b/internal/generator/backend/template/representations/gkr.go.tmpl index 5d788bd570..89b4b8dbc5 100644 --- a/internal/generator/backend/template/representations/gkr.go.tmpl +++ b/internal/generator/backend/template/representations/gkr.go.tmpl @@ -23,7 +23,7 @@ type GkrSolvingData struct { func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { resCircuit := make(gkr.Circuit, len(noPtr)) for i := range noPtr { - if resCircuit[i].Gate = gkr.GetGate(noPtr[i].Gate); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { + if resCircuit[i].Gate = gkr.GetGate(gkr.GateName(noPtr[i].Gate)); resCircuit[i].Gate == nil && noPtr[i].Gate != "" { return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) } resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index 60efa50421..c6dd67515a 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -1,18 +1,13 @@ package gkr import ( - "crypto/rand" "errors" "fmt" - bn254Gkr "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" - "github.com/consensys/gnark/std/gkr/internal" - "strconv" - "sync" - "github.com/consensys/gnark/frontend" fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/polynomial" "github.com/consensys/gnark/std/sumcheck" + "strconv" ) // @tabaie TODO: Contains many things copy-pasted from gnark-crypto. Generify somehow? @@ -44,144 +39,6 @@ func (g *Gate) NbIn() int { return g.nbIn } -var ( - gates = make(map[string]*Gate) - gatesLock sync.Mutex -) - -/*type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -}*/ - -// here options are not defined as functions on settings to make translation to their field counterpart easier -// TODO @Tabaie once GKR is moved to gnark, use the same options/settings type for all curves, obviating this - -type registerGateOptionType byte - -const ( - registerGateOptionTypeWithSolvableVar registerGateOptionType = iota - registerGateOptionTypeWithUnverifiedSolvableVar - registerGateOptionTypeWithNoSolvableVar - registerGateOptionTypeWithUnverifiedDegree - registerGateOptionTypeWithDegree -) - -type registerGateOption struct { - tp registerGateOptionType - param int -} - -// WithSolvableVar gives the index of a variable of degree 1 in the gate's polynomial. RegisterGate will return an error if the given index is not correct. -func WithSolvableVar(linearVar int) *registerGateOption { - return ®isterGateOption{ - tp: registerGateOptionTypeWithSolvableVar, - param: linearVar, - } -} - -// WithUnverifiedSolvableVar sets the index of a variable of degree 1 in the gate's polynomial. RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(linearVar int) *registerGateOption { - return ®isterGateOption{ - tp: registerGateOptionTypeWithUnverifiedSolvableVar, - param: linearVar, - } -} - -// WithNoSolvableVar sets the gate as having no variable of degree 1. RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() *registerGateOption { - return ®isterGateOption{ - tp: registerGateOptionTypeWithNoSolvableVar, - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) *registerGateOption { - return ®isterGateOption{ - tp: registerGateOptionTypeWithUnverifiedDegree, - param: degree, - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) *registerGateOption { - return ®isterGateOption{ - tp: registerGateOptionTypeWithDegree, - param: degree, - } -} - -// RegisterGate creates a gate object and stores it in the gates registry -// name is a human-readable name for the gate -// f is the polynomial function defining the gate -// nbIn is the number of inputs to the gate -// NB! This package generally expects certain properties of the gate to be invariant across all curves. -// In particular the degree is computed and verified over BN254. If the leading coefficient is divided by -// the curve's order, the degree will be computed incorrectly. -func RegisterGate(name string, f GateFunction, nbIn int, options ...*registerGateOption) error { - frF := internal.ToBn254GateFunction(f) // delegate tests to bn254 - var nameRand [4]byte - if _, err := rand.Read(nameRand[:]); err != nil { - return err - } - frName := fmt.Sprintf("%s-test-%x", name, nameRand) - frOptions := make([]bn254Gkr.RegisterGateOption, 0, len(options)) - - // translate options - for _, opt := range options { - switch opt.tp { - case registerGateOptionTypeWithSolvableVar: - frOptions = append(frOptions, bn254Gkr.WithSolvableVar(opt.param)) - case registerGateOptionTypeWithUnverifiedSolvableVar: - frOptions = append(frOptions, bn254Gkr.WithUnverifiedSolvableVar(opt.param)) - case registerGateOptionTypeWithNoSolvableVar: - frOptions = append(frOptions, bn254Gkr.WithNoSolvableVar()) - case registerGateOptionTypeWithUnverifiedDegree: - frOptions = append(frOptions, bn254Gkr.WithUnverifiedDegree(opt.param)) - case registerGateOptionTypeWithDegree: - frOptions = append(frOptions, bn254Gkr.WithDegree(opt.param)) - default: - return fmt.Errorf("unknown option type %d", opt.tp) - } - } - - if err := bn254Gkr.RegisterGate(frName, frF, nbIn, frOptions...); err != nil { - return err - } - bn254Gate := bn254Gkr.GetGate(frName) - bn254Gkr.RemoveGate(frName) - - gatesLock.Lock() - defer gatesLock.Unlock() - - gates[name] = &Gate{ - Evaluate: f, - nbIn: nbIn, - degree: bn254Gate.Degree(), - solvableVar: bn254Gate.SolvableVar(), - } - - return nil -} - -func GetGate(name string) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -func RemoveGate(name string) bool { - gatesLock.Lock() - defer gatesLock.Unlock() - _, found := gates[name] - if found { - delete(gates, name) - } - return found -} - type Wire struct { Gate *Gate Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire @@ -385,13 +242,6 @@ func ProofSize(c Circuit, logNbInstances int) int { return nbUniqueInputs + nbPartialEvalPolys*logNbInstances } -func max(a, b int) int { - if a > b { - return a - } - return b -} - func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { // Pre-compute the size TODO: Consider not doing this and just grow the list by appending diff --git a/std/gkr/registry.go b/std/gkr/registry.go new file mode 100644 index 0000000000..db3fad3901 --- /dev/null +++ b/std/gkr/registry.go @@ -0,0 +1,119 @@ +package gkr + +import ( + "fmt" + "github.com/consensys/gnark/std/gkr/internal" + "sync" +) + +var ( + gates = make(map[string]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +// TODO @Tabaie once GKR is moved to gnark, use the same options/settings type for all curves, obviating this + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// RegisterGate creates a gate object and stores it in the gates registry +// name is a human-readable name for the gate +// f is the polynomial function defining the gate +// nbIn is the number of inputs to the gate +// NB! This package generally expects certain properties of the gate to be invariant across all curves. +// In particular the degree is computed and verified over BN254. If the leading coefficient is divided by +// the curve's order, the degree will be computed incorrectly. +func RegisterGate(name string, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + frF := internal.ToBn254GateFunction(f) + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = frF.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := frF.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = frF.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !frF.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name string) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} diff --git a/std/gkr/testing.go b/std/gkr/testing.go index 7ee8c5760a..ef6ebde0cf 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -83,7 +83,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { return errors.New("gate must have one output") } if ecc.BLS12_377.ScalarField().Cmp(mod) == 0 { - gate := gkrBls12377.GetGate(gateName) + gate := gkrBls12377.GetGate(gkrBls12377.GateName(gateName)) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } @@ -103,7 +103,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BN254.ScalarField().Cmp(mod) == 0 { - gate := gkrBn254.GetGate(gateName) + gate := gkrBn254.GetGate(gkrBn254.GateName(gateName)) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } @@ -123,7 +123,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BLS24_315.ScalarField().Cmp(mod) == 0 { - gate := gkrBls24315.GetGate(gateName) + gate := gkrBls24315.GetGate(gkrBls24315.GateName(gateName)) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } @@ -143,7 +143,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BW6_761.ScalarField().Cmp(mod) == 0 { - gate := gkrBw6761.GetGate(gateName) + gate := gkrBw6761.GetGate(gkrBw6761.GateName(gateName)) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } @@ -163,7 +163,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BLS12_381.ScalarField().Cmp(mod) == 0 { - gate := gkrBls12381.GetGate(gateName) + gate := gkrBls12381.GetGate(gkrBls12381.GateName(gateName)) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } @@ -183,7 +183,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BLS24_317.ScalarField().Cmp(mod) == 0 { - gate := gkrBls24317.GetGate(gateName) + gate := gkrBls24317.GetGate(gkrBls24317.GateName(gateName)) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } @@ -203,7 +203,7 @@ func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { y := gate.Evaluate(x...) y.BigInt(outs[0]) } else if ecc.BW6_633.ScalarField().Cmp(mod) == 0 { - gate := gkrBw6633.GetGate(gateName) + gate := gkrBw6633.GetGate(gkrBw6633.GateName(gateName)) if gate == nil { return fmt.Errorf("gate \"%s\" not found", gateName) } From 948f95c0192ec94ba9a3a20838aee4a1c056cd6d Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 26 Mar 2025 21:17:45 -0500 Subject: [PATCH 12/62] chore: gkr.GateName --- std/gkr/api.go | 6 +++--- std/gkr/compile.go | 2 +- std/gkr/gkr.go | 12 ----------- std/gkr/gkr_test.go | 2 +- std/gkr/registry.go | 35 +++++++++++++++++++++++++++++--- std/gkr/testing.go | 6 +++--- std/hash/poseidon2/poseidon2.go | 2 +- std/permutation/poseidon2/gkr.go | 17 ++++++---------- 8 files changed, 47 insertions(+), 35 deletions(-) diff --git a/std/gkr/api.go b/std/gkr/api.go index 2751f31d4c..7fe1cf073d 100644 --- a/std/gkr/api.go +++ b/std/gkr/api.go @@ -9,16 +9,16 @@ func frontendVarToInt(a constraint.GkrVariable) int { return int(a) } -func (api *API) NamedGate(gate string, in ...constraint.GkrVariable) constraint.GkrVariable { +func (api *API) NamedGate(gate GateName, in ...constraint.GkrVariable) constraint.GkrVariable { api.toStore.Circuit = append(api.toStore.Circuit, constraint.GkrWire{ - Gate: gate, + Gate: string(gate), Inputs: utils.Map(in, frontendVarToInt), }) api.assignments = append(api.assignments, nil) return constraint.GkrVariable(len(api.toStore.Circuit) - 1) } -func (api *API) namedGate2PlusIn(gate string, in1, in2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { +func (api *API) namedGate2PlusIn(gate GateName, in1, in2 constraint.GkrVariable, in ...constraint.GkrVariable) constraint.GkrVariable { inCombined := make([]constraint.GkrVariable, 2+len(in)) inCombined[0] = in1 inCombined[1] = in2 diff --git a/std/gkr/compile.go b/std/gkr/compile.go index b077063368..3d683fc9e3 100644 --- a/std/gkr/compile.go +++ b/std/gkr/compile.go @@ -223,7 +223,7 @@ func newCircuitDataForSnark(info constraint.GkrInfo, assignment assignment) circ for i := range circuit { w := info.Circuit[i] circuit[i] = Wire{ - Gate: GetGate(ite(w.IsInput(), w.Gate, "identity")), + Gate: GetGate(ite(w.IsInput(), GateName(w.Gate), Identity)), Inputs: utils.Map(w.Inputs, circuitAt), nbUniqueOutputs: w.NbUniqueOutputs, } diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index c6dd67515a..49be099ca1 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -534,18 +534,6 @@ func DeserializeProof(sorted []*Wire, serializedProof []frontend.Variable) (Proo return proof, nil } -func init() { - panicIfError(RegisterGate("mul2", func(api frontend.API, x ...frontend.Variable) frontend.Variable { - return api.Mul(x[0], x[1]) - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar())) - panicIfError(RegisterGate("add2", func(api frontend.API, x ...frontend.Variable) frontend.Variable { - return api.Add(x[0], x[1]) - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0))) - panicIfError(RegisterGate("identity", func(api frontend.API, x ...frontend.Variable) frontend.Variable { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0))) -} - func panicIfError(err error) { if err != nil { panic(err) diff --git a/std/gkr/gkr_test.go b/std/gkr/gkr_test.go index 324160a985..6a9eb8b4f8 100644 --- a/std/gkr/gkr_test.go +++ b/std/gkr/gkr_test.go @@ -249,7 +249,7 @@ func (c CircuitInfo) toCircuit() (circuit Circuit, err error) { circuit[i].Inputs[iAsInput] = input } - if circuit[i].Gate = GetGate(wireInfo.Gate); circuit[i].Gate == nil && wireInfo.Gate != "" { + if circuit[i].Gate = GetGate(GateName(wireInfo.Gate)); circuit[i].Gate == nil && wireInfo.Gate != "" { err = fmt.Errorf("undefined gate \"%s\"", wireInfo.Gate) } } diff --git a/std/gkr/registry.go b/std/gkr/registry.go index db3fad3901..40649cbf11 100644 --- a/std/gkr/registry.go +++ b/std/gkr/registry.go @@ -2,12 +2,15 @@ package gkr import ( "fmt" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/gkr/internal" "sync" ) +type GateName string + var ( - gates = make(map[string]*Gate) + gates = make(map[GateName]*Gate) gatesLock sync.Mutex ) @@ -70,7 +73,7 @@ func WithDegree(degree int) RegisterGateOption { // NB! This package generally expects certain properties of the gate to be invariant across all curves. // In particular the degree is computed and verified over BN254. If the leading coefficient is divided by // the curve's order, the degree will be computed incorrectly. -func RegisterGate(name string, f GateFunction, nbIn int, options ...RegisterGateOption) error { +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { s := registerGateSettings{degree: -1, solvableVar: -1} for _, option := range options { option(&s) @@ -112,8 +115,34 @@ func RegisterGate(name string, f GateFunction, nbIn int, options ...RegisterGate return nil } -func GetGate(name string) *Gate { +func GetGate(name GateName) *Gate { gatesLock.Lock() defer gatesLock.Unlock() return gates[name] } + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + panicIfError(RegisterGate(Mul2, func(api frontend.API, x ...frontend.Variable) frontend.Variable { + return api.Mul(x[0], x[1]) + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar())) + panicIfError(RegisterGate(Add2, func(api frontend.API, x ...frontend.Variable) frontend.Variable { + return api.Add(x[0], x[1]) + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0))) + panicIfError(RegisterGate(Identity, func(api frontend.API, x ...frontend.Variable) frontend.Variable { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0))) + panicIfError(RegisterGate(Neg, func(api frontend.API, x ...frontend.Variable) frontend.Variable { + return api.Neg(x[0]) + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0))) + panicIfError(RegisterGate(Sub2, func(api frontend.API, x ...frontend.Variable) frontend.Variable { + return api.Sub(x[0], x[1]) + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0))) +} diff --git a/std/gkr/testing.go b/std/gkr/testing.go index ef6ebde0cf..1c43a65e40 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -63,11 +63,11 @@ func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable for i, in := range w.Inputs { ins[i] = res[in][instanceI] } - expectedV, err := parentApi.Compiler().NewHint(frGateHint(w.Gate, °reeTestedGates), 1, ins...) + expectedV, err := parentApi.Compiler().NewHint(frGateHint(GateName(w.Gate), °reeTestedGates), 1, ins...) if err != nil { panic(err) } - res[wireI][instanceI] = GetGate(w.Gate).Evaluate(parentApi, ins...) + res[wireI][instanceI] = GetGate(GateName(w.Gate)).Evaluate(parentApi, ins...) parentApi.AssertIsEqual(expectedV[0], res[wireI][instanceI]) // snark and raw gate evaluations must agree } } @@ -75,7 +75,7 @@ func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable return res } -func frGateHint(gateName string, degreeTestedGates *sync.Map) hint.Hint { +func frGateHint(gateName GateName, degreeTestedGates *sync.Map) hint.Hint { return func(mod *big.Int, ins, outs []*big.Int) error { const dummyGateName = "dummy-solve-in-test-engine-gate" var degreeFr, nbInFr, solvableVarFr int diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go index d0427d2ca0..dbe32a85d8 100644 --- a/std/hash/poseidon2/poseidon2.go +++ b/std/hash/poseidon2/poseidon2.go @@ -5,7 +5,7 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/hash" - poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2" + "github.com/consensys/gnark/std/permutation/poseidon2" ) // NewMerkleDamgardHasher returns a Poseidon2 hasher using the Merkle-Damgard diff --git a/std/permutation/poseidon2/gkr.go b/std/permutation/poseidon2/gkr.go index 63ab381878..b217d39084 100644 --- a/std/permutation/poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr.go @@ -156,7 +156,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // poseidon2 parameters roundKeysFr := poseidon2Bls12377.GetDefaultParameters().RoundKeys - params := poseidon2Bls12377.GetDefaultParameters().String() + gateNamer := gkrPoseidon2Bls12377.RoundGateNamer(poseidon2Bls12377.GetDefaultParameters()) rF := poseidon2Bls12377.GetDefaultParameters().NbFullRounds rP := poseidon2Bls12377.GetDefaultParameters().NbPartialRounds halfRf := rF / 2 @@ -173,11 +173,6 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. return nil, -1, err } - // unique names for linear rounds - gateNameSolvable := func(varI, round int) string { - return fmt.Sprintf("x%d-l-op-round=%d;%s", varI, round, params) - } - // the s-Box gates: u¹⁷ = (u⁴)⁴ * u if err = gkr.RegisterGate("pow4", pow4Gate, 1, gkr.WithUnverifiedDegree(4), gkr.WithNoSolvableVar()); err != nil { return nil, -1, err @@ -203,7 +198,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // register and apply external matrix multiplication and round key addition // round dependent due to the round key extKeySBox := func(round, varI int, a, b constraint.GkrVariable) constraint.GkrVariable { - gate := gateNameSolvable(varI, round) + gate := gkr.GateName(gateNamer.Linear(varI, round)) if err = gkr.RegisterGate(gate, extKeyGate(frToInt(&roundKeysFr[round][varI])), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return -1 } @@ -215,7 +210,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // for the second variable // round independent due to the round key intKeySBox2 := func(round int, a, b constraint.GkrVariable) constraint.GkrVariable { - gate := gateNameSolvable(yI, round) + gate := gkr.GateName(gateNamer.Linear(yI, round)) if err = gkr.RegisterGate(gate, intKeyGate2(frToInt(&roundKeysFr[round][1])), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return -1 } @@ -239,7 +234,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // still using the external matrix, since the linear operation still belongs to a full (canonical) round x1 := extKeySBox(halfRf, xI, x, y) - gate := gateNameSolvable(yI, halfRf) + gate := gkr.GateName(gateNamer.Linear(yI, halfRf)) if err = gkr.RegisterGate(gate, extGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return nil, -1, err } @@ -250,7 +245,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. for i := halfRf + 1; i < halfRf+rP; i++ { x1 := extKeySBox(i, xI, x, y) // the first row of the internal matrix is the same as that of the external matrix - gate := gateNameSolvable(yI, i) + gate := gkr.GateName(gateNamer.Linear(yI, i)) if err = gkr.RegisterGate(gate, intKeyGate2(zero), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return nil, -1, err } @@ -270,7 +265,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. } // apply the external matrix one last time to obtain the final value of y - gate := gateNameSolvable(yI, rP+rF) + gate := gkr.GateName(gateNamer.Linear(yI, rP+rF)) if err = gkr.RegisterGate(gate, extAddGate, 3, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return nil, -1, err } From 25e482da04fdab2c0bfb7140a41fbeca41a519c2 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 26 Mar 2025 21:32:34 -0500 Subject: [PATCH 13/62] chore: address copilot PR feedback --- std/gkr/api_test.go | 27 ++------------------------- std/gkr/internal/bn254_wrapper_api.go | 12 +++++++++--- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go index 8371acaa2c..40fa252291 100644 --- a/std/gkr/api_test.go +++ b/std/gkr/api_test.go @@ -433,6 +433,7 @@ func init() { } func registerMiMCGate() { + // register mimc gate panicIfError(RegisterGate("mimc", func(api frontend.API, input ...frontend.Variable) frontend.Variable { mimcSnarkTotalCalls++ @@ -445,6 +446,7 @@ func registerMiMCGate() { return api.Mul(sumCubed, sumCubed, sum) }, 2, WithDegree(7))) + // register fr version of mimc gate panicIfError(gkr.RegisterGate("mimc", func(input ...fr.Element) (res fr.Element) { var sum fr.Element @@ -473,31 +475,6 @@ func (c constPseudoHash) Reset() {} var mimcFrTotalCalls = 0 -// Copied from gnark-crypto TODO: Make public? -type mimcCipherGate struct { - ark fr.Element -} - -func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { - var sum fr.Element - - sum. - Add(&input[0], &input[1]). - Add(&sum, &m.ark) - - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - mimcFrTotalCalls++ - return -} - -func (m mimcCipherGate) Degree() int { - return 7 -} - type mimcNoGkrCircuit struct { X []frontend.Variable Y []frontend.Variable diff --git a/std/gkr/internal/bn254_wrapper_api.go b/std/gkr/internal/bn254_wrapper_api.go index 270c26201a..3857311ed2 100644 --- a/std/gkr/internal/bn254_wrapper_api.go +++ b/std/gkr/internal/bn254_wrapper_api.go @@ -154,10 +154,16 @@ func (w *bn254WrapperApi) AssertIsLessOrEqual(frontend.Variable, frontend.Variab func (w *bn254WrapperApi) Println(a ...frontend.Variable) { toPrint := make([]any, len(a)) for i, v := range a { - if x := w.cast(v); w.err == nil { - toPrint[i] = x[i] + var x fr.Element + if _, err := x.SetInterface(v); err != nil { + if s, ok := v.(string); ok { + toPrint[i] = s + continue + } else { + w.newError("not numeric or string") + } } else { - return + toPrint[i] = x.String() } } fmt.Println(toPrint...) From a12e75964d42663e59e1552368bc5d4454fba845 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 26 Mar 2025 21:37:56 -0500 Subject: [PATCH 14/62] nitpick: "identity" -> Identity --- std/gkr/gkr.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index 49be099ca1..b68c1c3e7c 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -367,7 +367,7 @@ func outputsList(c Circuit, indexes map[*Wire]int) [][]int { res[i] = make([]int, 0) c[i].nbUniqueOutputs = 0 if c[i].IsInput() { - c[i].Gate = GetGate("identity") + c[i].Gate = GetGate(Identity) } } ins := make(map[int]struct{}, len(c)) From cbaf5fd8488a0af7dde3403c23fbf743d9d1b1be Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 23 Mar 2025 12:54:25 -0500 Subject: [PATCH 15/62] checkpoint: up to line 418 --- std/gkr/api_test.go | 6 +- std/gkr/compile.go | 10 +- std/gkr/example_test.go | 235 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 245 insertions(+), 6 deletions(-) create mode 100644 std/gkr/example_test.go diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go index 40fa252291..a69ef6239b 100644 --- a/std/gkr/api_test.go +++ b/std/gkr/api_test.go @@ -2,6 +2,7 @@ package gkr import ( "fmt" + gcHash "github.com/consensys/gnark-crypto/hash" "hash" "math/rand" "strconv" @@ -22,7 +23,6 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" - bn254MiMC "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" @@ -386,9 +386,7 @@ func (c *benchMiMCMerkleTreeCircuit) Define(api frontend.API) error { } func registerMiMC() { - bn254.RegisterHashBuilder("mimc", func() hash.Hash { - return bn254MiMC.NewMiMC() - }) + bn254.RegisterHashBuilder("mimc", gcHash.MIMC_BN254.New) stdHash.Register("mimc", func(api frontend.API) (stdHash.FieldHasher, error) { m, err := mimc.NewMiMC(api) return &m, err diff --git a/std/gkr/compile.go b/std/gkr/compile.go index 3d683fc9e3..ddb0a0764b 100644 --- a/std/gkr/compile.go +++ b/std/gkr/compile.go @@ -2,6 +2,7 @@ package gkr import ( "errors" + "fmt" "math/bits" "github.com/consensys/gnark/constraint" @@ -108,9 +109,14 @@ func (api *API) Solve(parentApi frontend.API) (Solution, error) { for i := range circuit { v := &circuit[i] - if v.IsInput() { + in, out := v.IsInput(), v.IsOutput() + if in && out { + return Solution{}, fmt.Errorf("unused input (variable #%d)", i) + } + + if in { solveHintNIn += nbInstances - len(v.Dependencies) - } else if v.IsOutput() { + } else if out { solveHintNOut += nbInstances } } diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go new file mode 100644 index 0000000000..abef121bdc --- /dev/null +++ b/std/gkr/example_test.go @@ -0,0 +1,235 @@ +package gkr_test + +import ( + "encoding/binary" + "errors" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + gkrBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" + gcHash "github.com/consensys/gnark-crypto/hash" + bw6761 "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/gkr" + stdHash "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/test" +) + +func Example() { + // This example computes the double of multiple BLS12-377 G1 points, which can be computed natively over BW6-761. + // This means that the imported fr and fp packages are the same, being from BW6-761 and BLS12-377 respectively. TODO @Tabaie delete if no longer have fp imported + // It is based on the function DoubleAssign() of type G1Jac in gnark-crypto v0.17.0. + // github.com/consensys/gnark-crypto/ecc/bls12-377 + const gateNamePrefix = "bls12-377-jac-double-" + + // Every gate needs to be defined over a concrete field, used by the GKR prover, + // and over a frontend.API, used by the in-SNARK GKR verifier. + // + // Note that the SNARK prover will need both of these: + // The GKR prover will provide a proof as private input to the SNARK prover, + // wherein the embedded GKR verifier will verify it, establishing the correctness + // of our claimed values. + + // This function will contain the concrete implementations of the gates. + // The SNARK implementations will be defined in the Define() method of the circuit. + + // combine the operations that define the first value assigned to variable S + // input = [X, YY, XX, YYYY] + // S = 2 * [(X + YY)² - XX - YYYY] + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"s1", func(input ...fr.Element) (S fr.Element) { + S. + Add(&input[0], &input[1]). // 409: S.Add(&p.X, &YY) + Square(&S). // 410: S.Square(&S). + Sub(&S, &input[2]). // 411: Sub(&S, &XX). + Sub(&S, &input[3]). // 412: Sub(&S, &YYYY). + Double(&S) // 413: Double(&S) + + return + }, 4)) + + // combine the operations that define the first change to p.Z + // input = [p.Z, p.Y, YY, ZZ] + // Z = (p.Z + p.Y)² - YY - ZZ + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"z1", func(input ...fr.Element) (Z fr.Element) { + Z.Add(&input[0], &input[1]) // 415: p.Z.Add(&p.Z, &p.Y). + Z.Square(&Z) // 416: p.Z.Square(&p.Z). + Z.Sub(&Z, &input[2]) // 417: Sub(&p.Z, &YY). + Z.Sub(&Z, &input[3]) // 418: Sub(&p.Z, &ZZ) + return + }, 4)) + + // we have a lot of squaring operations, which we'd rather look at as single-input + assertNoError(gkrBw6761.RegisterGate("square", func(input ...fr.Element) (res fr.Element) { + res.Square(&input[0]) + return + }, 1)) + + const nbInstances = 2 + // create instances + assignment := exampleCircuit{ + X: make([]frontend.Variable, nbInstances), + Y: make([]frontend.Variable, nbInstances), + Z: make([]frontend.Variable, nbInstances), + XOut: make([]frontend.Variable, nbInstances), + YOut: make([]frontend.Variable, nbInstances), + ZOut: make([]frontend.Variable, nbInstances), + } + + for i := range nbInstances { + // create a "random" point + var b [8]byte + binary.BigEndian.PutUint64(b[:], uint64(i)) + a, err := bls12377.HashToG1(b[:], nil) + assertNoError(err) + var p bls12377.G1Jac + p.FromAffine(&a) + + assignment.X[i] = p.X + assignment.Y[i] = p.Y + assignment.Z[i] = p.Z + + p.DoubleAssign() + assignment.XOut[i] = p.X + assignment.YOut[i] = p.Y + assignment.ZOut[i] = p.Z + + // TODO delete this + { + p.X = assignment.X[i].(fp.Element) + p.Y = assignment.Y[i].(fp.Element) + p.Z = assignment.Z[i].(fp.Element) + + var XX, YY, YYYY, ZZ, S, M, T fp.Element + + _, _ = M, T + + XX.Square(&p.X) + YY.Square(&p.Y) + YYYY.Square(&YY) + ZZ.Square(&p.Z) + S.Add(&p.X, &YY). + Square(&S). + Sub(&S, &XX). + Sub(&S, &YYYY). + Double(&S) + + assignment.XOut[i] = S + } + } + + circuit := exampleCircuit{ + X: make([]frontend.Variable, nbInstances), + Y: make([]frontend.Variable, nbInstances), + Z: make([]frontend.Variable, nbInstances), + XOut: make([]frontend.Variable, nbInstances), + YOut: make([]frontend.Variable, nbInstances), + ZOut: make([]frontend.Variable, nbInstances), + gateNamePrefix: gateNamePrefix, + } + + // register the hash function used for verifying the GKR proof (prover side) + bw6761.RegisterHashBuilder("mimc", gcHash.MIMC_BW6_761.New) + + assertNoError(test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) + + // Output: +} + +type exampleCircuit struct { + X, Y, Z []frontend.Variable // Jacobian coordinates for each point (input) + XOut, YOut, ZOut []frontend.Variable // Jacobian coordinates for the double of each point (expected output) + gateNamePrefix string +} + +func (c *exampleCircuit) Define(api frontend.API) error { + if len(c.X) != len(c.Y) || len(c.X) != len(c.Z) || len(c.X) != len(c.XOut) || len(c.X) != len(c.YOut) || len(c.X) != len(c.ZOut) { + return errors.New("all inputs/outputs must have the same length (i.e. the number of instances)") + } + + gkrApi := gkr.NewApi() + + assertNoError(gkr.RegisterGate("square", func(api frontend.API, input ...frontend.Variable) (res frontend.Variable) { + return api.Mul(input[0], input[0]) + }, 1)) + + // define the GKR circuit + + // create GKR circuit variables based on the given assignments + X, err := gkrApi.Import(c.X) + if err != nil { + return err + } + + Y, err := gkrApi.Import(c.Y) + if err != nil { + return err + } + + Z, err := gkrApi.Import(c.Z) + if err != nil { + return err + } + + XX := gkrApi.NamedGate("square", X) // 405: XX.Square(&p.X) TODO See if anything changes (perf-wise) if we use gkrApi.Mul(X, X) instead + YY := gkrApi.NamedGate("square", Y) // 406: YY.Square(&p.Y) + YYYY := gkrApi.NamedGate("square", YY) // 407: YYYY.Square(&YY) + ZZ := gkrApi.NamedGate("square", Z) // 408: ZZ.Square(&p.Z) + + // define the SNARK version of the custom gates, similarly to the ones in Example + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"s1", func(api frontend.API, input ...frontend.Variable) (S frontend.Variable) { + S = api.Add(input[0], input[1]) // 409: S.Add(&p.X, &YY) + S = api.Mul(S, S) // 410: S.Square(&S). + S = api.Sub(S, input[2], input[3]) // 411: Sub(&S, &XX). + // 412: Sub(&S, &YYYY). + return api.Add(S, S) // 413: Double(&S) + }, 4)) + S := gkrApi.NamedGate(c.gateNamePrefix+"s1", X, YY, XX, YYYY) // 409 - 413 + // 414: M.Double(&XX).Add(&M, &XX) + // Note that (but don't explicitly compute) that M = 3XX + + // p.Z.Add(&p.Z, &p.Y). + // Square(&p.Z). + // Sub(&p.Z, &YY). + // Sub(&p.Z, &ZZ) + + // combine the operations that define the first change to p.Z + // input = [p.Z, p.Y, YY, ZZ] + // Z = (p.Z + p.Y)² - YY - ZZ + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"z1", func(api frontend.API, input ...frontend.Variable) (Z frontend.Variable) { + Z = api.Add(input[0], input[1]) // 415: p.Z.Add(&p.Z, &p.Y). + Z = api.Mul(Z, Z) // 416: p.Z.Square(&p.Z). + Z = api.Sub(Z, input[2], input[3]) // 417: Sub(&p.Z, &YY). + // 418: Sub(&p.Z, &ZZ). + return + }, 4)) + Z = gkrApi.NamedGate(c.gateNamePrefix+"z1", Z, Y, YY, ZZ) // 415 - 418 + + // solve and prove the circuit + solution, err := gkrApi.Solve(api) + if err != nil { + return err + } + + // check the output + XOut := solution.Export(S) // TODO do this with actual output values + for i := range XOut { + api.AssertIsEqual(XOut[i], c.XOut[i]) + } + + // register the hash function used for verification (fiat shamir) + stdHash.Register("mimc", func(api frontend.API) (stdHash.FieldHasher, error) { + m, err := mimc.NewMiMC(api) + return &m, err + }) + + // verify the proof + return solution.Verify("mimc") +} + +func assertNoError(err error) { + if err != nil { + panic(err) + } +} From 509942be200e093b2cf9bc93acd4a4e379710be2 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 23 Mar 2025 13:14:59 -0500 Subject: [PATCH 16/62] checkpoint: line 422 --- std/gkr/example_test.go | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index abef121bdc..e29ec58a40 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -60,6 +60,19 @@ func Example() { return }, 4)) + // combine the operations that define the first change to p.X + // input = [XX, S] + // p.X = 9XX² - 2S + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"x1", func(input ...fr.Element) (X fr.Element) { + var M, T fr.Element + M.Double(&input[0]).Add(&M, &input[0]) // 414: M.Double(&XX).Add(&M, &XX) + T.Square(&M) // 419: T.Square(&M) + X = T // 420: p.X = T + T.Double(&input[1]) // 421: T.Double(&S) + X.Sub(&X, &T) // 422: p.X.Sub(&p.X, &T) + return + }, 2)) + // we have a lot of squaring operations, which we'd rather look at as single-input assertNoError(gkrBw6761.RegisterGate("square", func(input ...fr.Element) (res fr.Element) { res.Square(&input[0]) @@ -187,12 +200,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { }, 4)) S := gkrApi.NamedGate(c.gateNamePrefix+"s1", X, YY, XX, YYYY) // 409 - 413 // 414: M.Double(&XX).Add(&M, &XX) - // Note that (but don't explicitly compute) that M = 3XX - - // p.Z.Add(&p.Z, &p.Y). - // Square(&p.Z). - // Sub(&p.Z, &YY). - // Sub(&p.Z, &ZZ) + // Note (but don't explicitly compute) that M = 3XX // combine the operations that define the first change to p.Z // input = [p.Z, p.Y, YY, ZZ] @@ -206,6 +214,19 @@ func (c *exampleCircuit) Define(api frontend.API) error { }, 4)) Z = gkrApi.NamedGate(c.gateNamePrefix+"z1", Z, Y, YY, ZZ) // 415 - 418 + // combine the operations that define the first change to p.X + // input = [XX, S] + // p.X = 9XX² - 2S + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"x1", func(api frontend.API, input ...frontend.Variable) (X frontend.Variable) { + M := api.Mul(input[0], 3) // 414: M.Double(&XX).Add(&M, &XX) + T := api.Mul(M, M) // 419: T.Square(&M) + X = api.Sub(T, api.Mul(input[1], 2)) // 420: p.X = T + // 421: T.Double(&S) + // 422: p.X.Sub(&p.X, &T) + return + }, 2)) + X = gkrApi.NamedGate(c.gateNamePrefix+"x1", XX, S) // 419-422 + // solve and prove the circuit solution, err := gkrApi.Solve(api) if err != nil { From 7d4e0633dc98afa6d9304dca36300112ab963bfa Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 23 Mar 2025 13:46:17 -0500 Subject: [PATCH 17/62] docs: complete test for DoubleAssign with GKR --- std/gkr/example_test.go | 56 ++++++++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index e29ec58a40..24019bf946 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -49,10 +49,10 @@ func Example() { return }, 4)) - // combine the operations that define the first change to p.Z + // combine the operations that define the assignment to p.Z // input = [p.Z, p.Y, YY, ZZ] // Z = (p.Z + p.Y)² - YY - ZZ - assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"z1", func(input ...fr.Element) (Z fr.Element) { + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"z", func(input ...fr.Element) (Z fr.Element) { Z.Add(&input[0], &input[1]) // 415: p.Z.Add(&p.Z, &p.Y). Z.Square(&Z) // 416: p.Z.Square(&p.Z). Z.Sub(&Z, &input[2]) // 417: Sub(&p.Z, &YY). @@ -60,10 +60,10 @@ func Example() { return }, 4)) - // combine the operations that define the first change to p.X + // combine the operations that define the assignment to p.X // input = [XX, S] // p.X = 9XX² - 2S - assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"x1", func(input ...fr.Element) (X fr.Element) { + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"x", func(input ...fr.Element) (X fr.Element) { var M, T fr.Element M.Double(&input[0]).Add(&M, &input[0]) // 414: M.Double(&XX).Add(&M, &XX) T.Square(&M) // 419: T.Square(&M) @@ -73,6 +73,20 @@ func Example() { return }, 2)) + // combine the operations that define the assignment to p.Y + // input = [S, p.X, XX, YYYY] + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"y", func(input ...fr.Element) (Y fr.Element) { + Y.Double(&input[2]).Add(&Y, &input[2]) // 414: M.Double(&XX).Add(&M, &XX) + input[2] = Y + + Y.Sub(&input[0], &input[1]). // 423: p.Y.Sub(&S, &p.X). + Mul(&Y, &input[2]) // 424: p.Y.Mul(&p.Y, &M). + input[3].Double(&input[3]).Double(&input[3]).Double(&input[3]) // 425: M.Double(&YYYY).Double(&M).Double(&M) + Y.Sub(&Y, &input[3]) // 426: p.Y.Sub(&p.Y, &YYYY) + + return + }, 4)) + // we have a lot of squaring operations, which we'd rather look at as single-input assertNoError(gkrBw6761.RegisterGate("square", func(input ...fr.Element) (res fr.Element) { res.Square(&input[0]) @@ -191,41 +205,55 @@ func (c *exampleCircuit) Define(api frontend.API) error { ZZ := gkrApi.NamedGate("square", Z) // 408: ZZ.Square(&p.Z) // define the SNARK version of the custom gates, similarly to the ones in Example - assertNoError(gkr.RegisterGate(c.gateNamePrefix+"s1", func(api frontend.API, input ...frontend.Variable) (S frontend.Variable) { + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"s", func(api frontend.API, input ...frontend.Variable) (S frontend.Variable) { S = api.Add(input[0], input[1]) // 409: S.Add(&p.X, &YY) S = api.Mul(S, S) // 410: S.Square(&S). S = api.Sub(S, input[2], input[3]) // 411: Sub(&S, &XX). // 412: Sub(&S, &YYYY). return api.Add(S, S) // 413: Double(&S) }, 4)) - S := gkrApi.NamedGate(c.gateNamePrefix+"s1", X, YY, XX, YYYY) // 409 - 413 + S := gkrApi.NamedGate(c.gateNamePrefix+"s", X, YY, XX, YYYY) // 409 - 413 // 414: M.Double(&XX).Add(&M, &XX) // Note (but don't explicitly compute) that M = 3XX - // combine the operations that define the first change to p.Z + // combine the operations that define the assignment to p.Z // input = [p.Z, p.Y, YY, ZZ] // Z = (p.Z + p.Y)² - YY - ZZ - assertNoError(gkr.RegisterGate(c.gateNamePrefix+"z1", func(api frontend.API, input ...frontend.Variable) (Z frontend.Variable) { + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"z", func(api frontend.API, input ...frontend.Variable) (Z frontend.Variable) { Z = api.Add(input[0], input[1]) // 415: p.Z.Add(&p.Z, &p.Y). Z = api.Mul(Z, Z) // 416: p.Z.Square(&p.Z). Z = api.Sub(Z, input[2], input[3]) // 417: Sub(&p.Z, &YY). // 418: Sub(&p.Z, &ZZ). return }, 4)) - Z = gkrApi.NamedGate(c.gateNamePrefix+"z1", Z, Y, YY, ZZ) // 415 - 418 + Z = gkrApi.NamedGate(c.gateNamePrefix+"z", Z, Y, YY, ZZ) // 415 - 418 - // combine the operations that define the first change to p.X + // combine the operations that define the assignment to p.X // input = [XX, S] // p.X = 9XX² - 2S - assertNoError(gkr.RegisterGate(c.gateNamePrefix+"x1", func(api frontend.API, input ...frontend.Variable) (X frontend.Variable) { + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"x", func(api frontend.API, input ...frontend.Variable) (X frontend.Variable) { M := api.Mul(input[0], 3) // 414: M.Double(&XX).Add(&M, &XX) - T := api.Mul(M, M) // 419: T.Square(&M) + T := api.Mul(M, M) // 419: T.Square(&M) X = api.Sub(T, api.Mul(input[1], 2)) // 420: p.X = T // 421: T.Double(&S) // 422: p.X.Sub(&p.X, &T) return }, 2)) - X = gkrApi.NamedGate(c.gateNamePrefix+"x1", XX, S) // 419-422 + X = gkrApi.NamedGate(c.gateNamePrefix+"x", XX, S) // 419-422 + + // combine the operations that define the assignment to p.Y + // input = [S, p.X, XX, YYYY] + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"y", func(api frontend.API, input ...frontend.Variable) (Y frontend.Variable) { + input[2] = api.Mul(3, input[2]) // 414: M.Double(&XX).Add(&M, &XX) + + Y = api.Sub(input[0], input[1]) // 423: p.Y.Sub(&S, &p.X). + Y = api.Mul(Y, input[2]) // 424: Mul(&p.Y, &M) + // 425: M.Double(&YYYY).Double(&M).Double(&M) + Y = api.Sub(Y, api.Mul(input[3], 8)) // 426: p.Y.Sub(&p.Y, &YYYY) + + return + }, 4)) + Y = gkrApi.NamedGate(c.gateNamePrefix+"y", S, X, XX, YYYY) // 423 - 426 // solve and prove the circuit solution, err := gkrApi.Solve(api) @@ -234,7 +262,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { } // check the output - XOut := solution.Export(S) // TODO do this with actual output values + XOut := solution.Export(X) // TODO do this with actual output values for i := range XOut { api.AssertIsEqual(XOut[i], c.XOut[i]) } From e030c5077dd844c56a1f44bf8a422092f9b0e7da Mon Sep 17 00:00:00 2001 From: Tabaie Date: Sun, 23 Mar 2025 13:48:01 -0500 Subject: [PATCH 18/62] fix gate name --- std/gkr/example_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index 24019bf946..ddca409226 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -38,7 +38,7 @@ func Example() { // combine the operations that define the first value assigned to variable S // input = [X, YY, XX, YYYY] // S = 2 * [(X + YY)² - XX - YYYY] - assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"s1", func(input ...fr.Element) (S fr.Element) { + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"s", func(input ...fr.Element) (S fr.Element) { S. Add(&input[0], &input[1]). // 409: S.Add(&p.X, &YY) Square(&S). // 410: S.Square(&S). From da5c15f437dc799874dddb3337c090903edaf507 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 24 Mar 2025 10:29:45 -0500 Subject: [PATCH 19/62] test: works under test engine, proof fails --- std/gkr/compile.go | 4 ++-- std/gkr/example_test.go | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/std/gkr/compile.go b/std/gkr/compile.go index ddb0a0764b..4a86ac3784 100644 --- a/std/gkr/compile.go +++ b/std/gkr/compile.go @@ -158,8 +158,8 @@ func (api *API) Solve(parentApi frontend.API) (Solution, error) { } // Export returns the values of an output variable across all instances -func (s Solution) Export(v frontend.Variable) []frontend.Variable { - return utils.Map(s.permutations.SortedInstances, utils.SliceAt(s.assignments[v.(constraint.GkrVariable)])) +func (s Solution) Export(v constraint.GkrVariable) []frontend.Variable { + return utils.Map(s.permutations.SortedInstances, utils.SliceAt(s.assignments[v])) } // Verify encodes the verification circuitry for the GKR circuit diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index ddca409226..08fa6a49bc 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -102,6 +102,7 @@ func Example() { XOut: make([]frontend.Variable, nbInstances), YOut: make([]frontend.Variable, nbInstances), ZOut: make([]frontend.Variable, nbInstances), + SOut: make([]frontend.Variable, nbInstances), } for i := range nbInstances { @@ -124,6 +125,7 @@ func Example() { // TODO delete this { + p.X = assignment.X[i].(fp.Element) p.Y = assignment.Y[i].(fp.Element) p.Z = assignment.Z[i].(fp.Element) @@ -142,7 +144,7 @@ func Example() { Sub(&S, &YYYY). Double(&S) - assignment.XOut[i] = S + assignment.SOut[i] = S } } @@ -153,6 +155,7 @@ func Example() { XOut: make([]frontend.Variable, nbInstances), YOut: make([]frontend.Variable, nbInstances), ZOut: make([]frontend.Variable, nbInstances), + SOut: make([]frontend.Variable, nbInstances), gateNamePrefix: gateNamePrefix, } @@ -167,6 +170,7 @@ func Example() { type exampleCircuit struct { X, Y, Z []frontend.Variable // Jacobian coordinates for each point (input) XOut, YOut, ZOut []frontend.Variable // Jacobian coordinates for the double of each point (expected output) + SOut []frontend.Variable // temporary gateNamePrefix string } @@ -213,6 +217,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { return api.Add(S, S) // 413: Double(&S) }, 4)) S := gkrApi.NamedGate(c.gateNamePrefix+"s", X, YY, XX, YYYY) // 409 - 413 + scp := gkrApi.NamedGate("identity", S) // 414: M.Double(&XX).Add(&M, &XX) // Note (but don't explicitly compute) that M = 3XX @@ -255,6 +260,18 @@ func (c *exampleCircuit) Define(api frontend.API) error { }, 4)) Y = gkrApi.NamedGate(c.gateNamePrefix+"y", S, X, XX, YYYY) // 423 - 426 + // have to duplicate X for it to be considered an output variable + // TODO remove once https://github.com/Consensys/gnark/issues/1452 is addressed + X = gkrApi.NamedGate("identity", X) + + res := gkrApi.SolveInTestEngine(api) + for i := range c.XOut { + api.AssertIsEqual(res[scp][i], c.SOut[i]) + api.AssertIsEqual(res[Z][i], c.ZOut[i]) + api.AssertIsEqual(res[X][i], c.XOut[i]) + api.AssertIsEqual(res[Y][i], c.YOut[i]) + } + // solve and prove the circuit solution, err := gkrApi.Solve(api) if err != nil { @@ -262,9 +279,21 @@ func (c *exampleCircuit) Define(api frontend.API) error { } // check the output + // TODO merge loops + SOut := solution.Export(scp) + for i := range SOut { + api.AssertIsEqual(SOut[i], c.SOut[i]) + } + + ZOut := solution.Export(Z) + for i := range ZOut { + api.AssertIsEqual(ZOut[i], c.ZOut[i]) + } + XOut := solution.Export(X) // TODO do this with actual output values for i := range XOut { - api.AssertIsEqual(XOut[i], c.XOut[i]) + _ = i + //api.AssertIsEqual(XOut[i], c.XOut[i]) } // register the hash function used for verification (fiat shamir) From f9f576be197e575e703c81bca15bb0d2d666aadc Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 24 Mar 2025 18:56:04 -0500 Subject: [PATCH 20/62] test with const hash --- std/gkr/example_test.go | 81 +++++++++++++++++++++++++++++++++++++---- std/gkr/hints.go | 50 +++++++++++++++++++++++++ std/gkr/testing.go | 42 ++++++++++++++++++++- 3 files changed, 164 insertions(+), 9 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index 08fa6a49bc..b19bb9a110 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -3,18 +3,19 @@ package gkr_test import ( "encoding/binary" "errors" + "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" gkrBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" - gcHash "github.com/consensys/gnark-crypto/hash" bw6761 "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/gkr" stdHash "github.com/consensys/gnark/std/hash" - "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" + "hash" + "math/big" ) func Example() { @@ -22,7 +23,10 @@ func Example() { // This means that the imported fr and fp packages are the same, being from BW6-761 and BLS12-377 respectively. TODO @Tabaie delete if no longer have fp imported // It is based on the function DoubleAssign() of type G1Jac in gnark-crypto v0.17.0. // github.com/consensys/gnark-crypto/ecc/bls12-377 - const gateNamePrefix = "bls12-377-jac-double-" + const ( + gateNamePrefix = "bls12-377-jac-double-" + fsHashName = "const" + ) // Every gate needs to be defined over a concrete field, used by the GKR prover, // and over a frontend.API, used by the in-SNARK GKR verifier. @@ -157,10 +161,12 @@ func Example() { ZOut: make([]frontend.Variable, nbInstances), SOut: make([]frontend.Variable, nbInstances), gateNamePrefix: gateNamePrefix, + fsHashName: fsHashName, } // register the hash function used for verifying the GKR proof (prover side) - bw6761.RegisterHashBuilder("mimc", gcHash.MIMC_BW6_761.New) + //bw6761.RegisterHashBuilder("mimc", gcHash.MIMC_BW6_761.New) + bw6761.RegisterHashBuilder(fsHashName, func() hash.Hash { return constHasherBw6761{} }) assertNoError(test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) @@ -172,6 +178,7 @@ type exampleCircuit struct { XOut, YOut, ZOut []frontend.Variable // Jacobian coordinates for the double of each point (expected output) SOut []frontend.Variable // temporary gateNamePrefix string + fsHashName string // name of the hash function used for Fiat-Shamir in the GKR verifier } func (c *exampleCircuit) Define(api frontend.API) error { @@ -292,18 +299,26 @@ func (c *exampleCircuit) Define(api frontend.API) error { XOut := solution.Export(X) // TODO do this with actual output values for i := range XOut { - _ = i - //api.AssertIsEqual(XOut[i], c.XOut[i]) + api.AssertIsEqual(XOut[i], c.XOut[i]) + } + + YOut := solution.Export(Y) + for i := range YOut { + api.AssertIsEqual(YOut[i], c.YOut[i]) } // register the hash function used for verification (fiat shamir) - stdHash.Register("mimc", func(api frontend.API) (stdHash.FieldHasher, error) { + /*stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { m, err := mimc.NewMiMC(api) return &m, err + })*/ + + stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { + return &constHasherSnark{api: api}, nil }) // verify the proof - return solution.Verify("mimc") + return solution.Verify(c.fsHashName) } func assertNoError(err error) { @@ -311,3 +326,53 @@ func assertNoError(err error) { panic(err) } } + +type constHasherBw6761 struct{} + +func (constHasherBw6761) Write(p []byte) (int, error) { + for i := 0; i < len(p); i += fr.Bytes { + var I big.Int + I.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + fmt.Print(I.Text(10), " ") + } + return len(p), nil +} + +func (constHasherBw6761) Sum(p []byte) []byte { + if p != nil { + panic("unexpected input") + } + fmt.Println() + var b [fr.Bytes]byte + b[len(b)-1] = 1 + return b[:] +} + +func (constHasherBw6761) Reset() { +} + +func (constHasherBw6761) Size() int { + return fr.Bytes +} + +func (constHasherBw6761) BlockSize() int { + return fr.Bytes +} + +type constHasherSnark struct { + api frontend.API + v []frontend.Variable +} + +func (h *constHasherSnark) Sum() frontend.Variable { + h.api.Println(h.v...) + return 1 +} + +func (h *constHasherSnark) Write(v ...frontend.Variable) { + h.v = append(h.v, v...) +} + +func (h *constHasherSnark) Reset() { + h.v = h.v[:0] +} diff --git a/std/gkr/hints.go b/std/gkr/hints.go index 8cfdaa6421..027c86b9c2 100644 --- a/std/gkr/hints.go +++ b/std/gkr/hints.go @@ -1,7 +1,9 @@ package gkr import ( + "bytes" "errors" + "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint" bls12377 "github.com/consensys/gnark/constraint/bls12-377" @@ -12,6 +14,7 @@ import ( bw6633 "github.com/consensys/gnark/constraint/bw6-633" bw6761 "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark/constraint/solver" + "hash" "math/big" ) @@ -100,3 +103,50 @@ func ProveHintPlaceholder(hashName string) solver.Hint { return errors.New("unsupported modulus") } } + +func CheckHashHint(hashName string) solver.Hint { + return func(mod *big.Int, ins, outs []*big.Int) error { + if len(ins) != 2 || len(outs) != 1 { + return errors.New("invalid number of inputs/outputs") + } + + var ( + builder func() hash.Hash + err error + ) + if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { + builder, err = bls12377.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + builder, err = bls12381.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { + builder, err = bls24315.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { + builder, err = bls24317.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BN254.ScalarField()) == 0 { + builder, err = bn254.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { + builder, err = bw6633.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { + builder, err = bw6761.GetHashBuilder(hashName) + } else { + return errors.New("unsupported modulus") + } + + if err != nil { + return err + } + + toHash := ins[0].Bytes() + expectedHash := ins[1].Bytes() + + hsh := builder() + hsh.Write(toHash) + hashed := hsh.Sum(nil) + + if !bytes.Equal(hashed, expectedHash) { + return fmt.Errorf("hash mismatch: expected %x, got %x", expectedHash, hashed) + } + + return nil + } +} diff --git a/std/gkr/testing.go b/std/gkr/testing.go index 1c43a65e40..aefb3f92ac 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -3,6 +3,7 @@ package gkr import ( "errors" "fmt" + stdHash "github.com/consensys/gnark/std/hash" "math/big" "sync" @@ -25,10 +26,49 @@ import ( "github.com/consensys/gnark/frontend" ) +type solveInTestEngineSettings struct { + hashName string +} + +type SolveInTestEngineOption func(*solveInTestEngineSettings) + +func WithHashName(name string) SolveInTestEngineOption { + return func(s *solveInTestEngineSettings) { + s.hashName = name + } +} + // SolveInTestEngine solves the defined circuit directly inside the SNARK circuit. This means that the method does not compute the GKR proof of the circuit and does not embed the GKR proof verifier inside a SNARK. // The output is the values of all variables, across all instances; i.e. indexed variable-first, instance-second. // This method only works under the test engine and should only be called to debug a GKR circuit, as the GKR prover's errors can be obscure. -func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable { +func (api *API) SolveInTestEngine(parentApi frontend.API, options ...SolveInTestEngineOption) [][]frontend.Variable { + var s solveInTestEngineSettings + for _, o := range options { + o(&s) + } + if s.hashName != "" { + // hash something and make sure it gives the same answer both on prover and verifier sides + // TODO @Tabaie If indeed cheap, move this feature to Verify so that it is always run + h, err := stdHash.GetFieldHasher(s.hashName, parentApi) + if err != nil { + panic(err) + } + nbBytes := (parentApi.Compiler().FieldBitLen() + 7) / 8 + toHash := frontend.Variable(0) + for i := range nbBytes { + toHash = parentApi.Add(toHash, toHash, i%256) + } + h.Reset() + h.Write(toHash) + hashed := h.Sum() + + hintOut, err := parentApi.Compiler().NewHint(CheckHashHint(s.hashName), 1, toHash, hashed) + if err != nil { + panic(err) + } + parentApi.AssertIsEqual(hintOut[0], hashed) // the hint already checks this + } + res := make([][]frontend.Variable, len(api.toStore.Circuit)) var degreeTestedGates sync.Map for i, w := range api.toStore.Circuit { From b07cba30bdd1cc46e9054b67a65e5309cba33794 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 24 Mar 2025 19:02:00 -0500 Subject: [PATCH 21/62] revert: use mimc hash --- std/gkr/example_test.go | 54 ++++++++++------------------------------- std/gkr/hints.go | 2 ++ 2 files changed, 15 insertions(+), 41 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index b19bb9a110..aec1c1dbde 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -9,12 +9,13 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" gkrBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" + gcHash "github.com/consensys/gnark-crypto/hash" bw6761 "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/gkr" stdHash "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" - "hash" "math/big" ) @@ -165,8 +166,7 @@ func Example() { } // register the hash function used for verifying the GKR proof (prover side) - //bw6761.RegisterHashBuilder("mimc", gcHash.MIMC_BW6_761.New) - bw6761.RegisterHashBuilder(fsHashName, func() hash.Hash { return constHasherBw6761{} }) + bw6761.RegisterHashBuilder(fsHashName, gcHash.MIMC_BW6_761.New) assertNoError(test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) @@ -271,7 +271,13 @@ func (c *exampleCircuit) Define(api frontend.API) error { // TODO remove once https://github.com/Consensys/gnark/issues/1452 is addressed X = gkrApi.NamedGate("identity", X) - res := gkrApi.SolveInTestEngine(api) + // register the hash function used for verification (fiat shamir) + stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { + m, err := mimc.NewMiMC(api) + return &m, err + }) + + res := gkrApi.SolveInTestEngine(api, gkr.WithHashName(c.fsHashName)) for i := range c.XOut { api.AssertIsEqual(res[scp][i], c.SOut[i]) api.AssertIsEqual(res[Z][i], c.ZOut[i]) @@ -292,31 +298,15 @@ func (c *exampleCircuit) Define(api frontend.API) error { api.AssertIsEqual(SOut[i], c.SOut[i]) } + XOut := solution.Export(X) + YOut := solution.Export(Y) ZOut := solution.Export(Z) - for i := range ZOut { - api.AssertIsEqual(ZOut[i], c.ZOut[i]) - } - - XOut := solution.Export(X) // TODO do this with actual output values for i := range XOut { api.AssertIsEqual(XOut[i], c.XOut[i]) - } - - YOut := solution.Export(Y) - for i := range YOut { api.AssertIsEqual(YOut[i], c.YOut[i]) + api.AssertIsEqual(ZOut[i], c.ZOut[i]) } - // register the hash function used for verification (fiat shamir) - /*stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { - m, err := mimc.NewMiMC(api) - return &m, err - })*/ - - stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { - return &constHasherSnark{api: api}, nil - }) - // verify the proof return solution.Verify(c.fsHashName) } @@ -358,21 +348,3 @@ func (constHasherBw6761) Size() int { func (constHasherBw6761) BlockSize() int { return fr.Bytes } - -type constHasherSnark struct { - api frontend.API - v []frontend.Variable -} - -func (h *constHasherSnark) Sum() frontend.Variable { - h.api.Println(h.v...) - return 1 -} - -func (h *constHasherSnark) Write(v ...frontend.Variable) { - h.v = append(h.v, v...) -} - -func (h *constHasherSnark) Reset() { - h.v = h.v[:0] -} diff --git a/std/gkr/hints.go b/std/gkr/hints.go index 027c86b9c2..5b6ec9c756 100644 --- a/std/gkr/hints.go +++ b/std/gkr/hints.go @@ -147,6 +147,8 @@ func CheckHashHint(hashName string) solver.Hint { return fmt.Errorf("hash mismatch: expected %x, got %x", expectedHash, hashed) } + outs[0].SetBytes(hashed) + return nil } } From 5e222e4ccbacca931408659140babffa2bdb3f55 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 24 Mar 2025 19:38:26 -0500 Subject: [PATCH 22/62] replicating error with const hash --- std/gkr/example_test.go | 95 ++++++++++++++++++++++++++++++++++++++--- std/gkr/gkr.go | 1 + std/gkr/hints.go | 7 ++- std/gkr/testing.go | 2 +- 4 files changed, 93 insertions(+), 12 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index aec1c1dbde..69afd2a648 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -9,17 +9,17 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" gkrBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" - gcHash "github.com/consensys/gnark-crypto/hash" bw6761 "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/gkr" stdHash "github.com/consensys/gnark/std/hash" - "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" + "hash" "math/big" + "testing" ) -func Example() { +func TestExample(*testing.T) { // This example computes the double of multiple BLS12-377 G1 points, which can be computed natively over BW6-761. // This means that the imported fr and fp packages are the same, being from BW6-761 and BLS12-377 respectively. TODO @Tabaie delete if no longer have fp imported // It is based on the function DoubleAssign() of type G1Jac in gnark-crypto v0.17.0. @@ -166,7 +166,8 @@ func Example() { } // register the hash function used for verifying the GKR proof (prover side) - bw6761.RegisterHashBuilder(fsHashName, gcHash.MIMC_BW6_761.New) + //bw6761.RegisterHashBuilder(fsHashName, func() hash.Hash { return hashReporter{gcHash.MIMC_BW6_761.New()} }) + bw6761.RegisterHashBuilder("const", func() hash.Hash { return constHasherBw6761{} }) assertNoError(test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) @@ -272,9 +273,12 @@ func (c *exampleCircuit) Define(api frontend.API) error { X = gkrApi.NamedGate("identity", X) // register the hash function used for verification (fiat shamir) - stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { + /*stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { m, err := mimc.NewMiMC(api) - return &m, err + return &hashReporterSnark{h: &m, api: api}, err + })*/ + stdHash.Register("const", func(api frontend.API) (stdHash.FieldHasher, error) { + return &constHasherSnark{api: api}, nil }) res := gkrApi.SolveInTestEngine(api, gkr.WithHashName(c.fsHashName)) @@ -317,6 +321,65 @@ func assertNoError(err error) { } } +type hashReporter struct { + h hash.Hash +} + +func (h hashReporter) Write(p []byte) (n int, err error) { + for i := 0; i < len(p); i += fr.Bytes { + var I big.Int + I.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + fmt.Print(I.Text(10), " ") + } + return h.h.Write(p) +} + +func (h hashReporter) Sum(b []byte) []byte { + if b != nil { + panic("unexpected input") + } + b = h.h.Sum(b) + fmt.Println("\n<-", new(big.Int).SetBytes(b).Text(10)) + return b +} + +func (h hashReporter) Reset() { + h.h.Reset() +} + +func (h hashReporter) Size() int { + return h.h.Size() +} + +func (h hashReporter) BlockSize() int { + return h.h.BlockSize() +} + +type hashReporterSnark struct { + h stdHash.FieldHasher + api frontend.API + v []frontend.Variable +} + +func (h *hashReporterSnark) Sum() frontend.Variable { + h.api.Println(h.v...) + res := h.h.Sum() + h.api.Println("<-", res) + return res +} + +func (h *hashReporterSnark) Write(v ...frontend.Variable) { + h.v = append(h.v, v...) + h.h.Write(v...) +} + +func (h *hashReporterSnark) Reset() { + h.v = h.v[:0] + h.h.Reset() +} + +const constHash byte = 3 + type constHasherBw6761 struct{} func (constHasherBw6761) Write(p []byte) (int, error) { @@ -334,7 +397,7 @@ func (constHasherBw6761) Sum(p []byte) []byte { } fmt.Println() var b [fr.Bytes]byte - b[len(b)-1] = 1 + b[len(b)-1] = constHash return b[:] } @@ -348,3 +411,21 @@ func (constHasherBw6761) Size() int { func (constHasherBw6761) BlockSize() int { return fr.Bytes } + +type constHasherSnark struct { + api frontend.API + v []frontend.Variable +} + +func (h *constHasherSnark) Sum() frontend.Variable { + h.api.Println(h.v...) + return constHash +} + +func (h *constHasherSnark) Write(v ...frontend.Variable) { + h.v = append(h.v, v...) +} + +func (h *constHasherSnark) Reset() { + h.v = h.v[:0] +} diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index b68c1c3e7c..9f0151a1ae 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -328,6 +328,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, wirePrefix := o.transcriptPrefix + "w" var baseChallenge []frontend.Variable for i := len(c) - 1; i >= 0; i-- { + api.Println("verifying wire", i) wire := o.sorted[i] if wire.IsOutput() { diff --git a/std/gkr/hints.go b/std/gkr/hints.go index 5b6ec9c756..63b57b3402 100644 --- a/std/gkr/hints.go +++ b/std/gkr/hints.go @@ -1,7 +1,6 @@ package gkr import ( - "bytes" "errors" "fmt" "github.com/consensys/gnark-crypto/ecc" @@ -137,14 +136,14 @@ func CheckHashHint(hashName string) solver.Hint { } toHash := ins[0].Bytes() - expectedHash := ins[1].Bytes() + expectedHash := ins[1] hsh := builder() hsh.Write(toHash) hashed := hsh.Sum(nil) - if !bytes.Equal(hashed, expectedHash) { - return fmt.Errorf("hash mismatch: expected %x, got %x", expectedHash, hashed) + if hashed := new(big.Int).SetBytes(hashed); hashed.Cmp(expectedHash) != 0 { + return fmt.Errorf("hash mismatch: expected %s, got %s", expectedHash.String(), hashed.String()) } outs[0].SetBytes(hashed) diff --git a/std/gkr/testing.go b/std/gkr/testing.go index aefb3f92ac..74a8fc1f5c 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -56,7 +56,7 @@ func (api *API) SolveInTestEngine(parentApi frontend.API, options ...SolveInTest nbBytes := (parentApi.Compiler().FieldBitLen() + 7) / 8 toHash := frontend.Variable(0) for i := range nbBytes { - toHash = parentApi.Add(toHash, toHash, i%256) + toHash = parentApi.Add(parentApi.Mul(toHash, 256), i%256) } h.Reset() h.Write(toHash) From f5e741076fa5ac93e1d380c79b6c12678a29d56c Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 27 Mar 2025 09:51:53 -0500 Subject: [PATCH 23/62] checkpoint --- std/gkr/example_test.go | 34 ++++++++++++++++++++------- std/gkr/gkr.go | 3 ++- std/gkr/internal/bn254_wrapper_api.go | 7 +++--- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index 69afd2a648..962905341e 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -81,16 +81,19 @@ func TestExample(*testing.T) { // combine the operations that define the assignment to p.Y // input = [S, p.X, XX, YYYY] assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"y", func(input ...fr.Element) (Y fr.Element) { + fmt.Println("in", input[0].String(), input[1].String(), input[2].String(), input[3].String()) Y.Double(&input[2]).Add(&Y, &input[2]) // 414: M.Double(&XX).Add(&M, &XX) input[2] = Y Y.Sub(&input[0], &input[1]). // 423: p.Y.Sub(&S, &p.X). - Mul(&Y, &input[2]) // 424: p.Y.Mul(&p.Y, &M). - input[3].Double(&input[3]).Double(&input[3]).Double(&input[3]) // 425: M.Double(&YYYY).Double(&M).Double(&M) + Mul(&Y, &input[2]) // 424: Mul(&p.Y, &M). + input[3].Double(&input[3]).Double(&input[3]).Double(&input[3]) // 425: YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) Y.Sub(&Y, &input[3]) // 426: p.Y.Sub(&p.Y, &YYYY) + fmt.Println("out", Y.String()) return }, 4)) + fmt.Println("y gate registered") // we have a lot of squaring operations, which we'd rather look at as single-input assertNoError(gkrBw6761.RegisterGate("square", func(input ...fr.Element) (res fr.Element) { @@ -256,16 +259,22 @@ func (c *exampleCircuit) Define(api frontend.API) error { // combine the operations that define the assignment to p.Y // input = [S, p.X, XX, YYYY] + // p.Y = assertNoError(gkr.RegisterGate(c.gateNamePrefix+"y", func(api frontend.API, input ...frontend.Variable) (Y frontend.Variable) { - input[2] = api.Mul(3, input[2]) // 414: M.Double(&XX).Add(&M, &XX) + + api.Println("SNARK in", input[0], input[1], input[2], input[3]) Y = api.Sub(input[0], input[1]) // 423: p.Y.Sub(&S, &p.X). - Y = api.Mul(Y, input[2]) // 424: Mul(&p.Y, &M) - // 425: M.Double(&YYYY).Double(&M).Double(&M) - Y = api.Sub(Y, api.Mul(input[3], 8)) // 426: p.Y.Sub(&p.Y, &YYYY) + Y = api.Mul(Y, input[2], 3) // 414: M.Double(&XX).Add(&M, &XX) + // 424:Mul(&p.Y, &M) + Y = api.Sub(Y, api.Mul(input[3], 8)) // 425: YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) + // 426: p.Y.Sub(&p.Y, &YYYY) + + api.Println("SNARK out", Y) return }, 4)) + fmt.Println("y gate registered") Y = gkrApi.NamedGate(c.gateNamePrefix+"y", S, X, XX, YYYY) // 423 - 426 // have to duplicate X for it to be considered an output variable @@ -379,6 +388,7 @@ func (h *hashReporterSnark) Reset() { } const constHash byte = 3 +const printHashes = false type constHasherBw6761 struct{} @@ -386,7 +396,9 @@ func (constHasherBw6761) Write(p []byte) (int, error) { for i := 0; i < len(p); i += fr.Bytes { var I big.Int I.SetBytes(p[i:min(len(p), i+fr.Bytes)]) - fmt.Print(I.Text(10), " ") + if printHashes { + fmt.Print(I.Text(10), " ") + } } return len(p), nil } @@ -395,7 +407,9 @@ func (constHasherBw6761) Sum(p []byte) []byte { if p != nil { panic("unexpected input") } - fmt.Println() + if printHashes { + fmt.Println() + } var b [fr.Bytes]byte b[len(b)-1] = constHash return b[:] @@ -418,7 +432,9 @@ type constHasherSnark struct { } func (h *constHasherSnark) Sum() frontend.Variable { - h.api.Println(h.v...) + if printHashes { + h.api.Println(h.v...) + } return constHash } diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index 9f0151a1ae..2f35dbf708 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -3,11 +3,12 @@ package gkr import ( "errors" "fmt" + "strconv" + "github.com/consensys/gnark/frontend" fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/polynomial" "github.com/consensys/gnark/std/sumcheck" - "strconv" ) // @tabaie TODO: Contains many things copy-pasted from gnark-crypto. Generify somehow? diff --git a/std/gkr/internal/bn254_wrapper_api.go b/std/gkr/internal/bn254_wrapper_api.go index 3857311ed2..cb12a81b86 100644 --- a/std/gkr/internal/bn254_wrapper_api.go +++ b/std/gkr/internal/bn254_wrapper_api.go @@ -21,13 +21,12 @@ func ToBn254GateFunction(f func(frontend.API, ...frontend.Variable) frontend.Var var wrapper bn254WrapperApi return func(x ...fr.Element) fr.Element { - if wrapper.err != nil { - return fr.Element{} - } res := f(&wrapper, utils.Map(x, func(x fr.Element) frontend.Variable { return &x })...).(*fr.Element) - + if wrapper.err != nil { + panic(wrapper.err) + } return *res } } From a2fc86f974332ff9a2c911aa41c378b67c4d22b5 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 27 Mar 2025 09:56:27 -0500 Subject: [PATCH 24/62] chore: example works --- std/gkr/example_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index 962905341e..f366e7857d 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -181,7 +181,7 @@ type exampleCircuit struct { X, Y, Z []frontend.Variable // Jacobian coordinates for each point (input) XOut, YOut, ZOut []frontend.Variable // Jacobian coordinates for the double of each point (expected output) SOut []frontend.Variable // temporary - gateNamePrefix string + gateNamePrefix gkr.GateName fsHashName string // name of the hash function used for Fiat-Shamir in the GKR verifier } From f9df20e5cebef22e6e3f9fa475f43842b3698528 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 27 Mar 2025 10:04:40 -0500 Subject: [PATCH 25/62] chore: remove prints --- std/gkr/example_test.go | 108 +++------------------------------------- std/gkr/gkr.go | 1 - 2 files changed, 7 insertions(+), 102 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index f366e7857d..8a3ce361db 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -9,24 +9,25 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" gkrBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" + gcHash "github.com/consensys/gnark-crypto/hash" bw6761 "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/gkr" stdHash "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" "hash" "math/big" - "testing" ) -func TestExample(*testing.T) { +func Example() { // This example computes the double of multiple BLS12-377 G1 points, which can be computed natively over BW6-761. // This means that the imported fr and fp packages are the same, being from BW6-761 and BLS12-377 respectively. TODO @Tabaie delete if no longer have fp imported // It is based on the function DoubleAssign() of type G1Jac in gnark-crypto v0.17.0. // github.com/consensys/gnark-crypto/ecc/bls12-377 const ( gateNamePrefix = "bls12-377-jac-double-" - fsHashName = "const" + fsHashName = "mimc" ) // Every gate needs to be defined over a concrete field, used by the GKR prover, @@ -81,7 +82,6 @@ func TestExample(*testing.T) { // combine the operations that define the assignment to p.Y // input = [S, p.X, XX, YYYY] assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"y", func(input ...fr.Element) (Y fr.Element) { - fmt.Println("in", input[0].String(), input[1].String(), input[2].String(), input[3].String()) Y.Double(&input[2]).Add(&Y, &input[2]) // 414: M.Double(&XX).Add(&M, &XX) input[2] = Y @@ -90,10 +90,8 @@ func TestExample(*testing.T) { input[3].Double(&input[3]).Double(&input[3]).Double(&input[3]) // 425: YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) Y.Sub(&Y, &input[3]) // 426: p.Y.Sub(&p.Y, &YYYY) - fmt.Println("out", Y.String()) return }, 4)) - fmt.Println("y gate registered") // we have a lot of squaring operations, which we'd rather look at as single-input assertNoError(gkrBw6761.RegisterGate("square", func(input ...fr.Element) (res fr.Element) { @@ -169,8 +167,7 @@ func TestExample(*testing.T) { } // register the hash function used for verifying the GKR proof (prover side) - //bw6761.RegisterHashBuilder(fsHashName, func() hash.Hash { return hashReporter{gcHash.MIMC_BW6_761.New()} }) - bw6761.RegisterHashBuilder("const", func() hash.Hash { return constHasherBw6761{} }) + bw6761.RegisterHashBuilder(fsHashName, gcHash.MIMC_BW6_761.New) assertNoError(test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) @@ -261,20 +258,14 @@ func (c *exampleCircuit) Define(api frontend.API) error { // input = [S, p.X, XX, YYYY] // p.Y = assertNoError(gkr.RegisterGate(c.gateNamePrefix+"y", func(api frontend.API, input ...frontend.Variable) (Y frontend.Variable) { - - api.Println("SNARK in", input[0], input[1], input[2], input[3]) - Y = api.Sub(input[0], input[1]) // 423: p.Y.Sub(&S, &p.X). Y = api.Mul(Y, input[2], 3) // 414: M.Double(&XX).Add(&M, &XX) // 424:Mul(&p.Y, &M) Y = api.Sub(Y, api.Mul(input[3], 8)) // 425: YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) // 426: p.Y.Sub(&p.Y, &YYYY) - api.Println("SNARK out", Y) - return }, 4)) - fmt.Println("y gate registered") Y = gkrApi.NamedGate(c.gateNamePrefix+"y", S, X, XX, YYYY) // 423 - 426 // have to duplicate X for it to be considered an output variable @@ -282,12 +273,9 @@ func (c *exampleCircuit) Define(api frontend.API) error { X = gkrApi.NamedGate("identity", X) // register the hash function used for verification (fiat shamir) - /*stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { + stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { m, err := mimc.NewMiMC(api) - return &hashReporterSnark{h: &m, api: api}, err - })*/ - stdHash.Register("const", func(api frontend.API) (stdHash.FieldHasher, error) { - return &constHasherSnark{api: api}, nil + return &m, err }) res := gkrApi.SolveInTestEngine(api, gkr.WithHashName(c.fsHashName)) @@ -363,85 +351,3 @@ func (h hashReporter) Size() int { func (h hashReporter) BlockSize() int { return h.h.BlockSize() } - -type hashReporterSnark struct { - h stdHash.FieldHasher - api frontend.API - v []frontend.Variable -} - -func (h *hashReporterSnark) Sum() frontend.Variable { - h.api.Println(h.v...) - res := h.h.Sum() - h.api.Println("<-", res) - return res -} - -func (h *hashReporterSnark) Write(v ...frontend.Variable) { - h.v = append(h.v, v...) - h.h.Write(v...) -} - -func (h *hashReporterSnark) Reset() { - h.v = h.v[:0] - h.h.Reset() -} - -const constHash byte = 3 -const printHashes = false - -type constHasherBw6761 struct{} - -func (constHasherBw6761) Write(p []byte) (int, error) { - for i := 0; i < len(p); i += fr.Bytes { - var I big.Int - I.SetBytes(p[i:min(len(p), i+fr.Bytes)]) - if printHashes { - fmt.Print(I.Text(10), " ") - } - } - return len(p), nil -} - -func (constHasherBw6761) Sum(p []byte) []byte { - if p != nil { - panic("unexpected input") - } - if printHashes { - fmt.Println() - } - var b [fr.Bytes]byte - b[len(b)-1] = constHash - return b[:] -} - -func (constHasherBw6761) Reset() { -} - -func (constHasherBw6761) Size() int { - return fr.Bytes -} - -func (constHasherBw6761) BlockSize() int { - return fr.Bytes -} - -type constHasherSnark struct { - api frontend.API - v []frontend.Variable -} - -func (h *constHasherSnark) Sum() frontend.Variable { - if printHashes { - h.api.Println(h.v...) - } - return constHash -} - -func (h *constHasherSnark) Write(v ...frontend.Variable) { - h.v = append(h.v, v...) -} - -func (h *constHasherSnark) Reset() { - h.v = h.v[:0] -} diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index 2f35dbf708..9465358ac7 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -329,7 +329,6 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, wirePrefix := o.transcriptPrefix + "w" var baseChallenge []frontend.Variable for i := len(c) - 1; i >= 0; i-- { - api.Println("verifying wire", i) wire := o.sorted[i] if wire.IsOutput() { From ca23cc90aaeb9d22c0eecb843052d708e35062a3 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 27 Mar 2025 10:14:48 -0500 Subject: [PATCH 26/62] fix: gofmt --- std/gkr/example_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index 8a3ce361db..bd0d1ce5c6 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -47,10 +47,10 @@ func Example() { assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"s", func(input ...fr.Element) (S fr.Element) { S. Add(&input[0], &input[1]). // 409: S.Add(&p.X, &YY) - Square(&S). // 410: S.Square(&S). - Sub(&S, &input[2]). // 411: Sub(&S, &XX). - Sub(&S, &input[3]). // 412: Sub(&S, &YYYY). - Double(&S) // 413: Double(&S) + Square(&S). // 410: S.Square(&S). + Sub(&S, &input[2]). // 411: Sub(&S, &XX). + Sub(&S, &input[3]). // 412: Sub(&S, &YYYY). + Double(&S) // 413: Double(&S) return }, 4)) @@ -86,7 +86,7 @@ func Example() { input[2] = Y Y.Sub(&input[0], &input[1]). // 423: p.Y.Sub(&S, &p.X). - Mul(&Y, &input[2]) // 424: Mul(&p.Y, &M). + Mul(&Y, &input[2]) // 424: Mul(&p.Y, &M). input[3].Double(&input[3]).Double(&input[3]).Double(&input[3]) // 425: YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) Y.Sub(&Y, &input[3]) // 426: p.Y.Sub(&p.Y, &YYYY) From 7257006c7b99fbc8935e11cb2e3cb94aee3c4da0 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 27 Mar 2025 10:23:12 -0500 Subject: [PATCH 27/62] remove hash reporter --- std/gkr/example_test.go | 47 +++++------------------------------------ 1 file changed, 5 insertions(+), 42 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index bd0d1ce5c6..525f6535e9 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -3,7 +3,6 @@ package gkr_test import ( "encoding/binary" "errors" - "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" @@ -16,8 +15,6 @@ import ( stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/test" - "hash" - "math/big" ) func Example() { @@ -47,10 +44,10 @@ func Example() { assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"s", func(input ...fr.Element) (S fr.Element) { S. Add(&input[0], &input[1]). // 409: S.Add(&p.X, &YY) - Square(&S). // 410: S.Square(&S). - Sub(&S, &input[2]). // 411: Sub(&S, &XX). - Sub(&S, &input[3]). // 412: Sub(&S, &YYYY). - Double(&S) // 413: Double(&S) + Square(&S). // 410: S.Square(&S). + Sub(&S, &input[2]). // 411: Sub(&S, &XX). + Sub(&S, &input[3]). // 412: Sub(&S, &YYYY). + Double(&S) // 413: Double(&S) return }, 4)) @@ -86,7 +83,7 @@ func Example() { input[2] = Y Y.Sub(&input[0], &input[1]). // 423: p.Y.Sub(&S, &p.X). - Mul(&Y, &input[2]) // 424: Mul(&p.Y, &M). + Mul(&Y, &input[2]) // 424: Mul(&p.Y, &M). input[3].Double(&input[3]).Double(&input[3]).Double(&input[3]) // 425: YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) Y.Sub(&Y, &input[3]) // 426: p.Y.Sub(&p.Y, &YYYY) @@ -317,37 +314,3 @@ func assertNoError(err error) { panic(err) } } - -type hashReporter struct { - h hash.Hash -} - -func (h hashReporter) Write(p []byte) (n int, err error) { - for i := 0; i < len(p); i += fr.Bytes { - var I big.Int - I.SetBytes(p[i:min(len(p), i+fr.Bytes)]) - fmt.Print(I.Text(10), " ") - } - return h.h.Write(p) -} - -func (h hashReporter) Sum(b []byte) []byte { - if b != nil { - panic("unexpected input") - } - b = h.h.Sum(b) - fmt.Println("\n<-", new(big.Int).SetBytes(b).Text(10)) - return b -} - -func (h hashReporter) Reset() { - h.h.Reset() -} - -func (h hashReporter) Size() int { - return h.h.Size() -} - -func (h hashReporter) BlockSize() int { - return h.h.BlockSize() -} From 6997e71184e3e5b717920155b9281dade8fb9a8f Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 27 Mar 2025 10:31:41 -0500 Subject: [PATCH 28/62] gofmt --- std/gkr/example_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index 525f6535e9..c02845929e 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -44,10 +44,10 @@ func Example() { assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"s", func(input ...fr.Element) (S fr.Element) { S. Add(&input[0], &input[1]). // 409: S.Add(&p.X, &YY) - Square(&S). // 410: S.Square(&S). - Sub(&S, &input[2]). // 411: Sub(&S, &XX). - Sub(&S, &input[3]). // 412: Sub(&S, &YYYY). - Double(&S) // 413: Double(&S) + Square(&S). // 410: S.Square(&S). + Sub(&S, &input[2]). // 411: Sub(&S, &XX). + Sub(&S, &input[3]). // 412: Sub(&S, &YYYY). + Double(&S) // 413: Double(&S) return }, 4)) @@ -83,7 +83,7 @@ func Example() { input[2] = Y Y.Sub(&input[0], &input[1]). // 423: p.Y.Sub(&S, &p.X). - Mul(&Y, &input[2]) // 424: Mul(&p.Y, &M). + Mul(&Y, &input[2]) // 424: Mul(&p.Y, &M). input[3].Double(&input[3]).Double(&input[3]).Double(&input[3]) // 425: YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) Y.Sub(&Y, &input[3]) // 426: p.Y.Sub(&p.Y, &YYYY) From f4852d93e1b9c7c9bdb2a6ad240b4faf9e27bfae Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 27 Mar 2025 10:49:08 -0500 Subject: [PATCH 29/62] remove checks for S --- std/gkr/example_test.go | 48 ++++------------------------------------- 1 file changed, 4 insertions(+), 44 deletions(-) diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index c02845929e..f0099f23a5 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -5,7 +5,6 @@ import ( "errors" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" gkrBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" gcHash "github.com/consensys/gnark-crypto/hash" @@ -54,7 +53,7 @@ func Example() { // combine the operations that define the assignment to p.Z // input = [p.Z, p.Y, YY, ZZ] - // Z = (p.Z + p.Y)² - YY - ZZ + // p.Z = (p.Z + p.Y)² - YY - ZZ assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"z", func(input ...fr.Element) (Z fr.Element) { Z.Add(&input[0], &input[1]) // 415: p.Z.Add(&p.Z, &p.Y). Z.Square(&Z) // 416: p.Z.Square(&p.Z). @@ -78,6 +77,7 @@ func Example() { // combine the operations that define the assignment to p.Y // input = [S, p.X, XX, YYYY] + // p.Y = (S - p.X) * 3 * XX - 8 * YYYY assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"y", func(input ...fr.Element) (Y fr.Element) { Y.Double(&input[2]).Add(&Y, &input[2]) // 414: M.Double(&XX).Add(&M, &XX) input[2] = Y @@ -105,7 +105,6 @@ func Example() { XOut: make([]frontend.Variable, nbInstances), YOut: make([]frontend.Variable, nbInstances), ZOut: make([]frontend.Variable, nbInstances), - SOut: make([]frontend.Variable, nbInstances), } for i := range nbInstances { @@ -125,30 +124,6 @@ func Example() { assignment.XOut[i] = p.X assignment.YOut[i] = p.Y assignment.ZOut[i] = p.Z - - // TODO delete this - { - - p.X = assignment.X[i].(fp.Element) - p.Y = assignment.Y[i].(fp.Element) - p.Z = assignment.Z[i].(fp.Element) - - var XX, YY, YYYY, ZZ, S, M, T fp.Element - - _, _ = M, T - - XX.Square(&p.X) - YY.Square(&p.Y) - YYYY.Square(&YY) - ZZ.Square(&p.Z) - S.Add(&p.X, &YY). - Square(&S). - Sub(&S, &XX). - Sub(&S, &YYYY). - Double(&S) - - assignment.SOut[i] = S - } } circuit := exampleCircuit{ @@ -158,7 +133,6 @@ func Example() { XOut: make([]frontend.Variable, nbInstances), YOut: make([]frontend.Variable, nbInstances), ZOut: make([]frontend.Variable, nbInstances), - SOut: make([]frontend.Variable, nbInstances), gateNamePrefix: gateNamePrefix, fsHashName: fsHashName, } @@ -174,7 +148,6 @@ func Example() { type exampleCircuit struct { X, Y, Z []frontend.Variable // Jacobian coordinates for each point (input) XOut, YOut, ZOut []frontend.Variable // Jacobian coordinates for the double of each point (expected output) - SOut []frontend.Variable // temporary gateNamePrefix gkr.GateName fsHashName string // name of the hash function used for Fiat-Shamir in the GKR verifier } @@ -222,7 +195,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { return api.Add(S, S) // 413: Double(&S) }, 4)) S := gkrApi.NamedGate(c.gateNamePrefix+"s", X, YY, XX, YYYY) // 409 - 413 - scp := gkrApi.NamedGate("identity", S) + // 414: M.Double(&XX).Add(&M, &XX) // Note (but don't explicitly compute) that M = 3XX @@ -253,7 +226,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { // combine the operations that define the assignment to p.Y // input = [S, p.X, XX, YYYY] - // p.Y = + // p.Y = (S - p.X) * 3 * XX - 8 * YYYY assertNoError(gkr.RegisterGate(c.gateNamePrefix+"y", func(api frontend.API, input ...frontend.Variable) (Y frontend.Variable) { Y = api.Sub(input[0], input[1]) // 423: p.Y.Sub(&S, &p.X). Y = api.Mul(Y, input[2], 3) // 414: M.Double(&XX).Add(&M, &XX) @@ -275,14 +248,6 @@ func (c *exampleCircuit) Define(api frontend.API) error { return &m, err }) - res := gkrApi.SolveInTestEngine(api, gkr.WithHashName(c.fsHashName)) - for i := range c.XOut { - api.AssertIsEqual(res[scp][i], c.SOut[i]) - api.AssertIsEqual(res[Z][i], c.ZOut[i]) - api.AssertIsEqual(res[X][i], c.XOut[i]) - api.AssertIsEqual(res[Y][i], c.YOut[i]) - } - // solve and prove the circuit solution, err := gkrApi.Solve(api) if err != nil { @@ -290,11 +255,6 @@ func (c *exampleCircuit) Define(api frontend.API) error { } // check the output - // TODO merge loops - SOut := solution.Export(scp) - for i := range SOut { - api.AssertIsEqual(SOut[i], c.SOut[i]) - } XOut := solution.Export(X) YOut := solution.Export(Y) From 235f1259eb498b333a0cd3840400c2fed88054ea Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 27 Mar 2025 12:02:28 -0500 Subject: [PATCH 30/62] build use nogkr from gnark-crypto --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index e3aa584208..821d689e28 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.31-0.20250314194434-b30d4344e6d4 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1 + github.com/consensys/gnark-crypto v0.17.1-0.20250327163404-2fc9f58298e2 github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index 7e436d0c47..b26a6e7b42 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,8 @@ github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19 github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1 h1:6cK71BoMAjWHNl+EpvBh2PDDa0PIeoz1KFJ/6R16DjQ= github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1/go.mod h1:uV1HwfBwGRj50DGK3LbDLeCvq0RX/vFXST3CRSAu0Fs= +github.com/consensys/gnark-crypto v0.17.1-0.20250327163404-2fc9f58298e2 h1:vWEj2nIXK3dG2oyNlqLXVOFRGi1E3BQn+YVkwqf7GnM= +github.com/consensys/gnark-crypto v0.17.1-0.20250327163404-2fc9f58298e2/go.mod h1:uV1HwfBwGRj50DGK3LbDLeCvq0RX/vFXST3CRSAu0Fs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= From a3088112e8b776f58286589197c36e8ddaeb9eee Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 1 Apr 2025 16:13:19 -0500 Subject: [PATCH 31/62] generate code for sumcheck --- go.mod | 2 +- go.sum | 2 + internal/generator/backend/gkr/generate.go | 29 + .../backend/gkr/test_vectors/main.go | 349 +++++++ .../mimc_five_levels_two_instances._json | 7 + .../resources/mimc_five_levels.json | 36 + .../resources/single_identity_gate.json | 10 + .../single_input_two_identity_gates.json | 14 + .../resources/single_input_two_outs.json | 14 + .../resources/single_mimc_gate.json | 7 + .../resources/single_mul_gate.json | 14 + ..._identity_gates_composed_single_input.json | 14 + .../two_inputs_select-input-3_gate.json | 14 + .../single_identity_gate_two_instances.json | 36 + ...nput_two_identity_gates_two_instances.json | 56 ++ .../single_input_two_outs_two_instances.json | 57 ++ .../single_mimc_gate_four_instances.json | 67 ++ .../single_mimc_gate_two_instances.json | 51 ++ .../single_mul_gate_two_instances.json | 46 + ...s_composed_single_input_two_instances.json | 47 + ...uts_select-input-3_gate_two_instances.json | 45 + internal/generator/backend/main.go | 44 +- .../backend/sumcheck/test_vectors/main.go | 199 ++++ .../sumcheck/test_vectors/vectors.json | 56 ++ .../backend/template/gkr/gkr.go.tmpl | 863 ++++++++++++++++++ .../backend/template/gkr/gkr.test.go.tmpl | 611 +++++++++++++ .../template/gkr/gkr.test.vectors.gen.go.tmpl | 123 +++ .../template/gkr/gkr.test.vectors.go.tmpl | 254 ++++++ .../backend/template/gkr/registry.go.tmpl | 390 ++++++++ .../backend/template/gkr/sumcheck.go.tmpl | 163 ++++ .../template/gkr/sumcheck.test.go.tmpl | 143 +++ .../template/gkr/test_vector_utils.go.tmpl | 220 +++++ internal/gkr/bls12-377/sumcheck/sumcheck.go | 170 ++++ .../gkr/bls12-377/sumcheck/sumcheck_test.go | 150 +++ .../test_vector_utils/test_vector_utils.go | 216 +++++ internal/gkr/bls12-381/sumcheck/sumcheck.go | 170 ++++ .../gkr/bls12-381/sumcheck/sumcheck_test.go | 150 +++ .../test_vector_utils/test_vector_utils.go | 216 +++++ internal/gkr/bls24-315/sumcheck/sumcheck.go | 170 ++++ .../gkr/bls24-315/sumcheck/sumcheck_test.go | 150 +++ .../test_vector_utils/test_vector_utils.go | 216 +++++ internal/gkr/bls24-317/sumcheck/sumcheck.go | 170 ++++ .../gkr/bls24-317/sumcheck/sumcheck_test.go | 150 +++ .../test_vector_utils/test_vector_utils.go | 216 +++++ internal/gkr/bn254/sumcheck/sumcheck.go | 170 ++++ internal/gkr/bn254/sumcheck/sumcheck_test.go | 150 +++ .../test_vector_utils/test_vector_utils.go | 216 +++++ internal/gkr/bw6-633/sumcheck/sumcheck.go | 170 ++++ .../gkr/bw6-633/sumcheck/sumcheck_test.go | 150 +++ .../test_vector_utils/test_vector_utils.go | 216 +++++ internal/gkr/bw6-761/sumcheck/sumcheck.go | 170 ++++ .../gkr/bw6-761/sumcheck/sumcheck_test.go | 150 +++ .../test_vector_utils/test_vector_utils.go | 216 +++++ 53 files changed, 7731 insertions(+), 4 deletions(-) create mode 100644 internal/generator/backend/gkr/generate.go create mode 100644 internal/generator/backend/gkr/test_vectors/main.go create mode 100644 internal/generator/backend/gkr/test_vectors/mimc_five_levels_two_instances._json create mode 100644 internal/generator/backend/gkr/test_vectors/resources/mimc_five_levels.json create mode 100644 internal/generator/backend/gkr/test_vectors/resources/single_identity_gate.json create mode 100644 internal/generator/backend/gkr/test_vectors/resources/single_input_two_identity_gates.json create mode 100644 internal/generator/backend/gkr/test_vectors/resources/single_input_two_outs.json create mode 100644 internal/generator/backend/gkr/test_vectors/resources/single_mimc_gate.json create mode 100644 internal/generator/backend/gkr/test_vectors/resources/single_mul_gate.json create mode 100644 internal/generator/backend/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json create mode 100644 internal/generator/backend/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json create mode 100644 internal/generator/backend/gkr/test_vectors/single_identity_gate_two_instances.json create mode 100644 internal/generator/backend/gkr/test_vectors/single_input_two_identity_gates_two_instances.json create mode 100644 internal/generator/backend/gkr/test_vectors/single_input_two_outs_two_instances.json create mode 100644 internal/generator/backend/gkr/test_vectors/single_mimc_gate_four_instances.json create mode 100644 internal/generator/backend/gkr/test_vectors/single_mimc_gate_two_instances.json create mode 100644 internal/generator/backend/gkr/test_vectors/single_mul_gate_two_instances.json create mode 100644 internal/generator/backend/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json create mode 100644 internal/generator/backend/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json create mode 100644 internal/generator/backend/sumcheck/test_vectors/main.go create mode 100644 internal/generator/backend/sumcheck/test_vectors/vectors.json create mode 100644 internal/generator/backend/template/gkr/gkr.go.tmpl create mode 100644 internal/generator/backend/template/gkr/gkr.test.go.tmpl create mode 100644 internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl create mode 100644 internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl create mode 100644 internal/generator/backend/template/gkr/registry.go.tmpl create mode 100644 internal/generator/backend/template/gkr/sumcheck.go.tmpl create mode 100644 internal/generator/backend/template/gkr/sumcheck.test.go.tmpl create mode 100644 internal/generator/backend/template/gkr/test_vector_utils.go.tmpl create mode 100644 internal/gkr/bls12-377/sumcheck/sumcheck.go create mode 100644 internal/gkr/bls12-377/sumcheck/sumcheck_test.go create mode 100644 internal/gkr/bls12-377/test_vector_utils/test_vector_utils.go create mode 100644 internal/gkr/bls12-381/sumcheck/sumcheck.go create mode 100644 internal/gkr/bls12-381/sumcheck/sumcheck_test.go create mode 100644 internal/gkr/bls12-381/test_vector_utils/test_vector_utils.go create mode 100644 internal/gkr/bls24-315/sumcheck/sumcheck.go create mode 100644 internal/gkr/bls24-315/sumcheck/sumcheck_test.go create mode 100644 internal/gkr/bls24-315/test_vector_utils/test_vector_utils.go create mode 100644 internal/gkr/bls24-317/sumcheck/sumcheck.go create mode 100644 internal/gkr/bls24-317/sumcheck/sumcheck_test.go create mode 100644 internal/gkr/bls24-317/test_vector_utils/test_vector_utils.go create mode 100644 internal/gkr/bn254/sumcheck/sumcheck.go create mode 100644 internal/gkr/bn254/sumcheck/sumcheck_test.go create mode 100644 internal/gkr/bn254/test_vector_utils/test_vector_utils.go create mode 100644 internal/gkr/bw6-633/sumcheck/sumcheck.go create mode 100644 internal/gkr/bw6-633/sumcheck/sumcheck_test.go create mode 100644 internal/gkr/bw6-633/test_vector_utils/test_vector_utils.go create mode 100644 internal/gkr/bw6-761/sumcheck/sumcheck.go create mode 100644 internal/gkr/bw6-761/sumcheck/sumcheck_test.go create mode 100644 internal/gkr/bw6-761/test_vector_utils/test_vector_utils.go diff --git a/go.mod b/go.mod index 821d689e28..a988755b1b 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.31-0.20250314194434-b30d4344e6d4 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.17.1-0.20250327163404-2fc9f58298e2 + github.com/consensys/gnark-crypto v0.17.1-0.20250331132656-820ac1d108bd github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index b26a6e7b42..e3624d7e94 100644 --- a/go.sum +++ b/go.sum @@ -65,6 +65,8 @@ github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1 h1:6cK71 github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1/go.mod h1:uV1HwfBwGRj50DGK3LbDLeCvq0RX/vFXST3CRSAu0Fs= github.com/consensys/gnark-crypto v0.17.1-0.20250327163404-2fc9f58298e2 h1:vWEj2nIXK3dG2oyNlqLXVOFRGi1E3BQn+YVkwqf7GnM= github.com/consensys/gnark-crypto v0.17.1-0.20250327163404-2fc9f58298e2/go.mod h1:uV1HwfBwGRj50DGK3LbDLeCvq0RX/vFXST3CRSAu0Fs= +github.com/consensys/gnark-crypto v0.17.1-0.20250331132656-820ac1d108bd h1:og4X8KhBpFv37u0PuXLz7yOLz9vwOcpTX4pKfbbwtgM= +github.com/consensys/gnark-crypto v0.17.1-0.20250331132656-820ac1d108bd/go.mod h1:uV1HwfBwGRj50DGK3LbDLeCvq0RX/vFXST3CRSAu0Fs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/internal/generator/backend/gkr/generate.go b/internal/generator/backend/gkr/generate.go new file mode 100644 index 0000000000..3b679276d4 --- /dev/null +++ b/internal/generator/backend/gkr/generate.go @@ -0,0 +1,29 @@ +package gkr + +import ( + "path/filepath" + + "github.com/consensys/bavard" +) + +type Config struct { + GenerateTests bool + RetainTestCaseRawInfo bool + CanUseFFT bool + OutsideGkrPackage bool + TestVectorsRelativePath string +} + +func Generate(config Config, baseDir string, bgen *bavard.BatchGenerator) error { + entries := []bavard.Entry{ + {File: filepath.Join(baseDir, "gkr.go"), Templates: []string{"gkr.go.tmpl"}}, + {File: filepath.Join(baseDir, "registry.go"), Templates: []string{"registry.go.tmpl"}}, + } + + if config.GenerateTests { + entries = append(entries, + bavard.Entry{File: filepath.Join(baseDir, "gkr_test.go"), Templates: []string{"gkr.test.go.tmpl", "gkr.test.vectors.go.tmpl"}}) + } + + return bgen.Generate(config, "gkr", "./gkr/template/", entries...) +} diff --git a/internal/generator/backend/gkr/test_vectors/main.go b/internal/generator/backend/gkr/test_vectors/main.go new file mode 100644 index 0000000000..0bb86739af --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/main.go @@ -0,0 +1,349 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package main + +import ( + "encoding/json" + "fmt" + "hash" + "os" + "path/filepath" + "reflect" + + "github.com/consensys/bavard" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" +) + +func main() { + if err := GenerateVectors(); err != nil { + fmt.Println(err.Error()) + os.Exit(-1) + } +} + +func GenerateVectors() error { + testDirPath, err := filepath.Abs("gkr/test_vectors") + if err != nil { + return err + } + + fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) + + dirEntries, err := os.ReadDir(testDirPath) + if err != nil { + return err + } + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + if !bavard.ShouldGenerate(path) { + continue + } + fmt.Println("\tprocessing", dirEntry.Name()) + if err = run(path); err != nil { + return err + } + } + } + } + + return nil +} + +func run(absPath string) error { + testCase, err := newTestCase(absPath) + if err != nil { + return err + } + + transcriptSetting := fiatshamir.WithHash(testCase.Hash) + + var proof gkr.Proof + proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) + if err != nil { + return err + } + + if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { + return err + } + var outBytes []byte + if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { + if err = os.WriteFile(absPath, outBytes, 0); err != nil { + return err + } + } else { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) + if err != nil { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + if err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { + res := make(PrintableProof, len(proof)) + + for i := range proof { + + partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) + for k, partialK := range proof[i].PartialSumPolys { + partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) + } + + res[i] = PrintableSumcheckProof{ + FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), + PartialSumPolys: partialSumPolys, + } + } + return res, nil +} + +type WireInfo struct { + Gate gkr.GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]gkr.Circuit) + +func getCircuit(path string) (gkr.Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit gkr.Circuit) { + circuit = make(gkr.Circuit, len(c)) + for i := range c { + circuit[i].Gate = gkr.GetGate(c[i].Gate) + circuit[i].Inputs = make([]*gkr.Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { + var sum small_rational.SmallRational + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC gkr.GateName = "mimc" + SelectInput3 gkr.GateName = "select-input-3" +) + +func init() { + if err := gkr.RegisterGate(MiMC, mimcRound, 2, gkr.WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := gkr.RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { + return input[2] + }, 3, gkr.WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (gkr.Proof, error) { + proof := make(gkr.Proof, len(printable)) + for i := range printable { + finalEvalProof := []small_rational.SmallRational(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit gkr.Circuit + Hash hash.Hash + Proof gkr.Proof + FullAssignment gkr.WireAssignment + InOutAssignment gkr.WireAssignment + Info TestCaseInfo +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit gkr.Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof gkr.Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(gkr.WireAssignment) + inOutAssignment := make(gkr.WireAssignment) + + sorted := gkr.TopologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []small_rational.SmallRational + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + info.Output = make([][]interface{}, 0, outI) + + for _, w := range sorted { + if w.IsOutput() { + + info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + Info: info, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} diff --git a/internal/generator/backend/gkr/test_vectors/mimc_five_levels_two_instances._json b/internal/generator/backend/gkr/test_vectors/mimc_five_levels_two_instances._json new file mode 100644 index 0000000000..446d23fdb2 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/mimc_five_levels_two_instances._json @@ -0,0 +1,7 @@ +{ + "hash": {"type": "const", "val": -1}, + "circuit": "resources/mimc_five_levels.json", + "input": [[1, 3], [1, 3], [1, 3], [1, 3], [1, 3], [1, 3]], + "output": [[4, 3]], + "proof": [[{"partialSumPolys":[[3,4]],"finalEvalProof":[3]}],[{"partialSumPolys":null,"finalEvalProof":null}]] +} \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/resources/mimc_five_levels.json b/internal/generator/backend/gkr/test_vectors/resources/mimc_five_levels.json new file mode 100644 index 0000000000..3dd74f42b5 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/resources/mimc_five_levels.json @@ -0,0 +1,36 @@ +[ + [ + { + "gate": "mimc", + "inputs": [[1,0], [5,5]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[2,0], [5,4]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[3,0], [5,3]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[4,0], [5,2]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[5,0], [5,1]] + } + ], + [ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []} + ] +] \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_identity_gate.json b/internal/generator/backend/gkr/test_vectors/resources/single_identity_gate.json new file mode 100644 index 0000000000..a44066c7b4 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/resources/single_identity_gate.json @@ -0,0 +1,10 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_input_two_identity_gates.json b/internal/generator/backend/gkr/test_vectors/resources/single_input_two_identity_gates.json new file mode 100644 index 0000000000..6181784fa8 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/resources/single_input_two_identity_gates.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_input_two_outs.json b/internal/generator/backend/gkr/test_vectors/resources/single_input_two_outs.json new file mode 100644 index 0000000000..3a39e5625f --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/resources/single_input_two_outs.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul2", + "inputs": [0, 0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_mimc_gate.json b/internal/generator/backend/gkr/test_vectors/resources/single_mimc_gate.json new file mode 100644 index 0000000000..c89e7d52ae --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/resources/single_mimc_gate.json @@ -0,0 +1,7 @@ +[ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + { + "gate": "mimc", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_mul_gate.json b/internal/generator/backend/gkr/test_vectors/resources/single_mul_gate.json new file mode 100644 index 0000000000..d009ebe03d --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/resources/single_mul_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul2", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json b/internal/generator/backend/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json new file mode 100644 index 0000000000..26681c2f89 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [1] + } +] \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json b/internal/generator/backend/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json new file mode 100644 index 0000000000..cdbdb3b471 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "select-input-3", + "inputs": [0,0,1] + } +] \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/single_identity_gate_two_instances.json b/internal/generator/backend/gkr/test_vectors/single_identity_gate_two_instances.json new file mode 100644 index 0000000000..ce326d0a63 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/single_identity_gate_two_instances.json @@ -0,0 +1,36 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_identity_gate.json", + "input": [ + [ + 4, + 3 + ] + ], + "output": [ + [ + 4, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 5 + ], + "partialSumPolys": [ + [ + -3, + -8 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/single_input_two_identity_gates_two_instances.json b/internal/generator/backend/gkr/test_vectors/single_input_two_identity_gates_two_instances.json new file mode 100644 index 0000000000..2c95f044f2 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/single_input_two_identity_gates_two_instances.json @@ -0,0 +1,56 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_identity_gates.json", + "input": [ + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ], + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + 0, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + }, + { + "finalEvalProof": [ + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/single_input_two_outs_two_instances.json b/internal/generator/backend/gkr/test_vectors/single_input_two_outs_two_instances.json new file mode 100644 index 0000000000..d348303d0e --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/single_input_two_outs_two_instances.json @@ -0,0 +1,57 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_input_two_outs.json", + "input": [ + [ + 1, + 2 + ] + ], + "output": [ + [ + 1, + 4 + ], + [ + 1, + 2 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + 0, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 0 + ], + "partialSumPolys": [ + [ + -4, + -36, + -112 + ] + ] + }, + { + "finalEvalProof": [ + 0 + ], + "partialSumPolys": [ + [ + -2, + -12 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/single_mimc_gate_four_instances.json b/internal/generator/backend/gkr/test_vectors/single_mimc_gate_four_instances.json new file mode 100644 index 0000000000..ff275c9cb4 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/single_mimc_gate_four_instances.json @@ -0,0 +1,67 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1, + 2, + 1 + ], + [ + 1, + 2, + 2, + 1 + ] + ], + "output": [ + [ + 128, + 2187, + 16384, + 128 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + -1, + -3 + ], + "partialSumPolys": [ + [ + -32640, + -2239484, + -29360128, + -200000010, + -931628672, + -3373267120, + -10200858624, + -26939400158 + ], + [ + -81920, + -41943040, + -1254113280, + -13421772800, + -83200000000, + -366917713920, + -1281828208640, + -3779571220480 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/single_mimc_gate_two_instances.json b/internal/generator/backend/gkr/test_vectors/single_mimc_gate_two_instances.json new file mode 100644 index 0000000000..369297dbd5 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/single_mimc_gate_two_instances.json @@ -0,0 +1,51 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1 + ], + [ + 1, + 2 + ] + ], + "output": [ + [ + 128, + 2187 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 1, + 0 + ], + "partialSumPolys": [ + [ + -2187, + -65536, + -546875, + -2799360, + -10706059, + -33554432, + -90876411, + -220000000 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/single_mul_gate_two_instances.json b/internal/generator/backend/gkr/test_vectors/single_mul_gate_two_instances.json new file mode 100644 index 0000000000..75c1d59c3d --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/single_mul_gate_two_instances.json @@ -0,0 +1,46 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/single_mul_gate.json", + "input": [ + [ + 4, + 3 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 8, + 9 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 5, + 1 + ], + "partialSumPolys": [ + [ + -9, + -32, + -35 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/internal/generator/backend/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json new file mode 100644 index 0000000000..10e5f1ff3c --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json @@ -0,0 +1,47 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_identity_gates_composed_single_input.json", + "input": [ + [ + 2, + 1 + ] + ], + "output": [ + [ + 2, + 1 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 3 + ], + "partialSumPolys": [ + [ + -1, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 3 + ], + "partialSumPolys": [ + [ + -1, + 0 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/backend/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json b/internal/generator/backend/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json new file mode 100644 index 0000000000..19e127df71 --- /dev/null +++ b/internal/generator/backend/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json @@ -0,0 +1,45 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "resources/two_inputs_select-input-3_gate.json", + "input": [ + [ + 0, + 1 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + -1, + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 027ed4c4e3..81ebfecfa1 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -79,7 +79,7 @@ func main() { panic(err) } - datas := []templateData{ + data := []templateData{ bls12_377, bls12_381, bn254, @@ -90,11 +90,14 @@ func main() { tiny_field, } - const importCurve = "../imports.go.tmpl" + const ( + importCurve = "../imports.go.tmpl" + repoRoot = "../../../" + ) var wg sync.WaitGroup - for _, d := range datas { + for _, d := range data { wg.Add(1) @@ -129,10 +132,45 @@ func main() { // gkr backend if d.Curve != "tinyfield" { + // solver and proof delegator TODO merge with "backend" below entries = []bavard.Entry{{File: filepath.Join(csDir, "gkr.go"), Templates: []string{"gkr.go.tmpl", importCurve}}} if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { panic(err) } + + curvePackageName := strings.ToLower(d.Curve) + cfg := struct { + config.FieldDependency + GkrPackagePath string + }{ + config.FieldDependency{ + ElementType: "fr.Element", + FieldPackageName: "fr", + FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + curvePackageName + "/fr", + }, + "github.com/consensys/gnark/internal/gkr/" + curvePackageName, + } + gkrPackageDirRelPath := filepath.Join(repoRoot+"internal/gkr/", curvePackageName) + + // test vector utils + packagePath := filepath.Join(gkrPackageDirRelPath, "test_vector_utils") + entries = []bavard.Entry{ + {File: filepath.Join(packagePath, "test_vector_utils.go"), Templates: []string{"test_vector_utils.go.tmpl"}}, + } + + if err := bgen.Generate(cfg, "test_vector_utils", "./template/gkr/", entries...); err != nil { + panic(err) + } + + // sumcheck backend + packagePath = filepath.Join(gkrPackageDirRelPath, "sumcheck") + entries = []bavard.Entry{ + {File: filepath.Join(packagePath, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, + {File: filepath.Join(packagePath, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl"}}, + } + if err := bgen.Generate(cfg, "sumcheck", "./template/gkr/", entries...); err != nil { + panic(err) + } } entries = []bavard.Entry{ diff --git a/internal/generator/backend/sumcheck/test_vectors/main.go b/internal/generator/backend/sumcheck/test_vectors/main.go new file mode 100644 index 0000000000..798f5a4f3f --- /dev/null +++ b/internal/generator/backend/sumcheck/test_vectors/main.go @@ -0,0 +1,199 @@ +package main + +import ( + "encoding/json" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" + "hash" + "math/bits" + "os" + "path/filepath" +) + +func runMultilin(testCaseInfo *TestCaseInfo) error { + + var poly polynomial.MultiLin + if v, err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); err == nil { + poly = v + } else { + return err + } + + var hsh hash.Hash + var err error + if hsh, err = test_vector_utils.HashFromDescription(testCaseInfo.Hash); err != nil { + return err + } + + proof, err := sumcheck.Prove( + &singleMultilinClaim{poly}, fiatshamir.WithHash(hsh)) + if err != nil { + return err + } + testCaseInfo.Proof = toPrintableProof(proof) + + // Verification + if v, _err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); _err == nil { + poly = v + } else { + return _err + } + var claimedSum small_rational.SmallRational + if _, err = claimedSum.SetInterface(testCaseInfo.ClaimedSum); err != nil { + return err + } + + if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err != nil { + return fmt.Errorf("proof rejected: %v", err) + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func run(testCaseInfo *TestCaseInfo) error { + switch testCaseInfo.Type { + case "multilin": + return runMultilin(testCaseInfo) + default: + return fmt.Errorf("type \"%s\" unrecognized", testCaseInfo.Type) + } +} + +func runAll(relPath string) error { + var filename string + var err error + if filename, err = filepath.Abs(relPath); err != nil { + return err + } + + var bytes []byte + + if bytes, err = os.ReadFile(filename); err != nil { + return err + } + + var testCasesInfo TestCasesInfo + if err = json.Unmarshal(bytes, &testCasesInfo); err != nil { + return err + } + + failed := false + for name, testCase := range testCasesInfo { + if err = run(testCase); err != nil { + fmt.Println(name, ":", err) + failed = true + } + } + + if failed { + return fmt.Errorf("test case failed") + } + + if bytes, err = json.MarshalIndent(testCasesInfo, "", "\t"); err != nil { + return err + } + + return os.WriteFile(filename, bytes, 0) +} + +func main() { + if err := runAll("sumcheck/test_vectors/vectors.json"); err != nil { + fmt.Println(err) + os.Exit(-1) + } +} + +type TestCasesInfo map[string]*TestCaseInfo + +type TestCaseInfo struct { + Type string `json:"type"` + Hash test_vector_utils.HashDescription `json:"hash"` + Values []interface{} `json:"values"` + Description string `json:"description"` + Proof PrintableProof `json:"proof"` + ClaimedSum interface{} `json:"claimedSum"` +} + +type PrintableProof struct { + PartialSumPolys [][]interface{} `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` +} + +func toPrintableProof(proof sumcheck.Proof) (printable PrintableProof) { + if proof.FinalEvalProof != nil { + panic("null expected") + } + printable.FinalEvalProof = struct{}{} + printable.PartialSumPolys = test_vector_utils.ElementSliceSliceToInterfaceSliceSlice(proof.PartialSumPolys) + return +} + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval([]small_rational.SmallRational) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []small_rational.SmallRational{sum} +} + +func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum small_rational.SmallRational +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, _ small_rational.SmallRational, purportedValue small_rational.SmallRational, _ interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(small_rational.SmallRational) small_rational.SmallRational { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} diff --git a/internal/generator/backend/sumcheck/test_vectors/vectors.json b/internal/generator/backend/sumcheck/test_vectors/vectors.json new file mode 100644 index 0000000000..64b8e3fb2d --- /dev/null +++ b/internal/generator/backend/sumcheck/test_vectors/vectors.json @@ -0,0 +1,56 @@ +{ + "linear_univariate_single_claim": { + "type": "multilin", + "hash": { + "type": "const", + "val": -1 + }, + "values": [ + 1, + 3 + ], + "description": "X ↦ 2X + 1", + "proof": { + "partialSumPolys": [ + [ + 3 + ] + ], + "finalEvalProof": {} + }, + "claimedSum": 4 + }, + "trilinear_single_claim": { + "type": "multilin", + "hash": { + "type": "const", + "val": -1 + }, + "values": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8 + ], + "description": "X₁, X₂, X₃ ↦ 1 + 4X₁ + 2X₂ + X₃", + "proof": { + "partialSumPolys": [ + [ + 26 + ], + [ + -1 + ], + [ + -4 + ] + ], + "finalEvalProof": {} + }, + "claimedSum": 36 + } +} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl new file mode 100644 index 0000000000..c27daa9b59 --- /dev/null +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -0,0 +1,863 @@ +import ( + "errors" + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + "{{.FieldPackagePath}}/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/parallel" + "github.com/consensys/gnark-crypto/utils" + "math/big" + "strconv" + "sync" +) + +{{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...{{.ElementType}}) {{.ElementType}} + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]{{.ElementType}} + claimedEvaluations []{{.ElementType}} + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a {{.ElementType}}) {{.ElementType}} { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]{{.ElementType}}) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation {{.ElementType}} + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]{{.ElementType}}, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]{{.ElementType}} // x in the paper + claimedEvaluations []{{.ElementType}} // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff {{.ElementType}}) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq,c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.ElementType}}) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]{{.ElementType}}, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step {{.ElementType}} + + res := make([]{{.ElementType}}, degGJ) + operands := make([]{{.ElementType}}, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element {{.ElementType}}) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) interface{} { + + //defer the proof, return list of claims + evaluations := make([]{{.ElementType}}, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]{{.ElementType}}, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []{{.ElementType}}, evaluation {{.ElementType}}) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func (options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]{{.ElementType}}, error) { + res := make([]{{.ElementType}}, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []{{.ElementType}} + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []{{.ElementType}}{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]{{.ElementType}}) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []{{.ElementType}} + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]{{.ElementType}}) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// {{$topologicalSort}} sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func {{$topologicalSort}}(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := {{$topologicalSort}}(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]{{.ElementType}}, nbInstances) + } + } + + parallel.Execute(nbInstances, func(start, end int) { + ins := make([]{{.ElementType}}, maxNbIns) + for i := start; i < end; i++ { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + }) + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]{{.ElementType}}) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []{{.ElementType}}) { + for i := range src { + src[i].BigInt(dst[i]) + } +} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/gkr.test.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.go.tmpl new file mode 100644 index 0000000000..378cb813e0 --- /dev/null +++ b/internal/generator/backend/template/gkr/gkr.test.go.tmpl @@ -0,0 +1,611 @@ + +import ( + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/mimc" + "{{.FieldPackagePath}}/polynomial" + "{{.FieldPackagePath}}/sumcheck" + "{{.FieldPackagePath}}/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "fmt" + "hash" + "os" + "strconv" + "testing" + "path/filepath" + "encoding/json" + "reflect" + "time" +) + +{{$GenerateLargeTests := .GenerateTests}} {{/* this is redundant. soon to be removed if a use case for it doesn't come back */}} +{{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []{{.ElementType}}{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []{{.ElementType}}{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []{{.ElementType}}{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +{{- if $GenerateLargeTests}} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]{{.ElementType}}) { + return func(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + testMimc(t, numRounds, inputAssignments...) + } +} + +{{- end}} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{ Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + } } + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []{{.ElementType}}{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []{{.ElementType}}{three}, five) + manager.add(wire, []{{.ElementType}}{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six {{.ElementType}} + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]{{.ElementType}})) { + fullAssignments := make([][]{{.ElementType}}, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]{{.ElementType}}, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t,err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []{{.ElementType}}) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "{{.TestVectorsRelativePath}}" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := {{$topologicalSort}}(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := {{$topologicalSort}}(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := {{$topologicalSort}}(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +{{template "gkrTestVectors" .}} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl new file mode 100644 index 0000000000..832188f3d3 --- /dev/null +++ b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl @@ -0,0 +1,123 @@ +import ( + "encoding/json" + "fmt" + "hash" + "os" + "path/filepath" + "reflect" + + "github.com/consensys/bavard" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" + +) + +func main() { + if err := GenerateVectors(); err != nil { + fmt.Println(err.Error()) + os.Exit(-1) + } +} + +func GenerateVectors() error { + testDirPath, err := filepath.Abs("gkr/test_vectors") + if err != nil { + return err + } + + fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) + + dirEntries, err := os.ReadDir(testDirPath) + if err != nil { + return err + } + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + if !bavard.ShouldGenerate(path) { + continue + } + fmt.Println("\tprocessing", dirEntry.Name()) + if err = run(path); err != nil { + return err + } + } + } + } + + return nil +} + +func run(absPath string) error { + testCase, err := newTestCase(absPath) + if err != nil { + return err + } + + transcriptSetting := fiatshamir.WithHash(testCase.Hash) + + var proof gkr.Proof + proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) + if err != nil { + return err + } + + if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { + return err + } + var outBytes []byte + if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { + if err = os.WriteFile(absPath, outBytes, 0); err != nil { + return err + } + } else { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) + if err != nil { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + if err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { + res := make(PrintableProof, len(proof)) + + for i := range proof { + + partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) + for k, partialK := range proof[i].PartialSumPolys { + partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) + } + + res[i] = PrintableSumcheckProof{ + FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), + PartialSumPolys: partialSumPolys, + } + } + return res, nil +} + +{{template "gkrTestVectors" .}} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl new file mode 100644 index 0000000000..0025b0164a --- /dev/null +++ b/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl @@ -0,0 +1,254 @@ +{{define "gkrTestVectors"}} + +{{$GkrPackagePrefix := select .OutsideGkrPackage "" "gkr."}} +{{$CheckOutputCorrectness := not .OutsideGkrPackage}} + +{{$Circuit := print $GkrPackagePrefix "Circuit"}} +{{$Gate := print $GkrPackagePrefix "Gate"}} +{{$Proof := print $GkrPackagePrefix "Proof"}} +{{$WireAssignment := print $GkrPackagePrefix "WireAssignment"}} +{{$Wire := print $GkrPackagePrefix "Wire"}} +{{$CircuitLayer := print $GkrPackagePrefix "CircuitLayer"}} + +{{$PackagePrefix := ""}} +{{- if .OutsideGkrPackage}} + {{$PackagePrefix = "gkr."}} +{{end}} + +type WireInfo struct { + Gate {{$PackagePrefix}}GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]{{$Circuit}}) + +func getCircuit(path string) ({{$Circuit}}, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit {{$Circuit}}) { + circuit = make({{$Circuit}}, len(c)) + for i := range c { + circuit[i].Gate = {{$PackagePrefix}}GetGate(c[i].Gate) + circuit[i].Inputs = make([]*{{$Wire}}, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...{{.ElementType}}) (res {{.ElementType}}) { + var sum {{.ElementType}} + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC {{$PackagePrefix}}GateName = "mimc" + SelectInput3 {{$PackagePrefix}}GateName = "select-input-3" +) + +func init() { + if err := {{$PackagePrefix}}RegisterGate(MiMC, mimcRound, 2, {{$PackagePrefix}}WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := {{$PackagePrefix}}RegisterGate(SelectInput3, func(input ...{{.ElementType}}) {{.ElementType}} { + return input[2] + }, 3, {{$PackagePrefix}}WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) ({{$Proof}}, error) { + proof := make({{$Proof}}, len(printable)) + for i := range printable { + finalEvalProof := []{{.ElementType}}(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]{{.ElementType}}, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := {{ setElement "finalEvalProof[k]" "finalEvalSlice.Index(k).Interface()" .ElementType}}; err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit {{$Circuit}} + Hash hash.Hash + Proof {{$Proof}} + FullAssignment {{$WireAssignment}} + InOutAssignment {{$WireAssignment}} + {{if .RetainTestCaseRawInfo}}Info TestCaseInfo{{end}} +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit {{$Circuit}} + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof {{$Proof}} + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make({{$WireAssignment}}) + inOutAssignment := make({{$WireAssignment}}) + + sorted := {{select .OutsideGkrPackage "t" "gkr.T"}}opologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []{{.ElementType}} + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + {{if not $CheckOutputCorrectness}} + info.Output = make([][]interface{}, 0, outI) + {{end}} + + for _, w := range sorted { + if w.IsOutput() { + {{if $CheckOutputCorrectness}} + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + {{else}} + info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) + {{end}} + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + {{if .RetainTestCaseRawInfo }}Info: info,{{end}} + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +{{end}} + +{{- define "setElement element value elementType"}} +{{- if eq .elementType "fr.Element"}} test_vector_utils.SetElement(&{{.element}}, {{.value}}) +{{- else if eq .elementType "small_rational.SmallRational"}} {{.element}}.SetInterface({{.value}}) +{{- else}} +{{print "\"UNEXPECTED TYPE" .elementType "\""}} +{{- end}} +{{- end}} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/registry.go.tmpl b/internal/generator/backend/template/gkr/registry.go.tmpl new file mode 100644 index 0000000000..75ca8d0267 --- /dev/null +++ b/internal/generator/backend/template/gkr/registry.go.tmpl @@ -0,0 +1,390 @@ +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "{{.FieldPackagePath}}" + {{- if .CanUseFFT }} + "{{.FieldPackagePath}}/fft"{{- else}} + "errors"{{- end }} + "{{.FieldPackagePath}}/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make({{.FieldPackageName}}.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]{{.ElementType}}, nbIn) + consts := make({{.FieldPackageName}}.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + {{- if .CanUseFFT }} + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := {{.FieldPackageName}}.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + {{- else }} + x := make({{.FieldPackageName}}.Vector, degreeBound) + x.MustSetRandom() + for i := range x { + fIn[0] = x[i] + for j := range consts { + fIn[j+1].Mul(&x[i], &consts[j]) + } + p[i] = f(fIn...) + } + + // obtain p's coefficients + p, err := interpolate(x, p) + if err != nil { + panic(err) + } + {{- end }} + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max)+1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +{{- if not .CanUseFFT }} +// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) +// Note that the runtime is O(len(X)³) +func interpolate(X, Y []{{.ElementType}}) (polynomial.Polynomial, error) { + if len(X) != len(Y) { + return nil, errors.New("X and Y must have the same length") + } + + // solve the system of equations by Gaussian elimination + augmentedRows := make([][]{{.ElementType}}, len(X)) // the last column is the Y values + for i := range augmentedRows { + augmentedRows[i] = make([]{{.ElementType}}, len(X)+1) + augmentedRows[i][0].SetOne() + augmentedRows[i][1].Set(&X[i]) + for j := 2; j < len(augmentedRows[i])-1; j++ { + augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) + } + augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) + } + + // make the upper triangle + for i := range len(augmentedRows) - 1 { + // use row i to eliminate the ith element in all rows below + var negInv {{.ElementType}} + if augmentedRows[i][i].IsZero() { + return nil, errors.New("singular matrix") + } + negInv.Inverse(&augmentedRows[i][i]) + negInv.Neg(&negInv) + for j := i + 1; j < len(augmentedRows); j++ { + var c {{.ElementType}} + c.Mul(&augmentedRows[j][i], &negInv) + // augmentedRows[j][i].SetZero() omitted + for k := i + 1; k < len(augmentedRows[i]); k++ { + var t {{.ElementType}} + t.Mul(&augmentedRows[i][k], &c) + augmentedRows[j][k].Add(&augmentedRows[j][k], &t) + } + } + } + + // back substitution + res := make(polynomial.Polynomial, len(X)) + for i := len(augmentedRows) - 1; i >= 0; i-- { + res[i] = augmentedRows[i][len(augmentedRows[i])-1] + for j := i + 1; j < len(augmentedRows[i])-1; j++ { + var t {{.ElementType}} + t.Mul(&res[j], &augmentedRows[i][j]) + res[i].Sub(&res[i], &t) + } + res[i].Div(&res[i], &augmentedRows[i][i]) + } + + return res, nil +} +{{- end }} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...{{.ElementType}}) {{.ElementType}} { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/sumcheck.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.go.tmpl new file mode 100644 index 0000000000..2ca7ec4975 --- /dev/null +++ b/internal/generator/backend/template/gkr/sumcheck.go.tmpl @@ -0,0 +1,163 @@ +import ( + "errors" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a {{.ElementType}}) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next({{.ElementType}}) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []{{.ElementType}}) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a {{.ElementType}}) {{.ElementType}} // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []{{.ElementType}}, remainingChallengeNames *[]string) ({{.ElementType}}, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return {{.ElementType}}{}, err + } + } + var res {{.ElementType}} + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff {{.ElementType}} + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]{{.ElementType}}, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff {{.ElementType}} + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]{{.ElementType}}, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl new file mode 100644 index 0000000000..2197d763a9 --- /dev/null +++ b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl @@ -0,0 +1,143 @@ +import ( + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "{{.GkrPackagePath}}/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []{{.ElementType}}) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []{{.ElementType}}{sum} +} + +func (c singleMultilinClaim) Combine({{.ElementType}}) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r {{.ElementType}}) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum {{.ElementType}} +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs {{.ElementType}}) {{.ElementType}} { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/generator/backend/template/gkr/test_vector_utils.go.tmpl b/internal/generator/backend/template/gkr/test_vector_utils.go.tmpl new file mode 100644 index 0000000000..5b7495eec3 --- /dev/null +++ b/internal/generator/backend/template/gkr/test_vector_utils.go.tmpl @@ -0,0 +1,220 @@ +import ( + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + "hash" + "reflect" + {{if eq .ElementType "fr.Element"}}"strings"{{- end}} +) + +func ToElement(i int64) *{{.ElementType}} { + var res {{.ElementType}} + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter {startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/{{.FieldPackageName}}.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/{{.FieldPackageName}}.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res {{.ElementType}} + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return {{.FieldPackageName}}.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return {{.FieldPackageName}}.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []{{.ElementType}} + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return {{.FieldPackageName}}.Bytes +} + +func (h *ListHash) BlockSize() int { +return {{.FieldPackageName}}.Bytes +} + +{{- if eq .ElementType "fr.Element"}} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} +{{- end}} + +{{- define "setElement element value elementType"}} +{{- if eq .elementType "fr.Element"}} SetElement(&{{.element}}, {{.value}}) +{{- else if eq .elementType "small_rational.SmallRational"}} {{.element}}.SetInterface({{.value}}) +{{- else}} + {{print "\"UNEXPECTED TYPE" .elementType "\""}} +{{- end}} +{{- end}} + +func SliceToElementSlice[T any](slice []T) ([]{{.ElementType}}, error) { + elementSlice := make([]{{.ElementType}}, len(slice)) + for i, v := range slice { + if _, err := {{setElement "elementSlice[i]" "v" .ElementType}}; err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []{{.ElementType}}, b []{{.ElementType}}) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]{{.ElementType}}, b [][]{{.ElementType}}) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i],b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i],b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *{{.ElementType}}) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().({{.ElementType}}) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bls12-377/sumcheck/sumcheck.go b/internal/gkr/bls12-377/sumcheck/sumcheck.go new file mode 100644 index 0000000000..d7be95ccb8 --- /dev/null +++ b/internal/gkr/bls12-377/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bls12-377/sumcheck/sumcheck_test.go b/internal/gkr/bls12-377/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..00d6ffdf28 --- /dev/null +++ b/internal/gkr/bls12-377/sumcheck/sumcheck_test.go @@ -0,0 +1,150 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bls12-377/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bls12-377/test_vector_utils/test_vector_utils.go b/internal/gkr/bls12-377/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..da958ce237 --- /dev/null +++ b/internal/gkr/bls12-377/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bls12-381/sumcheck/sumcheck.go b/internal/gkr/bls12-381/sumcheck/sumcheck.go new file mode 100644 index 0000000000..6ecb1722a6 --- /dev/null +++ b/internal/gkr/bls12-381/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bls12-381/sumcheck/sumcheck_test.go b/internal/gkr/bls12-381/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..40664ee4eb --- /dev/null +++ b/internal/gkr/bls12-381/sumcheck/sumcheck_test.go @@ -0,0 +1,150 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bls12-381/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bls12-381/test_vector_utils/test_vector_utils.go b/internal/gkr/bls12-381/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..b1a74e1bae --- /dev/null +++ b/internal/gkr/bls12-381/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bls24-315/sumcheck/sumcheck.go b/internal/gkr/bls24-315/sumcheck/sumcheck.go new file mode 100644 index 0000000000..4d6fd2a15a --- /dev/null +++ b/internal/gkr/bls24-315/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bls24-315/sumcheck/sumcheck_test.go b/internal/gkr/bls24-315/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..f1a86c12f4 --- /dev/null +++ b/internal/gkr/bls24-315/sumcheck/sumcheck_test.go @@ -0,0 +1,150 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bls24-315/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bls24-315/test_vector_utils/test_vector_utils.go b/internal/gkr/bls24-315/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..59836542d5 --- /dev/null +++ b/internal/gkr/bls24-315/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bls24-317/sumcheck/sumcheck.go b/internal/gkr/bls24-317/sumcheck/sumcheck.go new file mode 100644 index 0000000000..90dc85ffdf --- /dev/null +++ b/internal/gkr/bls24-317/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bls24-317/sumcheck/sumcheck_test.go b/internal/gkr/bls24-317/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..0efca63df7 --- /dev/null +++ b/internal/gkr/bls24-317/sumcheck/sumcheck_test.go @@ -0,0 +1,150 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bls24-317/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bls24-317/test_vector_utils/test_vector_utils.go b/internal/gkr/bls24-317/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..eef6e7dea9 --- /dev/null +++ b/internal/gkr/bls24-317/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bn254/sumcheck/sumcheck.go b/internal/gkr/bn254/sumcheck/sumcheck.go new file mode 100644 index 0000000000..821399b4f3 --- /dev/null +++ b/internal/gkr/bn254/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bn254/sumcheck/sumcheck_test.go b/internal/gkr/bn254/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..cd7259736e --- /dev/null +++ b/internal/gkr/bn254/sumcheck/sumcheck_test.go @@ -0,0 +1,150 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bn254/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bn254/test_vector_utils/test_vector_utils.go b/internal/gkr/bn254/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..93679ae858 --- /dev/null +++ b/internal/gkr/bn254/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bw6-633/sumcheck/sumcheck.go b/internal/gkr/bw6-633/sumcheck/sumcheck.go new file mode 100644 index 0000000000..8a8c25f3c5 --- /dev/null +++ b/internal/gkr/bw6-633/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bw6-633/sumcheck/sumcheck_test.go b/internal/gkr/bw6-633/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..403839293f --- /dev/null +++ b/internal/gkr/bw6-633/sumcheck/sumcheck_test.go @@ -0,0 +1,150 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bw6-633/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bw6-633/test_vector_utils/test_vector_utils.go b/internal/gkr/bw6-633/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..ea84f4f255 --- /dev/null +++ b/internal/gkr/bw6-633/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bw6-761/sumcheck/sumcheck.go b/internal/gkr/bw6-761/sumcheck/sumcheck.go new file mode 100644 index 0000000000..ce9800a258 --- /dev/null +++ b/internal/gkr/bw6-761/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bw6-761/sumcheck/sumcheck_test.go b/internal/gkr/bw6-761/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..2f95dc376e --- /dev/null +++ b/internal/gkr/bw6-761/sumcheck/sumcheck_test.go @@ -0,0 +1,150 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bw6-761/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bw6-761/test_vector_utils/test_vector_utils.go b/internal/gkr/bw6-761/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..c3f063ff58 --- /dev/null +++ b/internal/gkr/bw6-761/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} From cb17ed92de0493f2cf96fe6c12eb68c05d5ef27c Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 1 Apr 2025 16:47:35 -0500 Subject: [PATCH 32/62] small-rational --- .../template/gkr/sumcheck.test.go.tmpl | 1 - internal/small_rational/polynomial/doc.go | 5 + .../small_rational/polynomial/multilin.go | 176 +++++++ .../small_rational/polynomial/polynomial.go | 307 ++++++++++++ internal/small_rational/polynomial/pool.go | 29 ++ internal/small_rational/small-rational.go | 454 ++++++++++++++++++ .../small_rational/small_rational_test.go | 115 +++++ .../test_vector_utils/test_vector_utils.go | 185 +++++++ internal/small_rational/vector.go | 9 + 9 files changed, 1280 insertions(+), 1 deletion(-) create mode 100644 internal/small_rational/polynomial/doc.go create mode 100644 internal/small_rational/polynomial/multilin.go create mode 100644 internal/small_rational/polynomial/polynomial.go create mode 100644 internal/small_rational/polynomial/pool.go create mode 100644 internal/small_rational/small-rational.go create mode 100644 internal/small_rational/small_rational_test.go create mode 100644 internal/small_rational/test_vector_utils/test_vector_utils.go create mode 100644 internal/small_rational/vector.go diff --git a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl index 2197d763a9..e599869be2 100644 --- a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl @@ -113,7 +113,6 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) polys := [][]uint64{ {1, 2, 3, 4}, // 1 + 2X₁ + X₂ diff --git a/internal/small_rational/polynomial/doc.go b/internal/small_rational/polynomial/doc.go new file mode 100644 index 0000000000..95ba2f135f --- /dev/null +++ b/internal/small_rational/polynomial/doc.go @@ -0,0 +1,5 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Package polynomial provides polynomial methods and commitment schemes. +package polynomial diff --git a/internal/small_rational/polynomial/multilin.go b/internal/small_rational/polynomial/multilin.go new file mode 100644 index 0000000000..7002cdc811 --- /dev/null +++ b/internal/small_rational/polynomial/multilin.go @@ -0,0 +1,176 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +package polynomial + +import ( + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/utils" + "math/bits" +) + +// MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial +// The variables are X₁ through Xₙ where n = log(len(.)) +// .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ) +// It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial +type MultiLin []small_rational.SmallRational + +// Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r +func (m *MultiLin) Fold(r small_rational.SmallRational) { + mid := len(*m) / 2 + + bottom, top := (*m)[:mid], (*m)[mid:] + + var t small_rational.SmallRational // no need to update the top part + + // updating bookkeeping table + // knowing that the polynomial f ∈ (k[X₂, ..., Xₙ])[X₁] is linear, we would get f(r) = f(0) + r(f(1) - f(0)) + // the following loop computes the evaluations of f(r) accordingly: + // f(r, b₂, ..., bₙ) = f(0, b₂, ..., bₙ) + r(f(1, b₂, ..., bₙ) - f(0, b₂, ..., bₙ)) + for i := 0; i < mid; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + t.Sub(&top[i], &bottom[i]) + t.Mul(&t, &r) + bottom[i].Add(&bottom[i], &t) + } + + *m = (*m)[:mid] +} + +func (m *MultiLin) FoldParallel(r small_rational.SmallRational) utils.Task { + mid := len(*m) / 2 + bottom, top := (*m)[:mid], (*m)[mid:] + + *m = bottom + + return func(start, end int) { + var t small_rational.SmallRational // no need to update the top part + for i := start; i < end; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + t.Sub(&top[i], &bottom[i]) + t.Mul(&t, &r) + bottom[i].Add(&bottom[i], &t) + } + } +} + +func (m MultiLin) Sum() small_rational.SmallRational { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + +// Evaluate extrapolate the value of the multilinear polynomial corresponding to m +// on the given coordinates +func (m MultiLin) Evaluate(coordinates []small_rational.SmallRational, p *Pool) small_rational.SmallRational { + // Folding is a mutating operation + bkCopy := _clone(m, p) + + // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) + for _, r := range coordinates { + bkCopy.Fold(r) + } + + result := bkCopy[0] + + _dump(bkCopy, p) + return result +} + +// Clone creates a deep copy of a bookkeeping table. +// Both multilinear interpolation and sumcheck require folding an underlying +// array, but folding changes the array. To do both one requires a deep copy +// of the bookkeeping table. +func (m MultiLin) Clone() MultiLin { + res := make(MultiLin, len(m)) + copy(res, m) + return res +} + +// Add two bookKeepingTables +func (m *MultiLin) Add(left, right MultiLin) { + size := len(left) + // Check that left and right have the same size + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") + } + + // Add elementwise + for i := 0; i < size; i++ { + (*m)[i].Add(&left[i], &right[i]) + } +} + +// EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) +// where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates +// +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// +// In other words the polynomial evaluated here is the multilinear extrapolation of +// one that evaluates to q' == h' for vectors q', h' of binary values +func EvalEq(q, h []small_rational.SmallRational) small_rational.SmallRational { + var res, nxt, one, sum small_rational.SmallRational + one.SetOne() + for i := 0; i < len(q); i++ { + nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ + nxt.Double(&nxt) // nxt <- 2 * qᵢ * hᵢ + nxt.Add(&nxt, &one) // nxt <- 1 + 2 * qᵢ * hᵢ + sum.Add(&q[i], &h[i]) // sum <- qᵢ + hᵢ TODO: Why not subtract one by one from nxt? More parallel? + + if i == 0 { + res.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + } else { + nxt.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + res.Mul(&res, &nxt) // res <- res * nxt + } + } + return res +} + +// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] +func (m *MultiLin) Eq(q []small_rational.SmallRational) { + n := len(q) + + if len(*m) != 1<= 0; i-- { + res.Mul(&res, v) + res.Add(&res, &(*p)[i]) + } + + return res +} + +// Clone returns a copy of the polynomial +func (p *Polynomial) Clone() Polynomial { + _p := make(Polynomial, len(*p)) + copy(_p, *p) + return _p +} + +// Set to another polynomial +func (p *Polynomial) Set(p1 Polynomial) { + if len(*p) != len(p1) { + *p = p1.Clone() + return + } + + for i := 0; i < len(p1); i++ { + (*p)[i].Set(&p1[i]) + } +} + +// AddConstantInPlace adds a constant to the polynomial, modifying p +func (p *Polynomial) AddConstantInPlace(c *small_rational.SmallRational) { + for i := 0; i < len(*p); i++ { + (*p)[i].Add(&(*p)[i], c) + } +} + +// SubConstantInPlace subs a constant to the polynomial, modifying p +func (p *Polynomial) SubConstantInPlace(c *small_rational.SmallRational) { + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&(*p)[i], c) + } +} + +// ScaleInPlace multiplies p by v, modifying p +func (p *Polynomial) ScaleInPlace(c *small_rational.SmallRational) { + for i := 0; i < len(*p); i++ { + (*p)[i].Mul(&(*p)[i], c) + } +} + +// Scale multiplies p0 by v, storing the result in p +func (p *Polynomial) Scale(c *small_rational.SmallRational, p0 Polynomial) { + if len(*p) != len(p0) { + *p = make(Polynomial, len(p0)) + } + for i := 0; i < len(p0); i++ { + (*p)[i].Mul(c, &p0[i]) + } +} + +// Add adds p1 to p2 +// This function allocates a new slice unless p == p1 or p == p2 +func (p *Polynomial) Add(p1, p2 Polynomial) *Polynomial { + + bigger := p1 + smaller := p2 + if len(bigger) < len(smaller) { + bigger, smaller = smaller, bigger + } + + if len(*p) == len(bigger) && (&(*p)[0] == &bigger[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &smaller[i]) + } + return p + } + + if len(*p) == len(smaller) && (&(*p)[0] == &smaller[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &bigger[i]) + } + *p = append(*p, bigger[len(smaller):]...) + return p + } + + res := make(Polynomial, len(bigger)) + copy(res, bigger) + for i := 0; i < len(smaller); i++ { + res[i].Add(&res[i], &smaller[i]) + } + *p = res + return p +} + +// Sub subtracts p2 from p1 +// TODO make interface more consistent with Add +func (p *Polynomial) Sub(p1, p2 Polynomial) *Polynomial { + if len(p1) != len(p2) || len(p2) != len(*p) { + return nil + } + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&p1[i], &p2[i]) + } + return p +} + +// Equal checks equality between two polynomials +func (p *Polynomial) Equal(p1 Polynomial) bool { + if (*p == nil) != (p1 == nil) { + return false + } + + if len(*p) != len(p1) { + return false + } + + for i := range p1 { + if !(*p)[i].Equal(&p1[i]) { + return false + } + } + + return true +} + +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() + } +} + +func (p Polynomial) Text(base int) string { + + var builder strings.Builder + + first := true + for d := len(p) - 1; d >= 0; d-- { + if p[d].IsZero() { + continue + } + + pD := p[d] + pDText := pD.Text(base) + + initialLen := builder.Len() + + if pDText[0] == '-' { + pDText = pDText[1:] + if first { + builder.WriteString("-") + } else { + builder.WriteString(" - ") + } + } else if !first { + builder.WriteString(" + ") + } + + first = false + + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) + } + + if builder.Len()-initialLen > 10 { + builder.WriteString("×") + } + + if d != 0 { + builder.WriteString("X") + } + if d > 1 { + builder.WriteString( + utils.ToSuperscript(strconv.Itoa(d)), + ) + } + + } + + if first { + return "0" + } + + return builder.String() +} + +// InterpolateOnRange maps vector v to polynomial f +// such that f(i) = v[i] for 0 ≤ i < len(v). +// len(f) = len(v) and deg(f) ≤ len(v) - 1 +func InterpolateOnRange(v []small_rational.SmallRational) Polynomial { + nEvals := uint8(len(v)) + if int(nEvals) != len(v) { + panic("interpolation method too inefficient for nEvals > 255") + } + lagrange := getLagrangeBasis(nEvals) + + var res Polynomial + res.Scale(&v[0], lagrange[0]) + + temp := make(Polynomial, nEvals) + + for i := uint8(1); i < nEvals; i++ { + temp.Scale(&v[i], lagrange[i]) + res.Add(res, temp) + } + + return res +} + +// lagrange bases used by InterpolateOnRange +var lagrangeBasis sync.Map + +func getLagrangeBasis(domainSize uint8) []Polynomial { + if res, ok := lagrangeBasis.Load(domainSize); ok { + return res.([]Polynomial) + } + + // not found. compute + var res []Polynomial + if domainSize >= 2 { + res = computeLagrangeBasis(domainSize) + } else if domainSize == 1 { + res = []Polynomial{make(Polynomial, 1)} + res[0][0].SetOne() + } + lagrangeBasis.Store(domainSize, res) + + return res +} + +// computeLagrangeBasis precomputes in explicit coefficient form for each 0 ≤ l < domainSize the polynomial +// pₗ := X (X-1) ... (X-l-1) (X-l+1) ... (X - domainSize + 1) / ( l (l-1) ... 2 (-1) ... (l - domainSize +1) ) +// Note that pₗ(l) = 1 and pₗ(n) = 0 if 0 ≤ l < domainSize, n ≠ l +func computeLagrangeBasis(domainSize uint8) []Polynomial { + + constTerms := make([]small_rational.SmallRational, domainSize) + for i := uint8(0); i < domainSize; i++ { + constTerms[i].SetInt64(-int64(i)) + } + + res := make([]Polynomial, domainSize) + multScratch := make(Polynomial, domainSize-1) + + // compute pₗ + for l := uint8(0); l < domainSize; l++ { + + // TODO @Tabaie Optimize this with some trees? O(log(domainSize)) polynomial mults instead of O(domainSize)? Then again it would be fewer big poly mults vs many small poly mults + d := uint8(0) //d is the current degree of res + for i := uint8(0); i < domainSize; i++ { + if i == l { + continue + } + if d == 0 { + res[l] = make(Polynomial, domainSize) + res[l][domainSize-2] = constTerms[i] + res[l][domainSize-1].SetOne() + } else { + current := res[l][domainSize-d-2:] + timesConst := multScratch[domainSize-d-2:] + + timesConst.Scale(&constTerms[i], current[1:]) //TODO: Directly double and add since constTerms are tiny? (even less than 4 bits) + nonLeading := current[0 : d+1] + + nonLeading.Add(nonLeading, timesConst) + + } + d++ + } + + } + + // We have pₗ(i≠l)=0. Now scale so that pₗ(l)=1 + // Replace the constTerms with norms + for l := uint8(0); l < domainSize; l++ { + constTerms[l].Neg(&constTerms[l]) + constTerms[l] = res[l].Eval(&constTerms[l]) + } + constTerms = small_rational.BatchInvert(constTerms) + for l := uint8(0); l < domainSize; l++ { + res[l].ScaleInPlace(&constTerms[l]) + } + + return res +} diff --git a/internal/small_rational/polynomial/pool.go b/internal/small_rational/polynomial/pool.go new file mode 100644 index 0000000000..333029f17d --- /dev/null +++ b/internal/small_rational/polynomial/pool.go @@ -0,0 +1,29 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +package polynomial + +import ( + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" +) + +// Do as little as possible to instantiate the interface +type Pool struct { +} + +func NewPool(...int) (pool Pool) { + return Pool{} +} + +func (p *Pool) Make(n int) []small_rational.SmallRational { + return make([]small_rational.SmallRational, n) +} + +func (p *Pool) Dump(...[]small_rational.SmallRational) { +} + +func (p *Pool) Clone(slice []small_rational.SmallRational) []small_rational.SmallRational { + res := p.Make(len(slice)) + copy(res, slice) + return res +} diff --git a/internal/small_rational/small-rational.go b/internal/small_rational/small-rational.go new file mode 100644 index 0000000000..39ee2bfe2e --- /dev/null +++ b/internal/small_rational/small-rational.go @@ -0,0 +1,454 @@ +package small_rational + +import ( + "crypto/rand" + "fmt" + "math/big" + "strconv" + "strings" +) + +const Bytes = 64 + +// SmallRational implements the rational field, used to generate field agnostic test vectors. +// It is not optimized for performance, so it is best used sparingly. +type SmallRational struct { + text string //For debugging purposes + numerator big.Int + denominator big.Int // By convention, denominator == 0 also indicates zero +} + +var smallPrimes = []*big.Int{ + big.NewInt(2), big.NewInt(3), big.NewInt(5), + big.NewInt(7), big.NewInt(11), big.NewInt(13), +} + +func bigDivides(p, a *big.Int) bool { + var remainder big.Int + remainder.Mod(a, p) + return remainder.BitLen() == 0 +} + +func (z *SmallRational) UpdateText() { + z.text = z.Text(10) +} + +func (z *SmallRational) simplify() { + + if z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 { + return + } + + var num, den big.Int + + num.Set(&z.numerator) + den.Set(&z.denominator) + + for _, p := range smallPrimes { + for bigDivides(p, &num) && bigDivides(p, &den) { + num.Div(&num, p) + den.Div(&den, p) + } + } + + if bigDivides(&den, &num) { + num.Div(&num, &den) + den.SetInt64(1) + } + + z.numerator = num + z.denominator = den + +} +func (z *SmallRational) Square(x *SmallRational) *SmallRational { + var num, den big.Int + num.Mul(&x.numerator, &x.numerator) + den.Mul(&x.denominator, &x.denominator) + + z.numerator = num + z.denominator = den + + z.UpdateText() + + return z +} + +func (z *SmallRational) String() string { + z.text = z.Text(10) + return z.text +} + +func (z *SmallRational) Add(x, y *SmallRational) *SmallRational { + if x.denominator.BitLen() == 0 { + *z = *y + } else if y.denominator.BitLen() == 0 { + *z = *x + } else { + //TODO: Exploit cases where one denom divides the other + var numDen, denNum big.Int + numDen.Mul(&x.numerator, &y.denominator) + denNum.Mul(&x.denominator, &y.numerator) + + numDen.Add(&denNum, &numDen) + z.numerator = numDen //to avoid shallow copy problems + + denNum.Mul(&x.denominator, &y.denominator) + z.denominator = denNum + z.simplify() + } + + z.UpdateText() + + return z +} + +func (z *SmallRational) IsZero() bool { + return z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 +} + +func (z *SmallRational) Inverse(x *SmallRational) *SmallRational { + if x.IsZero() { + *z = *x + } else { + *z = SmallRational{numerator: x.denominator, denominator: x.numerator} + z.UpdateText() + } + + return z +} + +func (z *SmallRational) Neg(x *SmallRational) *SmallRational { + z.numerator.Neg(&x.numerator) + z.denominator = x.denominator + + if x.text == "" { + x.UpdateText() + } + + if x.text[0] == '-' { + z.text = x.text[1:] + } else { + z.text = "-" + x.text + } + + return z +} + +func (z *SmallRational) Double(x *SmallRational) *SmallRational { + + var y big.Int + + if x.denominator.Bit(0) == 0 { + z.numerator = x.numerator + y.Rsh(&x.denominator, 1) + z.denominator = y + } else { + y.Lsh(&x.numerator, 1) + z.numerator = y + z.denominator = x.denominator + } + + z.UpdateText() + + return z +} + +func (z *SmallRational) Sign() int { + return z.numerator.Sign() * z.denominator.Sign() +} + +func (z *SmallRational) MarshalJSON() ([]byte, error) { + return []byte(z.String()), nil +} + +func (z *SmallRational) UnmarshalJson(data []byte) error { + _, err := z.SetInterface(string(data)) + return err +} + +func (z *SmallRational) Equal(x *SmallRational) bool { + return z.Cmp(x) == 0 +} + +func (z *SmallRational) Sub(x, y *SmallRational) *SmallRational { + var yNeg SmallRational + yNeg.Neg(y) + z.Add(x, &yNeg) + + z.UpdateText() + return z +} + +func (z *SmallRational) Cmp(x *SmallRational) int { + zSign, xSign := z.Sign(), x.Sign() + + if zSign > xSign { + return 1 + } + if zSign < xSign { + return -1 + } + + var Z, X big.Int + Z.Mul(&z.numerator, &x.denominator) + X.Mul(&x.numerator, &z.denominator) + + Z.Abs(&Z) + X.Abs(&X) + + return Z.Cmp(&X) * zSign + +} + +func BatchInvert(a []SmallRational) []SmallRational { + res := make([]SmallRational, len(a)) + for i := range a { + res[i].Inverse(&a[i]) + } + return res +} + +func (z *SmallRational) Mul(x, y *SmallRational) *SmallRational { + var num, den big.Int + + num.Mul(&x.numerator, &y.numerator) + den.Mul(&x.denominator, &y.denominator) + + z.numerator = num + z.denominator = den + + z.simplify() + z.UpdateText() + return z +} + +func (z *SmallRational) Div(x, y *SmallRational) *SmallRational { + var num, den big.Int + + num.Mul(&x.numerator, &y.denominator) + den.Mul(&x.denominator, &y.numerator) + + z.numerator = num + z.denominator = den + + z.simplify() + z.UpdateText() + return z +} + +func (z *SmallRational) Halve() *SmallRational { + if z.numerator.Bit(0) == 0 { + z.numerator.Rsh(&z.numerator, 1) + } else { + z.denominator.Lsh(&z.denominator, 1) + } + + z.simplify() + z.UpdateText() + return z +} + +func (z *SmallRational) SetOne() *SmallRational { + return z.SetInt64(1) +} + +func (z *SmallRational) SetZero() *SmallRational { + return z.SetInt64(0) +} + +func (z *SmallRational) SetInt64(i int64) *SmallRational { + z.numerator = *big.NewInt(i) + z.denominator = *big.NewInt(1) + z.text = strconv.FormatInt(i, 10) + return z +} + +func (z *SmallRational) SetRandom() (*SmallRational, error) { + + bytes := make([]byte, 1) + n, err := rand.Read(bytes) + if err != nil { + return nil, err + } + if n != len(bytes) { + return nil, fmt.Errorf("%d bytes read instead of %d", n, len(bytes)) + } + + z.numerator = *big.NewInt(int64(bytes[0]%16) - 8) + z.denominator = *big.NewInt(int64((bytes[0]) / 16)) + + z.simplify() + z.UpdateText() + + return z, nil +} + +func (z *SmallRational) MustSetRandom() *SmallRational { + if _, err := z.SetRandom(); err != nil { + panic(err) + } + return z +} + +func (z *SmallRational) SetUint64(i uint64) { + var num big.Int + num.SetUint64(i) + z.numerator = num + z.denominator = *big.NewInt(1) + z.text = strconv.FormatUint(i, 10) +} + +func (z *SmallRational) IsOne() bool { + return z.numerator.Cmp(&z.denominator) == 0 && z.denominator.BitLen() != 0 +} + +func (z *SmallRational) Text(base int) string { + + if z.denominator.BitLen() == 0 { + return "0" + } + + if z.denominator.Sign() < 0 { + var num, den big.Int + num.Neg(&z.numerator) + den.Neg(&z.denominator) + z.numerator = num + z.denominator = den + } + + if bigDivides(&z.denominator, &z.numerator) { + var num big.Int + num.Div(&z.numerator, &z.denominator) + z.numerator = num + z.denominator = *big.NewInt(1) + } + + numerator := z.numerator.Text(base) + + if z.denominator.IsInt64() && z.denominator.Int64() == 1 { + return numerator + } + + return numerator + "/" + z.denominator.Text(base) +} + +func (z *SmallRational) Set(x *SmallRational) *SmallRational { + *z = *x // shallow copy is safe because ops are never in place + return z +} + +func (z *SmallRational) SetInterface(x interface{}) (*SmallRational, error) { + + switch v := x.(type) { + case *SmallRational: + *z = *v + case SmallRational: + *z = v + case int64: + z.SetInt64(v) + case int: + z.SetInt64(int64(v)) + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + case string: + z.text = v + sep := strings.Split(v, "/") + switch len(sep) { + case 1: + if asInt, err := strconv.Atoi(sep[0]); err == nil { + z.SetInt64(int64(asInt)) + } else { + return nil, err + } + case 2: + var err error + var num, denom int + num, err = strconv.Atoi(sep[0]) + if err != nil { + return nil, err + } + denom, err = strconv.Atoi(sep[1]) + if err != nil { + return nil, err + } + z.numerator = *big.NewInt(int64(num)) + z.denominator = *big.NewInt(int64(denom)) + default: + return nil, fmt.Errorf("cannot parse \"%s\"", v) + } + default: + return nil, fmt.Errorf("cannot parse %T", x) + } + + return z, nil +} + +func bigIntToBytesSigned(dst []byte, src big.Int) { + src.FillBytes(dst[1:]) + dst[0] = 0 + if src.Sign() < 0 { + dst[0] = 255 + } +} + +func (z *SmallRational) Bytes() [Bytes]byte { + var res [Bytes]byte + bigIntToBytesSigned(res[:Bytes/2], z.numerator) + bigIntToBytesSigned(res[Bytes/2:], z.denominator) + return res +} + +func bytesToBigIntSigned(src []byte) big.Int { + var res big.Int + res.SetBytes(src[1:]) + if src[0] != 0 { + res.Neg(&res) + } + return res +} + +// BigInt returns sets dst to the value of z if it is an integer. +// if z is not an integer, nil is returned. +// if the given dst is nil, the address of the numerator is returned. +// if the given dst is non-nil, it is returned. +func (z *SmallRational) BigInt(dst *big.Int) *big.Int { + if z.denominator.Cmp(big.NewInt(1)) != 0 { + return nil + } + if dst == nil { + return &z.numerator + } + dst.Set(&z.numerator) + return dst +} + +func (z *SmallRational) SetBytes(b []byte) { + if len(b) > Bytes/2 { + z.numerator = bytesToBigIntSigned(b[:Bytes/2]) + z.denominator = bytesToBigIntSigned(b[Bytes/2:]) + } else { + z.numerator.SetBytes(b) + z.denominator.SetInt64(1) + } + z.simplify() + z.UpdateText() +} + +func One() SmallRational { + res := SmallRational{ + text: "1", + } + res.numerator.SetInt64(1) + res.denominator.SetInt64(1) + return res +} + +func Modulus() *big.Int { + res := big.NewInt(1) + res.Lsh(res, 64) + return res +} diff --git a/internal/small_rational/small_rational_test.go b/internal/small_rational/small_rational_test.go new file mode 100644 index 0000000000..6d7733ea76 --- /dev/null +++ b/internal/small_rational/small_rational_test.go @@ -0,0 +1,115 @@ +package small_rational + +import ( + "github.com/stretchr/testify/assert" + "math/big" + "testing" +) + +func TestBigDivides(t *testing.T) { + assert.True(t, bigDivides(big.NewInt(-1), big.NewInt(4))) + assert.False(t, bigDivides(big.NewInt(-3), big.NewInt(4))) +} + +func TestCmp(t *testing.T) { + + cases := make([]SmallRational, 36) + + for i := int64(0); i < 9; i++ { + if i%2 == 0 { + cases[4*i].numerator.SetInt64((i - 4) / 2) + cases[4*i].denominator.SetInt64(1) + } else { + cases[4*i].numerator.SetInt64(i - 4) + cases[4*i].denominator.SetInt64(2) + } + + cases[4*i+1].numerator.Neg(&cases[4*i].numerator) + cases[4*i+1].denominator.Neg(&cases[4*i].denominator) + + cases[4*i+2].numerator.Lsh(&cases[4*i].numerator, 1) + cases[4*i+2].denominator.Lsh(&cases[4*i].denominator, 1) + + cases[4*i+3].numerator.Neg(&cases[4*i+2].numerator) + cases[4*i+3].denominator.Neg(&cases[4*i+2].denominator) + } + + for i := range cases { + for j := range cases { + I, J := i/4, j/4 + var expectedCmp int + cmp := cases[i].Cmp(&cases[j]) + if I < J { + expectedCmp = -1 + } else if I == J { + expectedCmp = 0 + } else { + expectedCmp = 1 + } + assert.Equal(t, expectedCmp, cmp, "comparing index %d, index %d", i, j) + } + } + + zeroIndex := len(cases) / 8 + var weirdZero SmallRational + for i := range cases { + I := i / 4 + var expectedCmp int + cmp := cases[i].Cmp(&weirdZero) + cmpNeg := weirdZero.Cmp(&cases[i]) + if I < zeroIndex { + expectedCmp = -1 + } else if I == zeroIndex { + expectedCmp = 0 + } else { + expectedCmp = 1 + } + + assert.Equal(t, expectedCmp, cmp, "comparing index %d, 0/0", i) + assert.Equal(t, -expectedCmp, cmpNeg, "comparing 0/0, index %d", i) + } +} + +func TestDouble(t *testing.T) { + values := []interface{}{1, 2, 3, 4, 5, "2/3", "3/2", "-3/-2"} + valsDoubled := []interface{}{2, 4, 6, 8, 10, "-4/-3", 3, 3} + + for i := range values { + var v, vDoubled, vDoubledExpected SmallRational + _, err := v.SetInterface(values[i]) + assert.NoError(t, err) + _, err = vDoubledExpected.SetInterface(valsDoubled[i]) + assert.NoError(t, err) + vDoubled.Double(&v) + assert.True(t, vDoubled.Equal(&vDoubledExpected), + "mismatch at %d: expected 2×%s = %s, saw %s", i, v.text, vDoubledExpected.text, vDoubled.text) + + } +} + +func TestOperandConstancy(t *testing.T) { + var p0, p, pPure SmallRational + p0.SetInt64(1) + p.SetInt64(-3) + pPure.SetInt64(-3) + + res := p + res.Add(&res, &p0) + assert.True(t, p.Equal(&pPure)) +} + +func TestSquare(t *testing.T) { + var two, four, x SmallRational + two.SetInt64(2) + four.SetInt64(4) + + x.Square(&two) + + assert.True(t, x.Equal(&four), "expected 4, saw %s", x.Text(10)) +} + +func TestSetBytes(t *testing.T) { + var c SmallRational + c.SetBytes([]byte("firstChallenge.0")) + +} diff --git a/internal/small_rational/test_vector_utils/test_vector_utils.go b/internal/small_rational/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..9e91fe7c67 --- /dev/null +++ b/internal/small_rational/test_vector_utils/test_vector_utils.go @@ -0,0 +1,185 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "hash" + "reflect" +) + +func ToElement(i int64) *small_rational.SmallRational { + var res small_rational.SmallRational + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res small_rational.SmallRational + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return small_rational.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return small_rational.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []small_rational.SmallRational + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return small_rational.Bytes +} + +func (h *ListHash) BlockSize() int { + return small_rational.Bytes +} + +func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *small_rational.SmallRational) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(small_rational.SmallRational) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/small_rational/vector.go b/internal/small_rational/vector.go new file mode 100644 index 0000000000..07fcc3afff --- /dev/null +++ b/internal/small_rational/vector.go @@ -0,0 +1,9 @@ +package small_rational + +type Vector []SmallRational + +func (v Vector) MustSetRandom() { + for i := range v { + v[i].MustSetRandom() + } +} From a02d4b33d1eaf0448b38bf4763cd5016084f1619 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 1 Apr 2025 17:29:25 -0500 Subject: [PATCH 33/62] feat: generate sumcheck test vecs --- internal/generator/backend/main.go | 111 ++++++---- .../backend/sumcheck/test_vectors/main.go | 8 +- .../gkr/bls12-377/sumcheck/sumcheck_test.go | 1 - .../gkr/bls12-381/sumcheck/sumcheck_test.go | 1 - .../gkr/bls24-315/sumcheck/sumcheck_test.go | 1 - .../gkr/bls24-317/sumcheck/sumcheck_test.go | 1 - internal/gkr/bn254/sumcheck/sumcheck_test.go | 1 - .../gkr/bw6-633/sumcheck/sumcheck_test.go | 1 - .../gkr/bw6-761/sumcheck/sumcheck_test.go | 1 - .../gkr/small_rational/sumcheck/sumcheck.go | 170 +++++++++++++++ .../small_rational/sumcheck/sumcheck_test.go | 149 +++++++++++++ .../test_vector_utils/test_vector_utils.go | 185 ++++++++++++++++ internal/gkr/test_vectors/sumcheck/main.go | 200 ++++++++++++++++++ .../gkr/test_vectors/sumcheck/vectors.json | 56 +++++ .../small_rational/polynomial/multilin.go | 2 +- .../small_rational/polynomial/polynomial.go | 2 +- .../test_vector_utils/test_vector_utils.go | 5 +- 17 files changed, 840 insertions(+), 55 deletions(-) create mode 100644 internal/gkr/small_rational/sumcheck/sumcheck.go create mode 100644 internal/gkr/small_rational/sumcheck/sumcheck_test.go create mode 100644 internal/gkr/small_rational/test_vector_utils/test_vector_utils.go create mode 100644 internal/gkr/test_vectors/sumcheck/main.go create mode 100644 internal/gkr/test_vectors/sumcheck/vectors.json diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 81ebfecfa1..cfe3c62cb8 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -90,11 +90,7 @@ func main() { tiny_field, } - const ( - importCurve = "../imports.go.tmpl" - repoRoot = "../../../" - ) - + const importCurve = "../imports.go.tmpl" var wg sync.WaitGroup for _, d := range data { @@ -134,43 +130,16 @@ func main() { if d.Curve != "tinyfield" { // solver and proof delegator TODO merge with "backend" below entries = []bavard.Entry{{File: filepath.Join(csDir, "gkr.go"), Templates: []string{"gkr.go.tmpl", importCurve}}} - if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { - panic(err) - } + err := bgen.Generate(d, "cs", "./template/representations/", entries...) + assertNoError(err) curvePackageName := strings.ToLower(d.Curve) - cfg := struct { - config.FieldDependency - GkrPackagePath string - }{ - config.FieldDependency{ - ElementType: "fr.Element", - FieldPackageName: "fr", - FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + curvePackageName + "/fr", - }, - "github.com/consensys/gnark/internal/gkr/" + curvePackageName, - } - gkrPackageDirRelPath := filepath.Join(repoRoot+"internal/gkr/", curvePackageName) - - // test vector utils - packagePath := filepath.Join(gkrPackageDirRelPath, "test_vector_utils") - entries = []bavard.Entry{ - {File: filepath.Join(packagePath, "test_vector_utils.go"), Templates: []string{"test_vector_utils.go.tmpl"}}, - } - - if err := bgen.Generate(cfg, "test_vector_utils", "./template/gkr/", entries...); err != nil { - panic(err) - } - - // sumcheck backend - packagePath = filepath.Join(gkrPackageDirRelPath, "sumcheck") - entries = []bavard.Entry{ - {File: filepath.Join(packagePath, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, - {File: filepath.Join(packagePath, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl"}}, - } - if err := bgen.Generate(cfg, "sumcheck", "./template/gkr/", entries...); err != nil { - panic(err) - } + err = generateGkrBackend(config.FieldDependency{ + ElementType: "fr.Element", + FieldPackageName: "fr", + FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + curvePackageName + "/fr", + }, curvePackageName) + assertNoError(err) } entries = []bavard.Entry{ @@ -241,6 +210,26 @@ func main() { } + wg.Add(1) + // GKR test vectors + go func() { + // generate sumcheck for small-rational + err := generateGkrBackend(config.FieldDependency{ + ElementType: "small_rational.SmallRational", + FieldPackagePath: "github.com/consensys/gnark/internal/small_rational", + FieldPackageName: "small_rational", + }, "small_rational") + assertNoError(err) + + // generate test vectors for sumcheck + cmd := exec.Command("go", "run", "./sumcheck/test_vectors") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + assertNoError(cmd.Run()) + + wg.Done() + }() + wg.Wait() // run go fmt on whole directory @@ -261,3 +250,45 @@ type templateData struct { noBackend bool NoGKR bool } + +func generateGkrBackend(fieldDep config.FieldDependency, curvePackageName string) error { + const repoRoot = "../../../" + + gkrPackageDirRelPath := filepath.Join(repoRoot+"internal/gkr/", curvePackageName) + + cfg := struct { + config.FieldDependency + GkrPackagePath string + }{ + fieldDep, + "github.com/consensys/gnark/internal/gkr/" + curvePackageName, + } + + // test vector utils + packagePath := filepath.Join(gkrPackageDirRelPath, "test_vector_utils") + entries := []bavard.Entry{ + {File: filepath.Join(packagePath, "test_vector_utils.go"), Templates: []string{"test_vector_utils.go.tmpl"}}, + } + + if err := bgen.Generate(cfg, "test_vector_utils", "./template/gkr/", entries...); err != nil { + return err + } + + // sumcheck backend + packagePath = filepath.Join(gkrPackageDirRelPath, "sumcheck") + entries = []bavard.Entry{ + {File: filepath.Join(packagePath, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, + {File: filepath.Join(packagePath, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl"}}, + } + if err := bgen.Generate(cfg, "sumcheck", "./template/gkr/", entries...); err != nil { + return err + } + + return nil +} + +func assertNoError(err error) { + if err != nil { + panic(err) + } +} diff --git a/internal/generator/backend/sumcheck/test_vectors/main.go b/internal/generator/backend/sumcheck/test_vectors/main.go index 798f5a4f3f..8a5c3f867e 100644 --- a/internal/generator/backend/sumcheck/test_vectors/main.go +++ b/internal/generator/backend/sumcheck/test_vectors/main.go @@ -4,10 +4,10 @@ import ( "encoding/json" "fmt" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" "hash" "math/bits" "os" diff --git a/internal/gkr/bls12-377/sumcheck/sumcheck_test.go b/internal/gkr/bls12-377/sumcheck/sumcheck_test.go index 00d6ffdf28..a9d152c10a 100644 --- a/internal/gkr/bls12-377/sumcheck/sumcheck_test.go +++ b/internal/gkr/bls12-377/sumcheck/sumcheck_test.go @@ -120,7 +120,6 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) polys := [][]uint64{ {1, 2, 3, 4}, // 1 + 2X₁ + X₂ diff --git a/internal/gkr/bls12-381/sumcheck/sumcheck_test.go b/internal/gkr/bls12-381/sumcheck/sumcheck_test.go index 40664ee4eb..4d98d79437 100644 --- a/internal/gkr/bls12-381/sumcheck/sumcheck_test.go +++ b/internal/gkr/bls12-381/sumcheck/sumcheck_test.go @@ -120,7 +120,6 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) polys := [][]uint64{ {1, 2, 3, 4}, // 1 + 2X₁ + X₂ diff --git a/internal/gkr/bls24-315/sumcheck/sumcheck_test.go b/internal/gkr/bls24-315/sumcheck/sumcheck_test.go index f1a86c12f4..f41552f57c 100644 --- a/internal/gkr/bls24-315/sumcheck/sumcheck_test.go +++ b/internal/gkr/bls24-315/sumcheck/sumcheck_test.go @@ -120,7 +120,6 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) polys := [][]uint64{ {1, 2, 3, 4}, // 1 + 2X₁ + X₂ diff --git a/internal/gkr/bls24-317/sumcheck/sumcheck_test.go b/internal/gkr/bls24-317/sumcheck/sumcheck_test.go index 0efca63df7..7053f04844 100644 --- a/internal/gkr/bls24-317/sumcheck/sumcheck_test.go +++ b/internal/gkr/bls24-317/sumcheck/sumcheck_test.go @@ -120,7 +120,6 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) polys := [][]uint64{ {1, 2, 3, 4}, // 1 + 2X₁ + X₂ diff --git a/internal/gkr/bn254/sumcheck/sumcheck_test.go b/internal/gkr/bn254/sumcheck/sumcheck_test.go index cd7259736e..8053589b35 100644 --- a/internal/gkr/bn254/sumcheck/sumcheck_test.go +++ b/internal/gkr/bn254/sumcheck/sumcheck_test.go @@ -120,7 +120,6 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) polys := [][]uint64{ {1, 2, 3, 4}, // 1 + 2X₁ + X₂ diff --git a/internal/gkr/bw6-633/sumcheck/sumcheck_test.go b/internal/gkr/bw6-633/sumcheck/sumcheck_test.go index 403839293f..4c740ab0ec 100644 --- a/internal/gkr/bw6-633/sumcheck/sumcheck_test.go +++ b/internal/gkr/bw6-633/sumcheck/sumcheck_test.go @@ -120,7 +120,6 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) polys := [][]uint64{ {1, 2, 3, 4}, // 1 + 2X₁ + X₂ diff --git a/internal/gkr/bw6-761/sumcheck/sumcheck_test.go b/internal/gkr/bw6-761/sumcheck/sumcheck_test.go index 2f95dc376e..d6f520fc19 100644 --- a/internal/gkr/bw6-761/sumcheck/sumcheck_test.go +++ b/internal/gkr/bw6-761/sumcheck/sumcheck_test.go @@ -120,7 +120,6 @@ func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash } func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - //printMsws(36) polys := [][]uint64{ {1, 2, 3, 4}, // 1 + 2X₁ + X₂ diff --git a/internal/gkr/small_rational/sumcheck/sumcheck.go b/internal/gkr/small_rational/sumcheck/sumcheck.go new file mode 100644 index 0000000000..e491815a87 --- /dev/null +++ b/internal/gkr/small_rational/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []small_rational.SmallRational) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a small_rational.SmallRational) small_rational.SmallRational // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []small_rational.SmallRational, remainingChallengeNames *[]string) (small_rational.SmallRational, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return small_rational.SmallRational{}, err + } + } + var res small_rational.SmallRational + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff small_rational.SmallRational + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]small_rational.SmallRational, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff small_rational.SmallRational + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]small_rational.SmallRational, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/small_rational/sumcheck/sumcheck_test.go b/internal/gkr/small_rational/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..c2166b7c12 --- /dev/null +++ b/internal/gkr/small_rational/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/small_rational/test_vector_utils" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []small_rational.SmallRational{sum} +} + +func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum small_rational.SmallRational +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs small_rational.SmallRational) small_rational.SmallRational { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/small_rational/test_vector_utils/test_vector_utils.go b/internal/gkr/small_rational/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..3102a2133d --- /dev/null +++ b/internal/gkr/small_rational/test_vector_utils/test_vector_utils.go @@ -0,0 +1,185 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "hash" + "reflect" +) + +func ToElement(i int64) *small_rational.SmallRational { + var res small_rational.SmallRational + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res small_rational.SmallRational + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return small_rational.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return small_rational.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []small_rational.SmallRational + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return small_rational.Bytes +} + +func (h *ListHash) BlockSize() int { + return small_rational.Bytes +} + +func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *small_rational.SmallRational) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(small_rational.SmallRational) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/test_vectors/sumcheck/main.go b/internal/gkr/test_vectors/sumcheck/main.go new file mode 100644 index 0000000000..765da2275a --- /dev/null +++ b/internal/gkr/test_vectors/sumcheck/main.go @@ -0,0 +1,200 @@ +package main + +import ( + "encoding/json" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" + "hash" + "math/bits" + "os" + "path/filepath" +) + +func runMultilin(testCaseInfo *TestCaseInfo) error { + + var poly polynomial.MultiLin + if v, err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); err == nil { + poly = v + } else { + return err + } + + var hsh hash.Hash + var err error + if hsh, err = test_vector_utils.HashFromDescription(testCaseInfo.Hash); err != nil { + return err + } + + proof, err := sumcheck.Prove( + &singleMultilinClaim{poly}, fiatshamir.WithHash(hsh)) + if err != nil { + return err + } + testCaseInfo.Proof = toPrintableProof(proof) + + // Verification + if v, _err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); _err == nil { + poly = v + } else { + return _err + } + var claimedSum small_rational.SmallRational + if _, err = claimedSum.SetInterface(testCaseInfo.ClaimedSum); err != nil { + return err + } + + if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err != nil { + return fmt.Errorf("proof rejected: %v", err) + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func run(testCaseInfo *TestCaseInfo) error { + switch testCaseInfo.Type { + case "multilin": + return runMultilin(testCaseInfo) + default: + return fmt.Errorf("type \"%s\" unrecognized", testCaseInfo.Type) + } +} + +func runAll(relPath string) error { + var filename string + var err error + if filename, err = filepath.Abs(relPath); err != nil { + return err + } + + var bytes []byte + + if bytes, err = os.ReadFile(filename); err != nil { + return err + } + + var testCasesInfo TestCasesInfo + if err = json.Unmarshal(bytes, &testCasesInfo); err != nil { + return err + } + + failed := false + for name, testCase := range testCasesInfo { + if err = run(testCase); err != nil { + fmt.Println(name, ":", err) + failed = true + } + } + + if failed { + return fmt.Errorf("test case failed") + } + + if bytes, err = json.MarshalIndent(testCasesInfo, "", "\t"); err != nil { + return err + } + + return os.WriteFile(filename, bytes, 0) +} + +func main() { + // read the test vectors file, generate the proof, make sure it verifies, + // and add the proof to the same file + if err := runAll("sumcheck/test_vectors/vectors.json"); err != nil { + fmt.Println(err) + os.Exit(-1) + } +} + +type TestCasesInfo map[string]*TestCaseInfo + +type TestCaseInfo struct { + Type string `json:"type"` + Hash test_vector_utils.HashDescription `json:"hash"` + Values []interface{} `json:"values"` + Description string `json:"description"` + Proof PrintableProof `json:"proof"` + ClaimedSum interface{} `json:"claimedSum"` +} + +type PrintableProof struct { + PartialSumPolys [][]interface{} `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` +} + +func toPrintableProof(proof sumcheck.Proof) (printable PrintableProof) { + if proof.FinalEvalProof != nil { + panic("null expected") + } + printable.FinalEvalProof = struct{}{} + printable.PartialSumPolys = test_vector_utils.ElementSliceSliceToInterfaceSliceSlice(proof.PartialSumPolys) + return +} + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval([]small_rational.SmallRational) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []small_rational.SmallRational{sum} +} + +func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum small_rational.SmallRational +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, _ small_rational.SmallRational, purportedValue small_rational.SmallRational, _ interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(small_rational.SmallRational) small_rational.SmallRational { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} diff --git a/internal/gkr/test_vectors/sumcheck/vectors.json b/internal/gkr/test_vectors/sumcheck/vectors.json new file mode 100644 index 0000000000..64b8e3fb2d --- /dev/null +++ b/internal/gkr/test_vectors/sumcheck/vectors.json @@ -0,0 +1,56 @@ +{ + "linear_univariate_single_claim": { + "type": "multilin", + "hash": { + "type": "const", + "val": -1 + }, + "values": [ + 1, + 3 + ], + "description": "X ↦ 2X + 1", + "proof": { + "partialSumPolys": [ + [ + 3 + ] + ], + "finalEvalProof": {} + }, + "claimedSum": 4 + }, + "trilinear_single_claim": { + "type": "multilin", + "hash": { + "type": "const", + "val": -1 + }, + "values": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8 + ], + "description": "X₁, X₂, X₃ ↦ 1 + 4X₁ + 2X₂ + X₃", + "proof": { + "partialSumPolys": [ + [ + 26 + ], + [ + -1 + ], + [ + -4 + ] + ], + "finalEvalProof": {} + }, + "claimedSum": 36 + } +} \ No newline at end of file diff --git a/internal/small_rational/polynomial/multilin.go b/internal/small_rational/polynomial/multilin.go index 7002cdc811..6bb2d8916e 100644 --- a/internal/small_rational/polynomial/multilin.go +++ b/internal/small_rational/polynomial/multilin.go @@ -4,8 +4,8 @@ package polynomial import ( - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/small_rational" "math/bits" ) diff --git a/internal/small_rational/polynomial/polynomial.go b/internal/small_rational/polynomial/polynomial.go index ae50b2c07d..61d287ccdd 100644 --- a/internal/small_rational/polynomial/polynomial.go +++ b/internal/small_rational/polynomial/polynomial.go @@ -4,8 +4,8 @@ package polynomial import ( - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/small_rational" "strconv" "strings" "sync" diff --git a/internal/small_rational/test_vector_utils/test_vector_utils.go b/internal/small_rational/test_vector_utils/test_vector_utils.go index 9e91fe7c67..9459281e09 100644 --- a/internal/small_rational/test_vector_utils/test_vector_utils.go +++ b/internal/small_rational/test_vector_utils/test_vector_utils.go @@ -7,8 +7,9 @@ package test_vector_utils import ( "fmt" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "hash" "reflect" ) From 7edcf03f5ea2bc2396c520f4f67f972f5c8b1aa6 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 1 Apr 2025 19:52:24 -0500 Subject: [PATCH 34/62] fix: codegen error --- internal/generator/backend/main.go | 3 ++- internal/small_rational/polynomial/pool.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index cfe3c62cb8..24f255710e 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "os" "os/exec" "path/filepath" @@ -221,7 +222,7 @@ func main() { }, "small_rational") assertNoError(err) - // generate test vectors for sumcheck + fmt.Println("generating test vectors for sumcheck") cmd := exec.Command("go", "run", "./sumcheck/test_vectors") cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr diff --git a/internal/small_rational/polynomial/pool.go b/internal/small_rational/polynomial/pool.go index 333029f17d..bc855ef5d4 100644 --- a/internal/small_rational/polynomial/pool.go +++ b/internal/small_rational/polynomial/pool.go @@ -4,7 +4,7 @@ package polynomial import ( - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark/internal/small_rational" ) // Do as little as possible to instantiate the interface From f01b8b680100f045bcb1f862c37ac85af71a2096 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 1 Apr 2025 20:33:43 -0500 Subject: [PATCH 35/62] perf: run test vec gen in same process --- internal/generator/backend/main.go | 6 ++-- .../{main.go => sumcheck-gen-vectors.go} | 30 +++++++++++-------- 2 files changed, 19 insertions(+), 17 deletions(-) rename internal/gkr/test_vectors/sumcheck/{main.go => sumcheck-gen-vectors.go} (95%) diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 24f255710e..49b67c0522 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + sumcheckTestVectors "github.com/consensys/gnark/internal/gkr/test_vectors/sumcheck" "os" "os/exec" "path/filepath" @@ -223,10 +224,7 @@ func main() { assertNoError(err) fmt.Println("generating test vectors for sumcheck") - cmd := exec.Command("go", "run", "./sumcheck/test_vectors") - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - assertNoError(cmd.Run()) + assertNoError(sumcheckTestVectors.Generate()) wg.Done() }() diff --git a/internal/gkr/test_vectors/sumcheck/main.go b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go similarity index 95% rename from internal/gkr/test_vectors/sumcheck/main.go rename to internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go index 765da2275a..7917a3b60c 100644 --- a/internal/gkr/test_vectors/sumcheck/main.go +++ b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go @@ -1,9 +1,10 @@ -package main +package sumcheck import ( "encoding/json" "fmt" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" "github.com/consensys/gnark/internal/small_rational" "github.com/consensys/gnark/internal/small_rational/polynomial" "github.com/consensys/gnark/internal/small_rational/test_vector_utils" @@ -11,6 +12,7 @@ import ( "math/bits" "os" "path/filepath" + "runtime/pprof" ) func runMultilin(testCaseInfo *TestCaseInfo) error { @@ -22,8 +24,11 @@ func runMultilin(testCaseInfo *TestCaseInfo) error { return err } - var hsh hash.Hash - var err error + var ( + hsh hash.Hash + err error + ) + if hsh, err = test_vector_utils.HashFromDescription(testCaseInfo.Hash); err != nil { return err } @@ -54,6 +59,10 @@ func runMultilin(testCaseInfo *TestCaseInfo) error { if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err == nil { return fmt.Errorf("bad proof accepted") } + + pprof.StopCPUProfile() + //return f.Close() + return nil } @@ -66,7 +75,11 @@ func run(testCaseInfo *TestCaseInfo) error { } } -func runAll(relPath string) error { +func Generate() error { + // read the test vectors file, generate the proof, make sure it verifies, + // and add the proof to the same file + const relPath = "sumcheck/test_vectors/vectors.json" + var filename string var err error if filename, err = filepath.Abs(relPath); err != nil { @@ -103,15 +116,6 @@ func runAll(relPath string) error { return os.WriteFile(filename, bytes, 0) } -func main() { - // read the test vectors file, generate the proof, make sure it verifies, - // and add the proof to the same file - if err := runAll("sumcheck/test_vectors/vectors.json"); err != nil { - fmt.Println(err) - os.Exit(-1) - } -} - type TestCasesInfo map[string]*TestCaseInfo type TestCaseInfo struct { From 16b20bcbeb45baaf44e86cb90502543f74b519cf Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 2 Apr 2025 12:22:08 -0500 Subject: [PATCH 36/62] generate gkr --- gkr.go | 867 ++++++++++++++++++ gkr_test.go | 829 +++++++++++++++++ internal/generator/backend/gkr/generate.go | 29 - .../mimc_five_levels.json | 0 .../single_identity_gate.json | 0 .../single_input_two_identity_gates.json | 0 .../single_input_two_outs.json | 0 .../single_mimc_gate.json | 0 .../single_mul_gate.json | 0 ..._identity_gates_composed_single_input.json | 0 .../two_inputs_select-input-3_gate.json | 0 internal/generator/backend/main.go | 78 +- .../backend/template/gkr/gkr.go.tmpl | 22 +- .../backend/template/gkr/gkr.test.go.tmpl | 6 +- .../template/gkr/gkr.test.vectors.go.tmpl | 4 +- .../template/gkr/sumcheck.test.go.tmpl | 2 +- internal/gkr/bls12-377/gkr.go | 865 +++++++++++++++++ internal/gkr/bls12-377/gkr_test.go | 829 +++++++++++++++++ internal/gkr/bls12-377/registry.go | 320 +++++++ internal/gkr/bls12-381/gkr.go | 865 +++++++++++++++++ internal/gkr/bls12-381/gkr_test.go | 829 +++++++++++++++++ internal/gkr/bls12-381/registry.go | 320 +++++++ internal/gkr/bls24-315/gkr.go | 865 +++++++++++++++++ internal/gkr/bls24-315/gkr_test.go | 829 +++++++++++++++++ internal/gkr/bls24-315/registry.go | 320 +++++++ internal/gkr/bls24-317/gkr.go | 865 +++++++++++++++++ internal/gkr/bls24-317/gkr_test.go | 829 +++++++++++++++++ internal/gkr/bls24-317/registry.go | 320 +++++++ internal/gkr/bn254/gkr.go | 865 +++++++++++++++++ internal/gkr/bn254/gkr_test.go | 829 +++++++++++++++++ internal/gkr/bn254/registry.go | 320 +++++++ internal/gkr/bw6-633/gkr.go | 865 +++++++++++++++++ internal/gkr/bw6-633/gkr_test.go | 829 +++++++++++++++++ internal/gkr/bw6-633/registry.go | 320 +++++++ internal/gkr/bw6-761/gkr.go | 865 +++++++++++++++++ internal/gkr/bw6-761/gkr_test.go | 829 +++++++++++++++++ internal/gkr/bw6-761/registry.go | 320 +++++++ internal/gkr/gkr.go | 867 ++++++++++++++++++ internal/gkr/gkr_test.go | 829 +++++++++++++++++ internal/gkr/registry.go | 374 ++++++++ internal/gkr/small_rational/gkr.go | 865 +++++++++++++++++ internal/gkr/small_rational/gkr_test.go | 829 +++++++++++++++++ internal/gkr/small_rational/registry.go | 374 ++++++++ internal/gkr/sumcheck/sumcheck.go | 170 ++++ internal/gkr/sumcheck/sumcheck_test.go | 149 +++ .../test_vector_utils/test_vector_utils.go | 185 ++++ registry.go | 374 ++++++++ sumcheck/sumcheck.go | 170 ++++ sumcheck/sumcheck_test.go | 149 +++ test_vector_utils/test_vector_utils.go | 185 ++++ 50 files changed, 21378 insertions(+), 77 deletions(-) create mode 100644 gkr.go create mode 100644 gkr_test.go delete mode 100644 internal/generator/backend/gkr/generate.go rename internal/generator/backend/gkr/test_vectors/{resources => circuits}/mimc_five_levels.json (100%) rename internal/generator/backend/gkr/test_vectors/{resources => circuits}/single_identity_gate.json (100%) rename internal/generator/backend/gkr/test_vectors/{resources => circuits}/single_input_two_identity_gates.json (100%) rename internal/generator/backend/gkr/test_vectors/{resources => circuits}/single_input_two_outs.json (100%) rename internal/generator/backend/gkr/test_vectors/{resources => circuits}/single_mimc_gate.json (100%) rename internal/generator/backend/gkr/test_vectors/{resources => circuits}/single_mul_gate.json (100%) rename internal/generator/backend/gkr/test_vectors/{resources => circuits}/two_identity_gates_composed_single_input.json (100%) rename internal/generator/backend/gkr/test_vectors/{resources => circuits}/two_inputs_select-input-3_gate.json (100%) create mode 100644 internal/gkr/bls12-377/gkr.go create mode 100644 internal/gkr/bls12-377/gkr_test.go create mode 100644 internal/gkr/bls12-377/registry.go create mode 100644 internal/gkr/bls12-381/gkr.go create mode 100644 internal/gkr/bls12-381/gkr_test.go create mode 100644 internal/gkr/bls12-381/registry.go create mode 100644 internal/gkr/bls24-315/gkr.go create mode 100644 internal/gkr/bls24-315/gkr_test.go create mode 100644 internal/gkr/bls24-315/registry.go create mode 100644 internal/gkr/bls24-317/gkr.go create mode 100644 internal/gkr/bls24-317/gkr_test.go create mode 100644 internal/gkr/bls24-317/registry.go create mode 100644 internal/gkr/bn254/gkr.go create mode 100644 internal/gkr/bn254/gkr_test.go create mode 100644 internal/gkr/bn254/registry.go create mode 100644 internal/gkr/bw6-633/gkr.go create mode 100644 internal/gkr/bw6-633/gkr_test.go create mode 100644 internal/gkr/bw6-633/registry.go create mode 100644 internal/gkr/bw6-761/gkr.go create mode 100644 internal/gkr/bw6-761/gkr_test.go create mode 100644 internal/gkr/bw6-761/registry.go create mode 100644 internal/gkr/gkr.go create mode 100644 internal/gkr/gkr_test.go create mode 100644 internal/gkr/registry.go create mode 100644 internal/gkr/small_rational/gkr.go create mode 100644 internal/gkr/small_rational/gkr_test.go create mode 100644 internal/gkr/small_rational/registry.go create mode 100644 internal/gkr/sumcheck/sumcheck.go create mode 100644 internal/gkr/sumcheck/sumcheck_test.go create mode 100644 internal/gkr/test_vector_utils/test_vector_utils.go create mode 100644 registry.go create mode 100644 sumcheck/sumcheck.go create mode 100644 sumcheck/sumcheck_test.go create mode 100644 test_vector_utils/test_vector_utils.go diff --git a/gkr.go b/gkr.go new file mode 100644 index 0000000000..70913dd297 --- /dev/null +++ b/gkr.go @@ -0,0 +1,867 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/parallel" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark//sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...small_rational.SmallRational) small_rational.SmallRational + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]small_rational.SmallRational + claimedEvaluations []small_rational.SmallRational + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a small_rational.SmallRational) small_rational.SmallRational { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]small_rational.SmallRational) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation small_rational.SmallRational + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]small_rational.SmallRational // x in the paper + claimedEvaluations []small_rational.SmallRational // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational.SmallRational) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]small_rational.SmallRational, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step small_rational.SmallRational + + res := make([]small_rational.SmallRational, degGJ) + operands := make([]small_rational.SmallRational, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element small_rational.SmallRational) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) interface{} { + + //defer the proof, return list of claims + evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]small_rational.SmallRational, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []small_rational.SmallRational, evaluation small_rational.SmallRational) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_rational.SmallRational, error) { + res := make([]small_rational.SmallRational, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []small_rational.SmallRational{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]small_rational.SmallRational) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]small_rational.SmallRational) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// TopologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func TopologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := TopologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]small_rational.SmallRational, nbInstances) + } + } + + parallel.Execute(nbInstances, func(start, end int) { + ins := make([]small_rational.SmallRational, maxNbIns) + for i := start; i < end; i++ { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + }) + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]small_rational.SmallRational) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/gkr_test.go b/gkr_test.go new file mode 100644 index 0000000000..31bd52133a --- /dev/null +++ b/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/mimc" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []small_rational.SmallRational{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []small_rational.SmallRational{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []small_rational.SmallRational{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]small_rational.SmallRational) { + return func(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []small_rational.SmallRational{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []small_rational.SmallRational{three}, five) + manager.add(wire, []small_rational.SmallRational{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six small_rational.SmallRational + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]small_rational.SmallRational)) { + fullAssignments := make([][]small_rational.SmallRational, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]small_rational.SmallRational, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []small_rational.SmallRational) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := TopologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := TopologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := TopologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { + var sum small_rational.SmallRational + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []small_rational.SmallRational(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []small_rational.SmallRational + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/generator/backend/gkr/generate.go b/internal/generator/backend/gkr/generate.go deleted file mode 100644 index 3b679276d4..0000000000 --- a/internal/generator/backend/gkr/generate.go +++ /dev/null @@ -1,29 +0,0 @@ -package gkr - -import ( - "path/filepath" - - "github.com/consensys/bavard" -) - -type Config struct { - GenerateTests bool - RetainTestCaseRawInfo bool - CanUseFFT bool - OutsideGkrPackage bool - TestVectorsRelativePath string -} - -func Generate(config Config, baseDir string, bgen *bavard.BatchGenerator) error { - entries := []bavard.Entry{ - {File: filepath.Join(baseDir, "gkr.go"), Templates: []string{"gkr.go.tmpl"}}, - {File: filepath.Join(baseDir, "registry.go"), Templates: []string{"registry.go.tmpl"}}, - } - - if config.GenerateTests { - entries = append(entries, - bavard.Entry{File: filepath.Join(baseDir, "gkr_test.go"), Templates: []string{"gkr.test.go.tmpl", "gkr.test.vectors.go.tmpl"}}) - } - - return bgen.Generate(config, "gkr", "./gkr/template/", entries...) -} diff --git a/internal/generator/backend/gkr/test_vectors/resources/mimc_five_levels.json b/internal/generator/backend/gkr/test_vectors/circuits/mimc_five_levels.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/resources/mimc_five_levels.json rename to internal/generator/backend/gkr/test_vectors/circuits/mimc_five_levels.json diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_identity_gate.json b/internal/generator/backend/gkr/test_vectors/circuits/single_identity_gate.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/resources/single_identity_gate.json rename to internal/generator/backend/gkr/test_vectors/circuits/single_identity_gate.json diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_input_two_identity_gates.json b/internal/generator/backend/gkr/test_vectors/circuits/single_input_two_identity_gates.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/resources/single_input_two_identity_gates.json rename to internal/generator/backend/gkr/test_vectors/circuits/single_input_two_identity_gates.json diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_input_two_outs.json b/internal/generator/backend/gkr/test_vectors/circuits/single_input_two_outs.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/resources/single_input_two_outs.json rename to internal/generator/backend/gkr/test_vectors/circuits/single_input_two_outs.json diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_mimc_gate.json b/internal/generator/backend/gkr/test_vectors/circuits/single_mimc_gate.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/resources/single_mimc_gate.json rename to internal/generator/backend/gkr/test_vectors/circuits/single_mimc_gate.json diff --git a/internal/generator/backend/gkr/test_vectors/resources/single_mul_gate.json b/internal/generator/backend/gkr/test_vectors/circuits/single_mul_gate.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/resources/single_mul_gate.json rename to internal/generator/backend/gkr/test_vectors/circuits/single_mul_gate.json diff --git a/internal/generator/backend/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json b/internal/generator/backend/gkr/test_vectors/circuits/two_identity_gates_composed_single_input.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json rename to internal/generator/backend/gkr/test_vectors/circuits/two_identity_gates_composed_single_input.json diff --git a/internal/generator/backend/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json b/internal/generator/backend/gkr/test_vectors/circuits/two_inputs_select-input-3_gate.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json rename to internal/generator/backend/gkr/test_vectors/circuits/two_inputs_select-input-3_gate.json diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 49b67c0522..30854f97e0 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -136,12 +136,18 @@ func main() { assertNoError(err) curvePackageName := strings.ToLower(d.Curve) - err = generateGkrBackend(config.FieldDependency{ - ElementType: "fr.Element", - FieldPackageName: "fr", - FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + curvePackageName + "/fr", - }, curvePackageName) - assertNoError(err) + + cfg := gkrConfig{ + FieldDependency: config.FieldDependency{ + ElementType: "fr.Element", + FieldPackageName: "fr", + FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + curvePackageName + "/fr", + }, + GkrPackageRelativePath: "internal/gkr/" + curvePackageName, + CanUseFFT: true, + } + + assertNoError(generateGkrBackend(cfg)) } entries = []bavard.Entry{ @@ -216,15 +222,20 @@ func main() { // GKR test vectors go func() { // generate sumcheck for small-rational - err := generateGkrBackend(config.FieldDependency{ - ElementType: "small_rational.SmallRational", - FieldPackagePath: "github.com/consensys/gnark/internal/small_rational", - FieldPackageName: "small_rational", - }, "small_rational") + err := generateGkrBackend(gkrConfig{ + FieldDependency: config.FieldDependency{ + ElementType: "small_rational.SmallRational", + FieldPackagePath: "github.com/consensys/gnark/internal/small_rational", + FieldPackageName: "small_rational", + }, + GkrPackageRelativePath: "internal/gkr/small_rational", + CanUseFFT: false, + }) assertNoError(err) fmt.Println("generating test vectors for sumcheck") - assertNoError(sumcheckTestVectors.Generate()) + assertNoError(sumcheckTestVectors.Generate()) // TODO CRITICAL This must be an independent process so that it's compiled before being run] + // TODO it also needs to run after everything else is done wg.Done() }() @@ -250,23 +261,14 @@ type templateData struct { NoGKR bool } -func generateGkrBackend(fieldDep config.FieldDependency, curvePackageName string) error { +func generateGkrBackend(cfg gkrConfig) error { const repoRoot = "../../../" - - gkrPackageDirRelPath := filepath.Join(repoRoot+"internal/gkr/", curvePackageName) - - cfg := struct { - config.FieldDependency - GkrPackagePath string - }{ - fieldDep, - "github.com/consensys/gnark/internal/gkr/" + curvePackageName, - } + packageOutPath := filepath.Join(repoRoot, cfg.GkrPackageRelativePath) // test vector utils - packagePath := filepath.Join(gkrPackageDirRelPath, "test_vector_utils") + packageDir := filepath.Join(packageOutPath, "test_vector_utils") entries := []bavard.Entry{ - {File: filepath.Join(packagePath, "test_vector_utils.go"), Templates: []string{"test_vector_utils.go.tmpl"}}, + {File: filepath.Join(packageDir, "test_vector_utils.go"), Templates: []string{"test_vector_utils.go.tmpl"}}, } if err := bgen.Generate(cfg, "test_vector_utils", "./template/gkr/", entries...); err != nil { @@ -274,18 +276,38 @@ func generateGkrBackend(fieldDep config.FieldDependency, curvePackageName string } // sumcheck backend - packagePath = filepath.Join(gkrPackageDirRelPath, "sumcheck") + packageDir = filepath.Join(packageOutPath, "sumcheck") entries = []bavard.Entry{ - {File: filepath.Join(packagePath, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, - {File: filepath.Join(packagePath, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl"}}, + {File: filepath.Join(packageDir, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, + {File: filepath.Join(packageDir, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl"}}, } if err := bgen.Generate(cfg, "sumcheck", "./template/gkr/", entries...); err != nil { return err } + // gkr backend + packageDir = packageOutPath + entries = []bavard.Entry{ + {File: filepath.Join(packageDir, "gkr.go"), Templates: []string{"gkr.go.tmpl"}}, + {File: filepath.Join(packageDir, "registry.go"), Templates: []string{"registry.go.tmpl"}}, + {File: filepath.Join(packageDir, "gkr_test.go"), Templates: []string{"gkr.test.go.tmpl", "gkr.test.vectors.go.tmpl"}}, + } + + if err := bgen.Generate(cfg, "gkr", "./template/gkr/", entries...); err != nil { + return err + } + return nil } +type gkrConfig struct { + config.FieldDependency + GkrPackageRelativePath string + CanUseFFT bool + OutsideGkrPackage bool + GenerateTestVectors bool +} + func assertNoError(err error) { if err != nil { panic(err) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index c27daa9b59..886feeb3ca 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -3,9 +3,8 @@ import ( "fmt" "{{.FieldPackagePath}}" "{{.FieldPackagePath}}/polynomial" - "{{.FieldPackagePath}}/sumcheck" + "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/sumcheck" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" "github.com/consensys/gnark-crypto/utils" "math/big" "strconv" @@ -808,19 +807,18 @@ func (a WireAssignment) Complete(c Circuit) WireAssignment { } } - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]{{.ElementType}}, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + // TODO: Parallelize, if needed + ins := make([]{{.ElementType}}, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) } } - }) + } return a } diff --git a/internal/generator/backend/template/gkr/gkr.test.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.go.tmpl index 378cb813e0..1000786d34 100644 --- a/internal/generator/backend/template/gkr/gkr.test.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.go.tmpl @@ -19,7 +19,6 @@ import ( "time" ) -{{$GenerateLargeTests := .GenerateTests}} {{/* this is redundant. soon to be removed if a use case for it doesn't come back */}} {{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} func TestNoGateTwoInstances(t *testing.T) { @@ -81,7 +80,6 @@ func TestShallowMimcTwoInstances(t *testing.T) { testMimc(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) } -{{- if $GenerateLargeTests}} func TestMimcTwoInstances(t *testing.T) { testMimc(t, 93, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) } @@ -96,8 +94,6 @@ func generateTestMimc(numRounds int) func(*testing.T, ...[]{{.ElementType}}) { } } -{{- end}} - func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { circuit := Circuit{ Wire{ Gate: GetGate(Identity), @@ -402,7 +398,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "{{.TestVectorsRelativePath}}" + testDirPath := "" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { diff --git a/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl index 0025b0164a..b1f8bbdf9b 100644 --- a/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl @@ -132,7 +132,7 @@ type TestCase struct { Proof {{$Proof}} FullAssignment {{$WireAssignment}} InOutAssignment {{$WireAssignment}} - {{if .RetainTestCaseRawInfo}}Info TestCaseInfo{{end}} + {{if .GenerateTestVectors}}Info TestCaseInfo // we are generating the test vectors, so we need to keep the circuit instance info to ADD the proof to it and resave it{{end}} } type TestCaseInfo struct { @@ -231,7 +231,7 @@ func newTestCase(path string) (*TestCase, error) { Proof: proof, Hash: _hash, Circuit: circuit, - {{if .RetainTestCaseRawInfo }}Info: info,{{end}} + {{if .GenerateTestVectors }}Info: info,{{end}} } testCases[path] = tCase diff --git a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl index e599869be2..f85214d1cd 100644 --- a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl @@ -3,7 +3,7 @@ import ( "{{.FieldPackagePath}}" "{{.FieldPackagePath}}/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "{{.GkrPackagePath}}/test_vector_utils" + "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/test_vector_utils" "github.com/stretchr/testify/assert" "hash" "math/bits" diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go new file mode 100644 index 0000000000..725ba5fbcd --- /dev/null +++ b/internal/gkr/bls12-377/gkr.go @@ -0,0 +1,865 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls12-377/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + res := make([]fr.Element, degGJ) + operands := make([]fr.Element, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bls12-377/gkr_test.go b/internal/gkr/bls12-377/gkr_test.go new file mode 100644 index 0000000000..acc38d35af --- /dev/null +++ b/internal/gkr/bls12-377/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bls12-377/registry.go b/internal/gkr/bls12-377/registry.go new file mode 100644 index 0000000000..251a0a5dfd --- /dev/null +++ b/internal/gkr/bls12-377/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go new file mode 100644 index 0000000000..c0387ff7bd --- /dev/null +++ b/internal/gkr/bls12-381/gkr.go @@ -0,0 +1,865 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls12-381/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + res := make([]fr.Element, degGJ) + operands := make([]fr.Element, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bls12-381/gkr_test.go b/internal/gkr/bls12-381/gkr_test.go new file mode 100644 index 0000000000..d8964182f6 --- /dev/null +++ b/internal/gkr/bls12-381/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bls12-381/registry.go b/internal/gkr/bls12-381/registry.go new file mode 100644 index 0000000000..484b1939f0 --- /dev/null +++ b/internal/gkr/bls12-381/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go new file mode 100644 index 0000000000..22809d20f0 --- /dev/null +++ b/internal/gkr/bls24-315/gkr.go @@ -0,0 +1,865 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls24-315/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + res := make([]fr.Element, degGJ) + operands := make([]fr.Element, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bls24-315/gkr_test.go b/internal/gkr/bls24-315/gkr_test.go new file mode 100644 index 0000000000..04dd26a153 --- /dev/null +++ b/internal/gkr/bls24-315/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bls24-315/registry.go b/internal/gkr/bls24-315/registry.go new file mode 100644 index 0000000000..f5cae19de7 --- /dev/null +++ b/internal/gkr/bls24-315/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go new file mode 100644 index 0000000000..5b26065286 --- /dev/null +++ b/internal/gkr/bls24-317/gkr.go @@ -0,0 +1,865 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls24-317/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + res := make([]fr.Element, degGJ) + operands := make([]fr.Element, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bls24-317/gkr_test.go b/internal/gkr/bls24-317/gkr_test.go new file mode 100644 index 0000000000..c647ee01ea --- /dev/null +++ b/internal/gkr/bls24-317/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bls24-317/registry.go b/internal/gkr/bls24-317/registry.go new file mode 100644 index 0000000000..d0c68beea6 --- /dev/null +++ b/internal/gkr/bls24-317/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go new file mode 100644 index 0000000000..971a3ac342 --- /dev/null +++ b/internal/gkr/bn254/gkr.go @@ -0,0 +1,865 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bn254/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + res := make([]fr.Element, degGJ) + operands := make([]fr.Element, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bn254/gkr_test.go b/internal/gkr/bn254/gkr_test.go new file mode 100644 index 0000000000..9ac49f5cc0 --- /dev/null +++ b/internal/gkr/bn254/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bn254/registry.go b/internal/gkr/bn254/registry.go new file mode 100644 index 0000000000..be935ba7b5 --- /dev/null +++ b/internal/gkr/bn254/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go new file mode 100644 index 0000000000..932070198f --- /dev/null +++ b/internal/gkr/bw6-633/gkr.go @@ -0,0 +1,865 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bw6-633/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + res := make([]fr.Element, degGJ) + operands := make([]fr.Element, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bw6-633/gkr_test.go b/internal/gkr/bw6-633/gkr_test.go new file mode 100644 index 0000000000..ae8adff951 --- /dev/null +++ b/internal/gkr/bw6-633/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bw6-633/registry.go b/internal/gkr/bw6-633/registry.go new file mode 100644 index 0000000000..d8bda624f1 --- /dev/null +++ b/internal/gkr/bw6-633/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go new file mode 100644 index 0000000000..e369b7c52b --- /dev/null +++ b/internal/gkr/bw6-761/gkr.go @@ -0,0 +1,865 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bw6-761/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step fr.Element + + res := make([]fr.Element, degGJ) + operands := make([]fr.Element, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]fr.Element) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bw6-761/gkr_test.go b/internal/gkr/bw6-761/gkr_test.go new file mode 100644 index 0000000000..a7cdf45cd5 --- /dev/null +++ b/internal/gkr/bw6-761/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bw6-761/registry.go b/internal/gkr/bw6-761/registry.go new file mode 100644 index 0000000000..ed7bb6819a --- /dev/null +++ b/internal/gkr/bw6-761/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go new file mode 100644 index 0000000000..70913dd297 --- /dev/null +++ b/internal/gkr/gkr.go @@ -0,0 +1,867 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/parallel" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark//sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...small_rational.SmallRational) small_rational.SmallRational + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]small_rational.SmallRational + claimedEvaluations []small_rational.SmallRational + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a small_rational.SmallRational) small_rational.SmallRational { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]small_rational.SmallRational) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation small_rational.SmallRational + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]small_rational.SmallRational // x in the paper + claimedEvaluations []small_rational.SmallRational // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational.SmallRational) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]small_rational.SmallRational, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step small_rational.SmallRational + + res := make([]small_rational.SmallRational, degGJ) + operands := make([]small_rational.SmallRational, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element small_rational.SmallRational) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) interface{} { + + //defer the proof, return list of claims + evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]small_rational.SmallRational, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []small_rational.SmallRational, evaluation small_rational.SmallRational) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_rational.SmallRational, error) { + res := make([]small_rational.SmallRational, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []small_rational.SmallRational{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]small_rational.SmallRational) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]small_rational.SmallRational) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// TopologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func TopologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := TopologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]small_rational.SmallRational, nbInstances) + } + } + + parallel.Execute(nbInstances, func(start, end int) { + ins := make([]small_rational.SmallRational, maxNbIns) + for i := start; i < end; i++ { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + }) + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]small_rational.SmallRational) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/gkr_test.go b/internal/gkr/gkr_test.go new file mode 100644 index 0000000000..31bd52133a --- /dev/null +++ b/internal/gkr/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/mimc" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []small_rational.SmallRational{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []small_rational.SmallRational{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []small_rational.SmallRational{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]small_rational.SmallRational) { + return func(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []small_rational.SmallRational{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []small_rational.SmallRational{three}, five) + manager.add(wire, []small_rational.SmallRational{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six small_rational.SmallRational + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]small_rational.SmallRational)) { + fullAssignments := make([][]small_rational.SmallRational, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]small_rational.SmallRational, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []small_rational.SmallRational) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := TopologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := TopologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := TopologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { + var sum small_rational.SmallRational + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []small_rational.SmallRational(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []small_rational.SmallRational + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/registry.go b/internal/gkr/registry.go new file mode 100644 index 0000000000..b48f179c20 --- /dev/null +++ b/internal/gkr/registry.go @@ -0,0 +1,374 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(small_rational.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]small_rational.SmallRational, nbIn) + consts := make(small_rational.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + x := make(small_rational.Vector, degreeBound) + x.MustSetRandom() + for i := range x { + fIn[0] = x[i] + for j := range consts { + fIn[j+1].Mul(&x[i], &consts[j]) + } + p[i] = f(fIn...) + } + + // obtain p's coefficients + p, err := interpolate(x, p) + if err != nil { + panic(err) + } + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) +// Note that the runtime is O(len(X)³) +func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { + if len(X) != len(Y) { + return nil, errors.New("X and Y must have the same length") + } + + // solve the system of equations by Gaussian elimination + augmentedRows := make([][]small_rational.SmallRational, len(X)) // the last column is the Y values + for i := range augmentedRows { + augmentedRows[i] = make([]small_rational.SmallRational, len(X)+1) + augmentedRows[i][0].SetOne() + augmentedRows[i][1].Set(&X[i]) + for j := 2; j < len(augmentedRows[i])-1; j++ { + augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) + } + augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) + } + + // make the upper triangle + for i := range len(augmentedRows) - 1 { + // use row i to eliminate the ith element in all rows below + var negInv small_rational.SmallRational + if augmentedRows[i][i].IsZero() { + return nil, errors.New("singular matrix") + } + negInv.Inverse(&augmentedRows[i][i]) + negInv.Neg(&negInv) + for j := i + 1; j < len(augmentedRows); j++ { + var c small_rational.SmallRational + c.Mul(&augmentedRows[j][i], &negInv) + // augmentedRows[j][i].SetZero() omitted + for k := i + 1; k < len(augmentedRows[i]); k++ { + var t small_rational.SmallRational + t.Mul(&augmentedRows[i][k], &c) + augmentedRows[j][k].Add(&augmentedRows[j][k], &t) + } + } + } + + // back substitution + res := make(polynomial.Polynomial, len(X)) + for i := len(augmentedRows) - 1; i >= 0; i-- { + res[i] = augmentedRows[i][len(augmentedRows[i])-1] + for j := i + 1; j < len(augmentedRows[i])-1; j++ { + var t small_rational.SmallRational + t.Mul(&res[j], &augmentedRows[i][j]) + res[i].Sub(&res[i], &t) + } + res[i].Div(&res[i], &augmentedRows[i][i]) + } + + return res, nil +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...small_rational.SmallRational) small_rational.SmallRational { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go new file mode 100644 index 0000000000..9119e58363 --- /dev/null +++ b/internal/gkr/small_rational/gkr.go @@ -0,0 +1,865 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...small_rational.SmallRational) small_rational.SmallRational + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of f + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]small_rational.SmallRational + claimedEvaluations []small_rational.SmallRational + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a small_rational.SmallRational) small_rational.SmallRational { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]small_rational.SmallRational) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation small_rational.SmallRational + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]small_rational.SmallRational // x in the paper + claimedEvaluations []small_rational.SmallRational // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational.SmallRational) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + // newEq.Eq(c.evaluationPoints[k]) + // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() + + // e.Add(e, polynomial.Polynomial(m)) +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + nbGateIn := len(c.inputPreprocessors) + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + s := make([]polynomial.MultiLin, nbGateIn+1) + s[0] = c.eq + copy(s[1:], c.inputPreprocessors) + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + nbInner := len(s) // wrt output, which has high nbOuter and low nbInner + nbOuter := len(s[0]) / 2 + + gJ := make([]small_rational.SmallRational, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { + var step small_rational.SmallRational + + res := make([]small_rational.SmallRational, degGJ) + operands := make([]small_rational.SmallRational, degGJ*nbInner) + + for i := start; i < end; i++ { + + block := nbOuter + i + for j := 0; j < nbInner; j++ { + step.Set(&s[j][i]) + operands[j].Set(&s[j][block]) + step.Sub(&operands[j], &step) + for d := 1; d < degGJ; d++ { + operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + } + } + + _s := 0 + _e := nbInner + for d := 0; d < degGJ; d++ { + summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) + summand.Mul(&summand, &operands[_s]) + res[d].Add(&res[d], &summand) + _s, _e = _e, _e+nbInner + } + } + mu.Lock() + for i := 0; i < len(gJ); i++ { + gJ[i].Add(&gJ[i], &res[i]) + } + mu.Unlock() + } + + const minBlockSize = 64 + + if nbOuter < minBlockSize { + // no parallelization + computeAll(0, nbOuter) + } else { + c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + } + + // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though + + return gJ +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element small_rational.SmallRational) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + c.eq.Fold(element) + } else { + wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) + for i := 0; i < len(c.inputPreprocessors); i++ { + wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) interface{} { + + //defer the proof, return list of claims + evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]small_rational.SmallRational, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []small_rational.SmallRational, evaluation small_rational.SmallRational) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_rational.SmallRational, error) { + res := make([]small_rational.SmallRational, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []small_rational.SmallRational{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]small_rational.SmallRational) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]small_rational.SmallRational) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// TopologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func TopologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := TopologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]small_rational.SmallRational, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]small_rational.SmallRational, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + finalEvalProof := p[i].FinalEvalProof.([]small_rational.SmallRational) + frToBigInts(outs[offset:], finalEvalProof) + offset += len(finalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/small_rational/gkr_test.go b/internal/gkr/small_rational/gkr_test.go new file mode 100644 index 0000000000..31bd52133a --- /dev/null +++ b/internal/gkr/small_rational/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/mimc" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []small_rational.SmallRational{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []small_rational.SmallRational{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []small_rational.SmallRational{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]small_rational.SmallRational) { + return func(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []small_rational.SmallRational{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []small_rational.SmallRational{three}, five) + manager.add(wire, []small_rational.SmallRational{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six small_rational.SmallRational + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]small_rational.SmallRational)) { + fullAssignments := make([][]small_rational.SmallRational, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]small_rational.SmallRational, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []small_rational.SmallRational) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := TopologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := TopologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := TopologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { + var sum small_rational.SmallRational + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []small_rational.SmallRational(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []small_rational.SmallRational + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/small_rational/registry.go b/internal/gkr/small_rational/registry.go new file mode 100644 index 0000000000..b48f179c20 --- /dev/null +++ b/internal/gkr/small_rational/registry.go @@ -0,0 +1,374 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(small_rational.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]small_rational.SmallRational, nbIn) + consts := make(small_rational.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + x := make(small_rational.Vector, degreeBound) + x.MustSetRandom() + for i := range x { + fIn[0] = x[i] + for j := range consts { + fIn[j+1].Mul(&x[i], &consts[j]) + } + p[i] = f(fIn...) + } + + // obtain p's coefficients + p, err := interpolate(x, p) + if err != nil { + panic(err) + } + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) +// Note that the runtime is O(len(X)³) +func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { + if len(X) != len(Y) { + return nil, errors.New("X and Y must have the same length") + } + + // solve the system of equations by Gaussian elimination + augmentedRows := make([][]small_rational.SmallRational, len(X)) // the last column is the Y values + for i := range augmentedRows { + augmentedRows[i] = make([]small_rational.SmallRational, len(X)+1) + augmentedRows[i][0].SetOne() + augmentedRows[i][1].Set(&X[i]) + for j := 2; j < len(augmentedRows[i])-1; j++ { + augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) + } + augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) + } + + // make the upper triangle + for i := range len(augmentedRows) - 1 { + // use row i to eliminate the ith element in all rows below + var negInv small_rational.SmallRational + if augmentedRows[i][i].IsZero() { + return nil, errors.New("singular matrix") + } + negInv.Inverse(&augmentedRows[i][i]) + negInv.Neg(&negInv) + for j := i + 1; j < len(augmentedRows); j++ { + var c small_rational.SmallRational + c.Mul(&augmentedRows[j][i], &negInv) + // augmentedRows[j][i].SetZero() omitted + for k := i + 1; k < len(augmentedRows[i]); k++ { + var t small_rational.SmallRational + t.Mul(&augmentedRows[i][k], &c) + augmentedRows[j][k].Add(&augmentedRows[j][k], &t) + } + } + } + + // back substitution + res := make(polynomial.Polynomial, len(X)) + for i := len(augmentedRows) - 1; i >= 0; i-- { + res[i] = augmentedRows[i][len(augmentedRows[i])-1] + for j := i + 1; j < len(augmentedRows[i])-1; j++ { + var t small_rational.SmallRational + t.Mul(&res[j], &augmentedRows[i][j]) + res[i].Sub(&res[i], &t) + } + res[i].Div(&res[i], &augmentedRows[i][i]) + } + + return res, nil +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...small_rational.SmallRational) small_rational.SmallRational { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/sumcheck/sumcheck.go b/internal/gkr/sumcheck/sumcheck.go new file mode 100644 index 0000000000..e491815a87 --- /dev/null +++ b/internal/gkr/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []small_rational.SmallRational) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a small_rational.SmallRational) small_rational.SmallRational // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []small_rational.SmallRational, remainingChallengeNames *[]string) (small_rational.SmallRational, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return small_rational.SmallRational{}, err + } + } + var res small_rational.SmallRational + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff small_rational.SmallRational + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]small_rational.SmallRational, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff small_rational.SmallRational + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]small_rational.SmallRational, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/sumcheck/sumcheck_test.go b/internal/gkr/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..85230fdb9d --- /dev/null +++ b/internal/gkr/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark//test_vector_utils" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []small_rational.SmallRational{sum} +} + +func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum small_rational.SmallRational +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs small_rational.SmallRational) small_rational.SmallRational { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/test_vector_utils/test_vector_utils.go b/internal/gkr/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..3102a2133d --- /dev/null +++ b/internal/gkr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,185 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "hash" + "reflect" +) + +func ToElement(i int64) *small_rational.SmallRational { + var res small_rational.SmallRational + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res small_rational.SmallRational + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return small_rational.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return small_rational.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []small_rational.SmallRational + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return small_rational.Bytes +} + +func (h *ListHash) BlockSize() int { + return small_rational.Bytes +} + +func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *small_rational.SmallRational) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(small_rational.SmallRational) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/registry.go b/registry.go new file mode 100644 index 0000000000..b48f179c20 --- /dev/null +++ b/registry.go @@ -0,0 +1,374 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(small_rational.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]small_rational.SmallRational, nbIn) + consts := make(small_rational.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + x := make(small_rational.Vector, degreeBound) + x.MustSetRandom() + for i := range x { + fIn[0] = x[i] + for j := range consts { + fIn[j+1].Mul(&x[i], &consts[j]) + } + p[i] = f(fIn...) + } + + // obtain p's coefficients + p, err := interpolate(x, p) + if err != nil { + panic(err) + } + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) +// Note that the runtime is O(len(X)³) +func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { + if len(X) != len(Y) { + return nil, errors.New("X and Y must have the same length") + } + + // solve the system of equations by Gaussian elimination + augmentedRows := make([][]small_rational.SmallRational, len(X)) // the last column is the Y values + for i := range augmentedRows { + augmentedRows[i] = make([]small_rational.SmallRational, len(X)+1) + augmentedRows[i][0].SetOne() + augmentedRows[i][1].Set(&X[i]) + for j := 2; j < len(augmentedRows[i])-1; j++ { + augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) + } + augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) + } + + // make the upper triangle + for i := range len(augmentedRows) - 1 { + // use row i to eliminate the ith element in all rows below + var negInv small_rational.SmallRational + if augmentedRows[i][i].IsZero() { + return nil, errors.New("singular matrix") + } + negInv.Inverse(&augmentedRows[i][i]) + negInv.Neg(&negInv) + for j := i + 1; j < len(augmentedRows); j++ { + var c small_rational.SmallRational + c.Mul(&augmentedRows[j][i], &negInv) + // augmentedRows[j][i].SetZero() omitted + for k := i + 1; k < len(augmentedRows[i]); k++ { + var t small_rational.SmallRational + t.Mul(&augmentedRows[i][k], &c) + augmentedRows[j][k].Add(&augmentedRows[j][k], &t) + } + } + } + + // back substitution + res := make(polynomial.Polynomial, len(X)) + for i := len(augmentedRows) - 1; i >= 0; i-- { + res[i] = augmentedRows[i][len(augmentedRows[i])-1] + for j := i + 1; j < len(augmentedRows[i])-1; j++ { + var t small_rational.SmallRational + t.Mul(&res[j], &augmentedRows[i][j]) + res[i].Sub(&res[i], &t) + } + res[i].Div(&res[i], &augmentedRows[i][i]) + } + + return res, nil +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...small_rational.SmallRational) small_rational.SmallRational { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/sumcheck/sumcheck.go b/sumcheck/sumcheck.go new file mode 100644 index 0000000000..e491815a87 --- /dev/null +++ b/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []small_rational.SmallRational) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a small_rational.SmallRational) small_rational.SmallRational // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []small_rational.SmallRational, remainingChallengeNames *[]string) (small_rational.SmallRational, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return small_rational.SmallRational{}, err + } + } + var res small_rational.SmallRational + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff small_rational.SmallRational + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]small_rational.SmallRational, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff small_rational.SmallRational + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]small_rational.SmallRational, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/sumcheck/sumcheck_test.go b/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..85230fdb9d --- /dev/null +++ b/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark//test_vector_utils" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []small_rational.SmallRational{sum} +} + +func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum small_rational.SmallRational +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs small_rational.SmallRational) small_rational.SmallRational { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/test_vector_utils/test_vector_utils.go b/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..3102a2133d --- /dev/null +++ b/test_vector_utils/test_vector_utils.go @@ -0,0 +1,185 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "hash" + "reflect" +) + +func ToElement(i int64) *small_rational.SmallRational { + var res small_rational.SmallRational + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res small_rational.SmallRational + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return small_rational.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return small_rational.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []small_rational.SmallRational + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return small_rational.Bytes +} + +func (h *ListHash) BlockSize() int { + return small_rational.Bytes +} + +func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *small_rational.SmallRational) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(small_rational.SmallRational) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} From 43ea601b0d28c824352e2fa7fe44ba927d30ce3d Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 2 Apr 2025 12:36:39 -0500 Subject: [PATCH 37/62] generate gkr test vec generator --- internal/generator/backend/main.go | 19 +- .../gkr/test_vectors/gkr/gkr-gen-vectors.go | 349 ++++++++++++++++++ internal/gkr/test_vectors/main.go | 13 + 3 files changed, 377 insertions(+), 4 deletions(-) create mode 100644 internal/gkr/test_vectors/gkr/gkr-gen-vectors.go create mode 100644 internal/gkr/test_vectors/main.go diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 30854f97e0..5053586ea0 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -221,8 +221,8 @@ func main() { wg.Add(1) // GKR test vectors go func() { - // generate sumcheck for small-rational - err := generateGkrBackend(gkrConfig{ + // generate gkr and sumcheck for small-rational + cfg := gkrConfig{ FieldDependency: config.FieldDependency{ ElementType: "small_rational.SmallRational", FieldPackagePath: "github.com/consensys/gnark/internal/small_rational", @@ -230,8 +230,19 @@ func main() { }, GkrPackageRelativePath: "internal/gkr/small_rational", CanUseFFT: false, - }) - assertNoError(err) + } + assertNoError(generateGkrBackend(cfg)) + + // generate gkr test vector generator + cfg.GenerateTestVectors = true + cfg.OutsideGkrPackage = true + + assertNoError(bgen.Generate(cfg, "gkr", "./template/gkr/", + bavard.Entry{ + File: "../../gkr/test_vectors/gkr/gkr-gen-vectors.go", + Templates: []string{"gkr.test.vectors.gen.go.tmpl", "gkr.test.vectors.go.tmpl"}, + }, + )) fmt.Println("generating test vectors for sumcheck") assertNoError(sumcheckTestVectors.Generate()) // TODO CRITICAL This must be an independent process so that it's compiled before being run] diff --git a/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go b/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go new file mode 100644 index 0000000000..598a2da702 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go @@ -0,0 +1,349 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "hash" + "os" + "path/filepath" + "reflect" + + "github.com/consensys/bavard" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" +) + +func main() { + if err := GenerateVectors(); err != nil { + fmt.Println(err.Error()) + os.Exit(-1) + } +} + +func GenerateVectors() error { + testDirPath, err := filepath.Abs("gkr/test_vectors") + if err != nil { + return err + } + + fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) + + dirEntries, err := os.ReadDir(testDirPath) + if err != nil { + return err + } + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + if !bavard.ShouldGenerate(path) { + continue + } + fmt.Println("\tprocessing", dirEntry.Name()) + if err = run(path); err != nil { + return err + } + } + } + } + + return nil +} + +func run(absPath string) error { + testCase, err := newTestCase(absPath) + if err != nil { + return err + } + + transcriptSetting := fiatshamir.WithHash(testCase.Hash) + + var proof gkr.Proof + proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) + if err != nil { + return err + } + + if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { + return err + } + var outBytes []byte + if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { + if err = os.WriteFile(absPath, outBytes, 0); err != nil { + return err + } + } else { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) + if err != nil { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + if err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { + res := make(PrintableProof, len(proof)) + + for i := range proof { + + partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) + for k, partialK := range proof[i].PartialSumPolys { + partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) + } + + res[i] = PrintableSumcheckProof{ + FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), + PartialSumPolys: partialSumPolys, + } + } + return res, nil +} + +type WireInfo struct { + Gate gkr.GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]gkr.Circuit) + +func getCircuit(path string) (gkr.Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit gkr.Circuit) { + circuit = make(gkr.Circuit, len(c)) + for i := range c { + circuit[i].Gate = gkr.GetGate(c[i].Gate) + circuit[i].Inputs = make([]*gkr.Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { + var sum small_rational.SmallRational + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC gkr.GateName = "mimc" + SelectInput3 gkr.GateName = "select-input-3" +) + +func init() { + if err := gkr.RegisterGate(MiMC, mimcRound, 2, gkr.WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := gkr.RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { + return input[2] + }, 3, gkr.WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (gkr.Proof, error) { + proof := make(gkr.Proof, len(printable)) + for i := range printable { + finalEvalProof := []small_rational.SmallRational(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit gkr.Circuit + Hash hash.Hash + Proof gkr.Proof + FullAssignment gkr.WireAssignment + InOutAssignment gkr.WireAssignment + Info TestCaseInfo // we are generating the test vectors, so we need to keep the circuit instance info to ADD the proof to it and resave it +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit gkr.Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof gkr.Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(gkr.WireAssignment) + inOutAssignment := make(gkr.WireAssignment) + + sorted := gkr.TopologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []small_rational.SmallRational + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + info.Output = make([][]interface{}, 0, outI) + + for _, w := range sorted { + if w.IsOutput() { + + info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + Info: info, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} diff --git a/internal/gkr/test_vectors/main.go b/internal/gkr/test_vectors/main.go new file mode 100644 index 0000000000..9031b75587 --- /dev/null +++ b/internal/gkr/test_vectors/main.go @@ -0,0 +1,13 @@ +package main + +import "github.com/consensys/gnark/internal/gkr/test_vectors/sumcheck" + +func main() { + assertNoError(sumcheck.Generate()) +} + +func assertNoError(err error) { + if err != nil { + panic(err) + } +} From bc5306667ca90ac4fb122cc579755c24a1554031 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 3 Apr 2025 19:46:09 -0500 Subject: [PATCH 38/62] feat nomimc option --- internal/generator/backend/main.go | 4 +- .../backend/template/gkr/gkr.test.go.tmpl | 43 ++++++------ internal/gkr/bls12-377/gkr_test.go | 5 +- internal/gkr/bls12-381/gkr_test.go | 5 +- internal/gkr/bls24-315/gkr_test.go | 5 +- internal/gkr/bls24-317/gkr_test.go | 5 +- internal/gkr/bn254/gkr_test.go | 5 +- internal/gkr/bw6-633/gkr_test.go | 5 +- internal/gkr/bw6-761/gkr_test.go | 5 +- internal/gkr/small_rational/gkr_test.go | 68 +++++-------------- 10 files changed, 58 insertions(+), 92 deletions(-) diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 5053586ea0..5ddfad66b4 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "github.com/consensys/gnark-crypto/field/generator/config" sumcheckTestVectors "github.com/consensys/gnark/internal/gkr/test_vectors/sumcheck" "os" "os/exec" @@ -11,7 +12,6 @@ import ( "github.com/consensys/bavard" "github.com/consensys/gnark-crypto/field/generator" - "github.com/consensys/gnark-crypto/field/generator/config" ) const copyrightHolder = "Consensys Software Inc." @@ -230,6 +230,7 @@ func main() { }, GkrPackageRelativePath: "internal/gkr/small_rational", CanUseFFT: false, + NoMiMC: true, } assertNoError(generateGkrBackend(cfg)) @@ -317,6 +318,7 @@ type gkrConfig struct { CanUseFFT bool OutsideGkrPackage bool GenerateTestVectors bool + NoMiMC bool // if the MiMC hash is not implemented for the field } func assertNoError(err error) { diff --git a/internal/generator/backend/template/gkr/gkr.test.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.go.tmpl index 1000786d34..59c6bd1c3b 100644 --- a/internal/generator/backend/template/gkr/gkr.test.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.go.tmpl @@ -1,10 +1,13 @@ import ( "{{.FieldPackagePath}}" - "{{.FieldPackagePath}}/mimc" + {{- if not .NoMiMC }} + "{{.FieldPackagePath}}/mimc" + "time" + {{- end }} "{{.FieldPackagePath}}/polynomial" - "{{.FieldPackagePath}}/sumcheck" - "{{.FieldPackagePath}}/test_vector_utils" + "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/sumcheck" + "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/test_vector_utils" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/stretchr/testify/assert" @@ -16,7 +19,6 @@ import ( "path/filepath" "encoding/json" "reflect" - "time" ) {{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} @@ -439,6 +441,7 @@ func proofEquals(expected Proof, seen Proof) error { return nil } +{{- if not .NoMiMC }} func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) @@ -471,6 +474,8 @@ func BenchmarkGkrMimc17(b *testing.B) { benchmarkGkrMiMC(b, 1<<17, 91) } +{{- end }} + func TestTopSortTrivial(t *testing.T) { c := make(Circuit, 2) c[0].Inputs = []*Wire{&c[1]} @@ -510,7 +515,7 @@ func TestTopSortWide(t *testing.T) { {{template "gkrTestVectors" .}} func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + testGate := func(name GateName, f func(...{{.ElementType}}) {{.ElementType}}, nbIn, degree int) { t.Run(string(name), func(t *testing.T) { name = name + "-register-gate-test" @@ -526,27 +531,27 @@ func TestRegisterGateDegreeDetection(t *testing.T) { }) } - testGate("select", func(x ...fr.Element) fr.Element { + testGate("select", func(x ...{{.ElementType}}) {{.ElementType}} { return x[0] }, 3, 1) - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element + testGate("add2", func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} res.Add(&x[0], &x[1]) res.Add(&res, &x[2]) return res }, 3, 1) - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element + testGate("mul2", func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} res.Mul(&x[0], &x[1]) return res }, 2, 2) testGate("mimc", mimcRound, 2, 7) - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element + testGate("sub2PlusOne", func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} res. SetOne(). Add(&res, &x[0]). @@ -558,8 +563,8 @@ func TestRegisterGateDegreeDetection(t *testing.T) { t.Run("zero", func(t *testing.T) { const gateName GateName = "zero-register-gate-test" expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element + zeroGate := func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} return res } assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) @@ -571,19 +576,19 @@ func TestRegisterGateDegreeDetection(t *testing.T) { func TestIsAdditive(t *testing.T) { // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { + f := func(x ...{{.ElementType}}) {{.ElementType}} { if len(x) != 2 { panic("bivariate input needed") } - var res fr.Element + var res {{.ElementType}} res.Add(&x[0], &x[1]) res.Mul(&res, &x[0]) return res } // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element + g := func(x ...{{.ElementType}}) {{.ElementType}} { + var res, y3 {{.ElementType}} res.Square(&x[0]) y3.Mul(&x[1], &three) res.Add(&res, &y3) @@ -592,7 +597,7 @@ func TestIsAdditive(t *testing.T) { // h: x -> 2x // but it edits it input - h := func(x ...fr.Element) fr.Element { + h := func(x ...{{.ElementType}}) {{.ElementType}} { x[0].Double(&x[0]) return x[0] } diff --git a/internal/gkr/bls12-377/gkr_test.go b/internal/gkr/bls12-377/gkr_test.go index acc38d35af..0e204dd71d 100644 --- a/internal/gkr/bls12-377/gkr_test.go +++ b/internal/gkr/bls12-377/gkr_test.go @@ -11,10 +11,10 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/test_vector_utils" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls12-377/sumcheck" + "github.com/consensys/gnark/internal/gkr/bls12-377/test_vector_utils" "github.com/stretchr/testify/assert" "hash" "os" @@ -442,7 +442,6 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } - func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bls12-381/gkr_test.go b/internal/gkr/bls12-381/gkr_test.go index d8964182f6..8cd3506e88 100644 --- a/internal/gkr/bls12-381/gkr_test.go +++ b/internal/gkr/bls12-381/gkr_test.go @@ -11,10 +11,10 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/test_vector_utils" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls12-381/sumcheck" + "github.com/consensys/gnark/internal/gkr/bls12-381/test_vector_utils" "github.com/stretchr/testify/assert" "hash" "os" @@ -442,7 +442,6 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } - func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bls24-315/gkr_test.go b/internal/gkr/bls24-315/gkr_test.go index 04dd26a153..1f90259342 100644 --- a/internal/gkr/bls24-315/gkr_test.go +++ b/internal/gkr/bls24-315/gkr_test.go @@ -11,10 +11,10 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/test_vector_utils" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls24-315/sumcheck" + "github.com/consensys/gnark/internal/gkr/bls24-315/test_vector_utils" "github.com/stretchr/testify/assert" "hash" "os" @@ -442,7 +442,6 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } - func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bls24-317/gkr_test.go b/internal/gkr/bls24-317/gkr_test.go index c647ee01ea..440774cd2b 100644 --- a/internal/gkr/bls24-317/gkr_test.go +++ b/internal/gkr/bls24-317/gkr_test.go @@ -11,10 +11,10 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/test_vector_utils" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls24-317/sumcheck" + "github.com/consensys/gnark/internal/gkr/bls24-317/test_vector_utils" "github.com/stretchr/testify/assert" "hash" "os" @@ -442,7 +442,6 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } - func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bn254/gkr_test.go b/internal/gkr/bn254/gkr_test.go index 9ac49f5cc0..69c3f02bb6 100644 --- a/internal/gkr/bn254/gkr_test.go +++ b/internal/gkr/bn254/gkr_test.go @@ -11,10 +11,10 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/test_vector_utils" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bn254/sumcheck" + "github.com/consensys/gnark/internal/gkr/bn254/test_vector_utils" "github.com/stretchr/testify/assert" "hash" "os" @@ -442,7 +442,6 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } - func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bw6-633/gkr_test.go b/internal/gkr/bw6-633/gkr_test.go index ae8adff951..b732924a2b 100644 --- a/internal/gkr/bw6-633/gkr_test.go +++ b/internal/gkr/bw6-633/gkr_test.go @@ -11,10 +11,10 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/test_vector_utils" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bw6-633/sumcheck" + "github.com/consensys/gnark/internal/gkr/bw6-633/test_vector_utils" "github.com/stretchr/testify/assert" "hash" "os" @@ -442,7 +442,6 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } - func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bw6-761/gkr_test.go b/internal/gkr/bw6-761/gkr_test.go index a7cdf45cd5..ab54c82bf4 100644 --- a/internal/gkr/bw6-761/gkr_test.go +++ b/internal/gkr/bw6-761/gkr_test.go @@ -11,10 +11,10 @@ import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/sumcheck" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/test_vector_utils" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bw6-761/sumcheck" + "github.com/consensys/gnark/internal/gkr/bw6-761/test_vector_utils" "github.com/stretchr/testify/assert" "hash" "os" @@ -442,7 +442,6 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } - func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/small_rational/gkr_test.go b/internal/gkr/small_rational/gkr_test.go index 31bd52133a..c55fd76683 100644 --- a/internal/gkr/small_rational/gkr_test.go +++ b/internal/gkr/small_rational/gkr_test.go @@ -10,11 +10,10 @@ import ( "fmt" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" + "github.com/consensys/gnark/internal/gkr/small_rational/test_vector_utils" "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/mimc" "github.com/consensys/gnark/internal/small_rational/polynomial" - "github.com/consensys/gnark/internal/small_rational/sumcheck" - "github.com/consensys/gnark/internal/small_rational/test_vector_utils" "github.com/stretchr/testify/assert" "hash" "os" @@ -22,7 +21,6 @@ import ( "reflect" "strconv" "testing" - "time" ) func TestNoGateTwoInstances(t *testing.T) { @@ -443,38 +441,6 @@ func proofEquals(expected Proof, seen Proof) error { return nil } -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - func TestTopSortTrivial(t *testing.T) { c := make(Circuit, 2) c[0].Inputs = []*Wire{&c[1]} @@ -732,7 +698,7 @@ func newTestCase(path string) (*TestCase, error) { } func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + testGate := func(name GateName, f func(...small_rational.SmallRational) small_rational.SmallRational, nbIn, degree int) { t.Run(string(name), func(t *testing.T) { name = name + "-register-gate-test" @@ -748,27 +714,27 @@ func TestRegisterGateDegreeDetection(t *testing.T) { }) } - testGate("select", func(x ...fr.Element) fr.Element { + testGate("select", func(x ...small_rational.SmallRational) small_rational.SmallRational { return x[0] }, 3, 1) - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element + testGate("add2", func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational res.Add(&x[0], &x[1]) res.Add(&res, &x[2]) return res }, 3, 1) - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element + testGate("mul2", func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational res.Mul(&x[0], &x[1]) return res }, 2, 2) testGate("mimc", mimcRound, 2, 7) - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element + testGate("sub2PlusOne", func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational res. SetOne(). Add(&res, &x[0]). @@ -780,8 +746,8 @@ func TestRegisterGateDegreeDetection(t *testing.T) { t.Run("zero", func(t *testing.T) { const gateName GateName = "zero-register-gate-test" expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element + zeroGate := func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational return res } assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) @@ -793,19 +759,19 @@ func TestRegisterGateDegreeDetection(t *testing.T) { func TestIsAdditive(t *testing.T) { // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { + f := func(x ...small_rational.SmallRational) small_rational.SmallRational { if len(x) != 2 { panic("bivariate input needed") } - var res fr.Element + var res small_rational.SmallRational res.Add(&x[0], &x[1]) res.Mul(&res, &x[0]) return res } // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element + g := func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res, y3 small_rational.SmallRational res.Square(&x[0]) y3.Mul(&x[1], &three) res.Add(&res, &y3) @@ -814,7 +780,7 @@ func TestIsAdditive(t *testing.T) { // h: x -> 2x // but it edits it input - h := func(x ...fr.Element) fr.Element { + h := func(x ...small_rational.SmallRational) small_rational.SmallRational { x[0].Double(&x[0]) return x[0] } From 4168046fd80ac8f575ef6a41b8eb8927e08c8914 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 3 Apr 2025 20:03:23 -0500 Subject: [PATCH 39/62] fix: no tests for small rational gkr --- internal/generator/backend/main.go | 21 +- .../backend/template/gkr/gkr.test.go.tmpl | 41 +- internal/gkr/bls12-377/gkr_test.go | 1 + internal/gkr/bls12-381/gkr_test.go | 1 + internal/gkr/bls24-315/gkr_test.go | 1 + internal/gkr/bls24-317/gkr_test.go | 1 + internal/gkr/bn254/gkr_test.go | 1 + internal/gkr/bw6-633/gkr_test.go | 1 + internal/gkr/bw6-761/gkr_test.go | 1 + internal/gkr/small_rational/gkr_test.go | 795 ------------------ 10 files changed, 39 insertions(+), 825 deletions(-) delete mode 100644 internal/gkr/small_rational/gkr_test.go diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 5ddfad66b4..d6cecfa91a 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -230,7 +230,7 @@ func main() { }, GkrPackageRelativePath: "internal/gkr/small_rational", CanUseFFT: false, - NoMiMC: true, + NoGkrTests: true, } assertNoError(generateGkrBackend(cfg)) @@ -293,6 +293,7 @@ func generateGkrBackend(cfg gkrConfig) error { {File: filepath.Join(packageDir, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, {File: filepath.Join(packageDir, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl"}}, } + if err := bgen.Generate(cfg, "sumcheck", "./template/gkr/", entries...); err != nil { return err } @@ -302,7 +303,12 @@ func generateGkrBackend(cfg gkrConfig) error { entries = []bavard.Entry{ {File: filepath.Join(packageDir, "gkr.go"), Templates: []string{"gkr.go.tmpl"}}, {File: filepath.Join(packageDir, "registry.go"), Templates: []string{"registry.go.tmpl"}}, - {File: filepath.Join(packageDir, "gkr_test.go"), Templates: []string{"gkr.test.go.tmpl", "gkr.test.vectors.go.tmpl"}}, + } + + if !cfg.NoGkrTests { + entries = append(entries, bavard.Entry{ + File: filepath.Join(packageDir, "gkr_test.go"), Templates: []string{"gkr.test.go.tmpl", "gkr.test.vectors.go.tmpl"}, + }) } if err := bgen.Generate(cfg, "gkr", "./template/gkr/", entries...); err != nil { @@ -314,11 +320,12 @@ func generateGkrBackend(cfg gkrConfig) error { type gkrConfig struct { config.FieldDependency - GkrPackageRelativePath string - CanUseFFT bool - OutsideGkrPackage bool - GenerateTestVectors bool - NoMiMC bool // if the MiMC hash is not implemented for the field + GkrPackageRelativePath string // the GKR package, relative to the repo root + TestVectorsRelativePath string // the test vectors, relative to the current package + CanUseFFT bool + OutsideGkrPackage bool + GenerateTestVectors bool + NoGkrTests bool } func assertNoError(err error) { diff --git a/internal/generator/backend/template/gkr/gkr.test.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.go.tmpl index 59c6bd1c3b..d79c465519 100644 --- a/internal/generator/backend/template/gkr/gkr.test.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.go.tmpl @@ -1,10 +1,7 @@ import ( "{{.FieldPackagePath}}" - {{- if not .NoMiMC }} - "{{.FieldPackagePath}}/mimc" - "time" - {{- end }} + "{{.FieldPackagePath}}/mimc" "{{.FieldPackagePath}}/polynomial" "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/sumcheck" "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/test_vector_utils" @@ -19,6 +16,7 @@ import ( "path/filepath" "encoding/json" "reflect" + "time" ) {{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} @@ -400,7 +398,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "" + testDirPath := "{{.TestVectorsRelativePath}}" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { @@ -441,7 +439,6 @@ func proofEquals(expected Proof, seen Proof) error { return nil } -{{- if not .NoMiMC }} func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) @@ -474,8 +471,6 @@ func BenchmarkGkrMimc17(b *testing.B) { benchmarkGkrMiMC(b, 1<<17, 91) } -{{- end }} - func TestTopSortTrivial(t *testing.T) { c := make(Circuit, 2) c[0].Inputs = []*Wire{&c[1]} @@ -515,7 +510,7 @@ func TestTopSortWide(t *testing.T) { {{template "gkrTestVectors" .}} func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...{{.ElementType}}) {{.ElementType}}, nbIn, degree int) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { t.Run(string(name), func(t *testing.T) { name = name + "-register-gate-test" @@ -531,27 +526,27 @@ func TestRegisterGateDegreeDetection(t *testing.T) { }) } - testGate("select", func(x ...{{.ElementType}}) {{.ElementType}} { + testGate("select", func(x ...fr.Element) fr.Element { return x[0] }, 3, 1) - testGate("add2", func(x ...{{.ElementType}}) {{.ElementType}} { - var res {{.ElementType}} + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element res.Add(&x[0], &x[1]) res.Add(&res, &x[2]) return res }, 3, 1) - testGate("mul2", func(x ...{{.ElementType}}) {{.ElementType}} { - var res {{.ElementType}} + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element res.Mul(&x[0], &x[1]) return res }, 2, 2) testGate("mimc", mimcRound, 2, 7) - testGate("sub2PlusOne", func(x ...{{.ElementType}}) {{.ElementType}} { - var res {{.ElementType}} + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element res. SetOne(). Add(&res, &x[0]). @@ -563,8 +558,8 @@ func TestRegisterGateDegreeDetection(t *testing.T) { t.Run("zero", func(t *testing.T) { const gateName GateName = "zero-register-gate-test" expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...{{.ElementType}}) {{.ElementType}} { - var res {{.ElementType}} + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element return res } assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) @@ -576,19 +571,19 @@ func TestRegisterGateDegreeDetection(t *testing.T) { func TestIsAdditive(t *testing.T) { // f: x,y -> x² + xy - f := func(x ...{{.ElementType}}) {{.ElementType}} { + f := func(x ...fr.Element) fr.Element { if len(x) != 2 { panic("bivariate input needed") } - var res {{.ElementType}} + var res fr.Element res.Add(&x[0], &x[1]) res.Mul(&res, &x[0]) return res } // g: x,y -> x² + 3y - g := func(x ...{{.ElementType}}) {{.ElementType}} { - var res, y3 {{.ElementType}} + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element res.Square(&x[0]) y3.Mul(&x[1], &three) res.Add(&res, &y3) @@ -597,7 +592,7 @@ func TestIsAdditive(t *testing.T) { // h: x -> 2x // but it edits it input - h := func(x ...{{.ElementType}}) {{.ElementType}} { + h := func(x ...fr.Element) fr.Element { x[0].Double(&x[0]) return x[0] } diff --git a/internal/gkr/bls12-377/gkr_test.go b/internal/gkr/bls12-377/gkr_test.go index 0e204dd71d..1b214822be 100644 --- a/internal/gkr/bls12-377/gkr_test.go +++ b/internal/gkr/bls12-377/gkr_test.go @@ -442,6 +442,7 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } + func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bls12-381/gkr_test.go b/internal/gkr/bls12-381/gkr_test.go index 8cd3506e88..3f6716a5a3 100644 --- a/internal/gkr/bls12-381/gkr_test.go +++ b/internal/gkr/bls12-381/gkr_test.go @@ -442,6 +442,7 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } + func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bls24-315/gkr_test.go b/internal/gkr/bls24-315/gkr_test.go index 1f90259342..ecdd478dde 100644 --- a/internal/gkr/bls24-315/gkr_test.go +++ b/internal/gkr/bls24-315/gkr_test.go @@ -442,6 +442,7 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } + func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bls24-317/gkr_test.go b/internal/gkr/bls24-317/gkr_test.go index 440774cd2b..91b49d4f89 100644 --- a/internal/gkr/bls24-317/gkr_test.go +++ b/internal/gkr/bls24-317/gkr_test.go @@ -442,6 +442,7 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } + func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bn254/gkr_test.go b/internal/gkr/bn254/gkr_test.go index 69c3f02bb6..1cc04d21dd 100644 --- a/internal/gkr/bn254/gkr_test.go +++ b/internal/gkr/bn254/gkr_test.go @@ -442,6 +442,7 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } + func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bw6-633/gkr_test.go b/internal/gkr/bw6-633/gkr_test.go index b732924a2b..5127ffc3d7 100644 --- a/internal/gkr/bw6-633/gkr_test.go +++ b/internal/gkr/bw6-633/gkr_test.go @@ -442,6 +442,7 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } + func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/bw6-761/gkr_test.go b/internal/gkr/bw6-761/gkr_test.go index ab54c82bf4..93f16005b2 100644 --- a/internal/gkr/bw6-761/gkr_test.go +++ b/internal/gkr/bw6-761/gkr_test.go @@ -442,6 +442,7 @@ func proofEquals(expected Proof, seen Proof) error { } return nil } + func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { fmt.Println("creating circuit structure") c := mimcCircuit(mimcDepth) diff --git a/internal/gkr/small_rational/gkr_test.go b/internal/gkr/small_rational/gkr_test.go deleted file mode 100644 index c55fd76683..0000000000 --- a/internal/gkr/small_rational/gkr_test.go +++ /dev/null @@ -1,795 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" - "github.com/consensys/gnark/internal/gkr/small_rational/test_vector_utils" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []small_rational.SmallRational{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []small_rational.SmallRational{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []small_rational.SmallRational{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]small_rational.SmallRational) { - return func(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []small_rational.SmallRational{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []small_rational.SmallRational{three}, five) - manager.add(wire, []small_rational.SmallRational{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six small_rational.SmallRational - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]small_rational.SmallRational)) { - fullAssignments := make([][]small_rational.SmallRational, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]small_rational.SmallRational, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []small_rational.SmallRational) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := TopologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := TopologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := TopologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { - var sum small_rational.SmallRational - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []small_rational.SmallRational(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []small_rational.SmallRational - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...small_rational.SmallRational) small_rational.SmallRational, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...small_rational.SmallRational) small_rational.SmallRational { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...small_rational.SmallRational) small_rational.SmallRational { - if len(x) != 2 { - panic("bivariate input needed") - } - var res small_rational.SmallRational - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res, y3 small_rational.SmallRational - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...small_rational.SmallRational) small_rational.SmallRational { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} From c2b83fc2cd069be73f5e2fe0fc6d61b15e921ceb Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 3 Apr 2025 20:15:45 -0500 Subject: [PATCH 40/62] fix: gkr vec gen imports --- internal/generator/backend/main.go | 10 +++--- .../template/gkr/gkr.test.vectors.gen.go.tmpl | 35 +++++++------------ .../gkr/test_vectors/gkr/gkr-gen-vectors.go | 22 ++++-------- internal/gkr/test_vectors/main.go | 8 +++-- .../sumcheck/sumcheck-gen-vectors.go | 2 +- 5 files changed, 32 insertions(+), 45 deletions(-) diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index d6cecfa91a..fe0a843a93 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -3,7 +3,6 @@ package main import ( "fmt" "github.com/consensys/gnark-crypto/field/generator/config" - sumcheckTestVectors "github.com/consensys/gnark/internal/gkr/test_vectors/sumcheck" "os" "os/exec" "path/filepath" @@ -245,10 +244,11 @@ func main() { }, )) - fmt.Println("generating test vectors for sumcheck") - assertNoError(sumcheckTestVectors.Generate()) // TODO CRITICAL This must be an independent process so that it's compiled before being run] - // TODO it also needs to run after everything else is done - + fmt.Println("generating test vectors for gkr and sumcheck") + /*cmd := exec.Command("go", "run", "../../gkr/test_vectors") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + assertNoError(cmd.Run())*/ wg.Done() }() diff --git a/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl index 832188f3d3..0386c6635d 100644 --- a/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl @@ -1,28 +1,19 @@ import ( - "encoding/json" - "fmt" - "hash" - "os" - "path/filepath" - "reflect" - - "github.com/consensys/bavard" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" - + "encoding/json" + "fmt" + "github.com/consensys/bavard" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/small_rational" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" + "hash" + "os" + "path/filepath" + "reflect" ) -func main() { - if err := GenerateVectors(); err != nil { - fmt.Println(err.Error()) - os.Exit(-1) - } -} - func GenerateVectors() error { testDirPath, err := filepath.Abs("gkr/test_vectors") if err != nil { diff --git a/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go b/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go index 598a2da702..1b0baff158 100644 --- a/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go +++ b/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go @@ -8,27 +8,19 @@ package gkr import ( "encoding/json" "fmt" + "github.com/consensys/bavard" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/small_rational" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" "hash" "os" "path/filepath" "reflect" - - "github.com/consensys/bavard" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" ) -func main() { - if err := GenerateVectors(); err != nil { - fmt.Println(err.Error()) - os.Exit(-1) - } -} - func GenerateVectors() error { testDirPath, err := filepath.Abs("gkr/test_vectors") if err != nil { diff --git a/internal/gkr/test_vectors/main.go b/internal/gkr/test_vectors/main.go index 9031b75587..a551f0713c 100644 --- a/internal/gkr/test_vectors/main.go +++ b/internal/gkr/test_vectors/main.go @@ -1,9 +1,13 @@ package main -import "github.com/consensys/gnark/internal/gkr/test_vectors/sumcheck" +import ( + "github.com/consensys/gnark/internal/gkr/test_vectors/gkr" + "github.com/consensys/gnark/internal/gkr/test_vectors/sumcheck" +) func main() { - assertNoError(sumcheck.Generate()) + assertNoError(sumcheck.GenerateVectors()) + assertNoError(gkr.GenerateVectors()) } func assertNoError(err error) { diff --git a/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go index 7917a3b60c..52c2d0b89d 100644 --- a/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go +++ b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go @@ -75,7 +75,7 @@ func run(testCaseInfo *TestCaseInfo) error { } } -func Generate() error { +func GenerateVectors() error { // read the test vectors file, generate the proof, make sure it verifies, // and add the proof to the same file const relPath = "sumcheck/test_vectors/vectors.json" From efbed6aef1f0b43d5d1ef307ba772d8912e7b970 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 3 Apr 2025 20:28:18 -0500 Subject: [PATCH 41/62] fix: gkr test vec gen --- .../backend/gkr/test_vectors/main.go | 349 ------------------ internal/generator/backend/main.go | 4 +- .../backend/sumcheck/test_vectors/main.go | 199 ---------- .../sumcheck/test_vectors/vectors.json | 56 --- .../template/gkr/gkr.test.vectors.gen.go.tmpl | 2 +- .../gkr}/circuits/mimc_five_levels.json | 0 .../gkr}/circuits/single_identity_gate.json | 0 .../single_input_two_identity_gates.json | 0 .../gkr}/circuits/single_input_two_outs.json | 0 .../gkr}/circuits/single_mimc_gate.json | 0 .../gkr}/circuits/single_mul_gate.json | 0 ..._identity_gates_composed_single_input.json | 0 .../two_inputs_select-input-3_gate.json | 0 .../gkr/test_vectors/gkr/gkr-gen-vectors.go | 2 +- .../gkr}/mimc_five_levels_two_instances._json | 2 +- .../single_identity_gate_two_instances.json | 2 +- ...nput_two_identity_gates_two_instances.json | 2 +- .../single_input_two_outs_two_instances.json | 2 +- .../gkr}/single_mimc_gate_four_instances.json | 2 +- .../gkr}/single_mimc_gate_two_instances.json | 2 +- .../gkr}/single_mul_gate_two_instances.json | 2 +- ...s_composed_single_input_two_instances.json | 2 +- ...uts_select-input-3_gate_two_instances.json | 2 +- 23 files changed, 13 insertions(+), 617 deletions(-) delete mode 100644 internal/generator/backend/gkr/test_vectors/main.go delete mode 100644 internal/generator/backend/sumcheck/test_vectors/main.go delete mode 100644 internal/generator/backend/sumcheck/test_vectors/vectors.json rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/circuits/mimc_five_levels.json (100%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/circuits/single_identity_gate.json (100%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/circuits/single_input_two_identity_gates.json (100%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/circuits/single_input_two_outs.json (100%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/circuits/single_mimc_gate.json (100%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/circuits/single_mul_gate.json (100%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/circuits/two_identity_gates_composed_single_input.json (100%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/circuits/two_inputs_select-input-3_gate.json (100%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/mimc_five_levels_two_instances._json (83%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/single_identity_gate_two_instances.json (85%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/single_input_two_identity_gates_two_instances.json (87%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/single_input_two_outs_two_instances.json (89%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/single_mimc_gate_four_instances.json (93%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/single_mimc_gate_two_instances.json (91%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/single_mul_gate_two_instances.json (89%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/two_identity_gates_composed_single_input_two_instances.json (84%) rename internal/{generator/backend/gkr/test_vectors => gkr/test_vectors/gkr}/two_inputs_select-input-3_gate_two_instances.json (86%) diff --git a/internal/generator/backend/gkr/test_vectors/main.go b/internal/generator/backend/gkr/test_vectors/main.go deleted file mode 100644 index 0bb86739af..0000000000 --- a/internal/generator/backend/gkr/test_vectors/main.go +++ /dev/null @@ -1,349 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package main - -import ( - "encoding/json" - "fmt" - "hash" - "os" - "path/filepath" - "reflect" - - "github.com/consensys/bavard" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" - "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" -) - -func main() { - if err := GenerateVectors(); err != nil { - fmt.Println(err.Error()) - os.Exit(-1) - } -} - -func GenerateVectors() error { - testDirPath, err := filepath.Abs("gkr/test_vectors") - if err != nil { - return err - } - - fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) - - dirEntries, err := os.ReadDir(testDirPath) - if err != nil { - return err - } - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - if !bavard.ShouldGenerate(path) { - continue - } - fmt.Println("\tprocessing", dirEntry.Name()) - if err = run(path); err != nil { - return err - } - } - } - } - - return nil -} - -func run(absPath string) error { - testCase, err := newTestCase(absPath) - if err != nil { - return err - } - - transcriptSetting := fiatshamir.WithHash(testCase.Hash) - - var proof gkr.Proof - proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) - if err != nil { - return err - } - - if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { - return err - } - var outBytes []byte - if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { - if err = os.WriteFile(absPath, outBytes, 0); err != nil { - return err - } - } else { - return err - } - - testCase, err = newTestCase(absPath) - if err != nil { - return err - } - - err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) - if err != nil { - return err - } - - testCase, err = newTestCase(absPath) - if err != nil { - return err - } - - err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - if err == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { - res := make(PrintableProof, len(proof)) - - for i := range proof { - - partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) - for k, partialK := range proof[i].PartialSumPolys { - partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) - } - - res[i] = PrintableSumcheckProof{ - FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), - PartialSumPolys: partialSumPolys, - } - } - return res, nil -} - -type WireInfo struct { - Gate gkr.GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]gkr.Circuit) - -func getCircuit(path string) (gkr.Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit gkr.Circuit) { - circuit = make(gkr.Circuit, len(c)) - for i := range c { - circuit[i].Gate = gkr.GetGate(c[i].Gate) - circuit[i].Inputs = make([]*gkr.Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { - var sum small_rational.SmallRational - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC gkr.GateName = "mimc" - SelectInput3 gkr.GateName = "select-input-3" -) - -func init() { - if err := gkr.RegisterGate(MiMC, mimcRound, 2, gkr.WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := gkr.RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { - return input[2] - }, 3, gkr.WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (gkr.Proof, error) { - proof := make(gkr.Proof, len(printable)) - for i := range printable { - finalEvalProof := []small_rational.SmallRational(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit gkr.Circuit - Hash hash.Hash - Proof gkr.Proof - FullAssignment gkr.WireAssignment - InOutAssignment gkr.WireAssignment - Info TestCaseInfo -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit gkr.Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof gkr.Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(gkr.WireAssignment) - inOutAssignment := make(gkr.WireAssignment) - - sorted := gkr.TopologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []small_rational.SmallRational - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - info.Output = make([][]interface{}, 0, outI) - - for _, w := range sorted { - if w.IsOutput() { - - info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - Info: info, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index fe0a843a93..49a38f7e9d 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -245,10 +245,10 @@ func main() { )) fmt.Println("generating test vectors for gkr and sumcheck") - /*cmd := exec.Command("go", "run", "../../gkr/test_vectors") + cmd := exec.Command("go", "run", "../../gkr/test_vectors") cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - assertNoError(cmd.Run())*/ + assertNoError(cmd.Run()) wg.Done() }() diff --git a/internal/generator/backend/sumcheck/test_vectors/main.go b/internal/generator/backend/sumcheck/test_vectors/main.go deleted file mode 100644 index 8a5c3f867e..0000000000 --- a/internal/generator/backend/sumcheck/test_vectors/main.go +++ /dev/null @@ -1,199 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "github.com/consensys/gnark/internal/small_rational/test_vector_utils" - "hash" - "math/bits" - "os" - "path/filepath" -) - -func runMultilin(testCaseInfo *TestCaseInfo) error { - - var poly polynomial.MultiLin - if v, err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); err == nil { - poly = v - } else { - return err - } - - var hsh hash.Hash - var err error - if hsh, err = test_vector_utils.HashFromDescription(testCaseInfo.Hash); err != nil { - return err - } - - proof, err := sumcheck.Prove( - &singleMultilinClaim{poly}, fiatshamir.WithHash(hsh)) - if err != nil { - return err - } - testCaseInfo.Proof = toPrintableProof(proof) - - // Verification - if v, _err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); _err == nil { - poly = v - } else { - return _err - } - var claimedSum small_rational.SmallRational - if _, err = claimedSum.SetInterface(testCaseInfo.ClaimedSum); err != nil { - return err - } - - if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err != nil { - return fmt.Errorf("proof rejected: %v", err) - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func run(testCaseInfo *TestCaseInfo) error { - switch testCaseInfo.Type { - case "multilin": - return runMultilin(testCaseInfo) - default: - return fmt.Errorf("type \"%s\" unrecognized", testCaseInfo.Type) - } -} - -func runAll(relPath string) error { - var filename string - var err error - if filename, err = filepath.Abs(relPath); err != nil { - return err - } - - var bytes []byte - - if bytes, err = os.ReadFile(filename); err != nil { - return err - } - - var testCasesInfo TestCasesInfo - if err = json.Unmarshal(bytes, &testCasesInfo); err != nil { - return err - } - - failed := false - for name, testCase := range testCasesInfo { - if err = run(testCase); err != nil { - fmt.Println(name, ":", err) - failed = true - } - } - - if failed { - return fmt.Errorf("test case failed") - } - - if bytes, err = json.MarshalIndent(testCasesInfo, "", "\t"); err != nil { - return err - } - - return os.WriteFile(filename, bytes, 0) -} - -func main() { - if err := runAll("sumcheck/test_vectors/vectors.json"); err != nil { - fmt.Println(err) - os.Exit(-1) - } -} - -type TestCasesInfo map[string]*TestCaseInfo - -type TestCaseInfo struct { - Type string `json:"type"` - Hash test_vector_utils.HashDescription `json:"hash"` - Values []interface{} `json:"values"` - Description string `json:"description"` - Proof PrintableProof `json:"proof"` - ClaimedSum interface{} `json:"claimedSum"` -} - -type PrintableProof struct { - PartialSumPolys [][]interface{} `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` -} - -func toPrintableProof(proof sumcheck.Proof) (printable PrintableProof) { - if proof.FinalEvalProof != nil { - panic("null expected") - } - printable.FinalEvalProof = struct{}{} - printable.PartialSumPolys = test_vector_utils.ElementSliceSliceToInterfaceSliceSlice(proof.PartialSumPolys) - return -} - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval([]small_rational.SmallRational) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []small_rational.SmallRational{sum} -} - -func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum small_rational.SmallRational -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, _ small_rational.SmallRational, purportedValue small_rational.SmallRational, _ interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(small_rational.SmallRational) small_rational.SmallRational { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} diff --git a/internal/generator/backend/sumcheck/test_vectors/vectors.json b/internal/generator/backend/sumcheck/test_vectors/vectors.json deleted file mode 100644 index 64b8e3fb2d..0000000000 --- a/internal/generator/backend/sumcheck/test_vectors/vectors.json +++ /dev/null @@ -1,56 +0,0 @@ -{ - "linear_univariate_single_claim": { - "type": "multilin", - "hash": { - "type": "const", - "val": -1 - }, - "values": [ - 1, - 3 - ], - "description": "X ↦ 2X + 1", - "proof": { - "partialSumPolys": [ - [ - 3 - ] - ], - "finalEvalProof": {} - }, - "claimedSum": 4 - }, - "trilinear_single_claim": { - "type": "multilin", - "hash": { - "type": "const", - "val": -1 - }, - "values": [ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8 - ], - "description": "X₁, X₂, X₃ ↦ 1 + 4X₁ + 2X₂ + X₃", - "proof": { - "partialSumPolys": [ - [ - 26 - ], - [ - -1 - ], - [ - -4 - ] - ], - "finalEvalProof": {} - }, - "claimedSum": 36 - } -} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl index 0386c6635d..a747c5ed04 100644 --- a/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl @@ -15,7 +15,7 @@ import ( ) func GenerateVectors() error { - testDirPath, err := filepath.Abs("gkr/test_vectors") + testDirPath, err := filepath.Abs("../../gkr/test_vectors/gkr") if err != nil { return err } diff --git a/internal/generator/backend/gkr/test_vectors/circuits/mimc_five_levels.json b/internal/gkr/test_vectors/gkr/circuits/mimc_five_levels.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/circuits/mimc_five_levels.json rename to internal/gkr/test_vectors/gkr/circuits/mimc_five_levels.json diff --git a/internal/generator/backend/gkr/test_vectors/circuits/single_identity_gate.json b/internal/gkr/test_vectors/gkr/circuits/single_identity_gate.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/circuits/single_identity_gate.json rename to internal/gkr/test_vectors/gkr/circuits/single_identity_gate.json diff --git a/internal/generator/backend/gkr/test_vectors/circuits/single_input_two_identity_gates.json b/internal/gkr/test_vectors/gkr/circuits/single_input_two_identity_gates.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/circuits/single_input_two_identity_gates.json rename to internal/gkr/test_vectors/gkr/circuits/single_input_two_identity_gates.json diff --git a/internal/generator/backend/gkr/test_vectors/circuits/single_input_two_outs.json b/internal/gkr/test_vectors/gkr/circuits/single_input_two_outs.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/circuits/single_input_two_outs.json rename to internal/gkr/test_vectors/gkr/circuits/single_input_two_outs.json diff --git a/internal/generator/backend/gkr/test_vectors/circuits/single_mimc_gate.json b/internal/gkr/test_vectors/gkr/circuits/single_mimc_gate.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/circuits/single_mimc_gate.json rename to internal/gkr/test_vectors/gkr/circuits/single_mimc_gate.json diff --git a/internal/generator/backend/gkr/test_vectors/circuits/single_mul_gate.json b/internal/gkr/test_vectors/gkr/circuits/single_mul_gate.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/circuits/single_mul_gate.json rename to internal/gkr/test_vectors/gkr/circuits/single_mul_gate.json diff --git a/internal/generator/backend/gkr/test_vectors/circuits/two_identity_gates_composed_single_input.json b/internal/gkr/test_vectors/gkr/circuits/two_identity_gates_composed_single_input.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/circuits/two_identity_gates_composed_single_input.json rename to internal/gkr/test_vectors/gkr/circuits/two_identity_gates_composed_single_input.json diff --git a/internal/generator/backend/gkr/test_vectors/circuits/two_inputs_select-input-3_gate.json b/internal/gkr/test_vectors/gkr/circuits/two_inputs_select-input-3_gate.json similarity index 100% rename from internal/generator/backend/gkr/test_vectors/circuits/two_inputs_select-input-3_gate.json rename to internal/gkr/test_vectors/gkr/circuits/two_inputs_select-input-3_gate.json diff --git a/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go b/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go index 1b0baff158..c1dcb5b6e0 100644 --- a/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go +++ b/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go @@ -22,7 +22,7 @@ import ( ) func GenerateVectors() error { - testDirPath, err := filepath.Abs("gkr/test_vectors") + testDirPath, err := filepath.Abs("../../gkr/test_vectors/gkr") if err != nil { return err } diff --git a/internal/generator/backend/gkr/test_vectors/mimc_five_levels_two_instances._json b/internal/gkr/test_vectors/gkr/mimc_five_levels_two_instances._json similarity index 83% rename from internal/generator/backend/gkr/test_vectors/mimc_five_levels_two_instances._json rename to internal/gkr/test_vectors/gkr/mimc_five_levels_two_instances._json index 446d23fdb2..e980cfb0cb 100644 --- a/internal/generator/backend/gkr/test_vectors/mimc_five_levels_two_instances._json +++ b/internal/gkr/test_vectors/gkr/mimc_five_levels_two_instances._json @@ -1,6 +1,6 @@ { "hash": {"type": "const", "val": -1}, - "circuit": "resources/mimc_five_levels.json", + "circuit": "circuits/mimc_five_levels.json", "input": [[1, 3], [1, 3], [1, 3], [1, 3], [1, 3], [1, 3]], "output": [[4, 3]], "proof": [[{"partialSumPolys":[[3,4]],"finalEvalProof":[3]}],[{"partialSumPolys":null,"finalEvalProof":null}]] diff --git a/internal/generator/backend/gkr/test_vectors/single_identity_gate_two_instances.json b/internal/gkr/test_vectors/gkr/single_identity_gate_two_instances.json similarity index 85% rename from internal/generator/backend/gkr/test_vectors/single_identity_gate_two_instances.json rename to internal/gkr/test_vectors/gkr/single_identity_gate_two_instances.json index ce326d0a63..ba28e35961 100644 --- a/internal/generator/backend/gkr/test_vectors/single_identity_gate_two_instances.json +++ b/internal/gkr/test_vectors/gkr/single_identity_gate_two_instances.json @@ -3,7 +3,7 @@ "type": "const", "val": -1 }, - "circuit": "resources/single_identity_gate.json", + "circuit": "circuits/single_identity_gate.json", "input": [ [ 4, diff --git a/internal/generator/backend/gkr/test_vectors/single_input_two_identity_gates_two_instances.json b/internal/gkr/test_vectors/gkr/single_input_two_identity_gates_two_instances.json similarity index 87% rename from internal/generator/backend/gkr/test_vectors/single_input_two_identity_gates_two_instances.json rename to internal/gkr/test_vectors/gkr/single_input_two_identity_gates_two_instances.json index 2c95f044f2..1451b332c2 100644 --- a/internal/generator/backend/gkr/test_vectors/single_input_two_identity_gates_two_instances.json +++ b/internal/gkr/test_vectors/gkr/single_input_two_identity_gates_two_instances.json @@ -3,7 +3,7 @@ "type": "const", "val": -1 }, - "circuit": "resources/single_input_two_identity_gates.json", + "circuit": "circuits/single_input_two_identity_gates.json", "input": [ [ 2, diff --git a/internal/generator/backend/gkr/test_vectors/single_input_two_outs_two_instances.json b/internal/gkr/test_vectors/gkr/single_input_two_outs_two_instances.json similarity index 89% rename from internal/generator/backend/gkr/test_vectors/single_input_two_outs_two_instances.json rename to internal/gkr/test_vectors/gkr/single_input_two_outs_two_instances.json index d348303d0e..897aea7ee5 100644 --- a/internal/generator/backend/gkr/test_vectors/single_input_two_outs_two_instances.json +++ b/internal/gkr/test_vectors/gkr/single_input_two_outs_two_instances.json @@ -3,7 +3,7 @@ "type": "const", "val": -1 }, - "circuit": "resources/single_input_two_outs.json", + "circuit": "circuits/single_input_two_outs.json", "input": [ [ 1, diff --git a/internal/generator/backend/gkr/test_vectors/single_mimc_gate_four_instances.json b/internal/gkr/test_vectors/gkr/single_mimc_gate_four_instances.json similarity index 93% rename from internal/generator/backend/gkr/test_vectors/single_mimc_gate_four_instances.json rename to internal/gkr/test_vectors/gkr/single_mimc_gate_four_instances.json index ff275c9cb4..a724ba5a7b 100644 --- a/internal/generator/backend/gkr/test_vectors/single_mimc_gate_four_instances.json +++ b/internal/gkr/test_vectors/gkr/single_mimc_gate_four_instances.json @@ -3,7 +3,7 @@ "type": "const", "val": -1 }, - "circuit": "resources/single_mimc_gate.json", + "circuit": "circuits/single_mimc_gate.json", "input": [ [ 1, diff --git a/internal/generator/backend/gkr/test_vectors/single_mimc_gate_two_instances.json b/internal/gkr/test_vectors/gkr/single_mimc_gate_two_instances.json similarity index 91% rename from internal/generator/backend/gkr/test_vectors/single_mimc_gate_two_instances.json rename to internal/gkr/test_vectors/gkr/single_mimc_gate_two_instances.json index 369297dbd5..901db48692 100644 --- a/internal/generator/backend/gkr/test_vectors/single_mimc_gate_two_instances.json +++ b/internal/gkr/test_vectors/gkr/single_mimc_gate_two_instances.json @@ -3,7 +3,7 @@ "type": "const", "val": -1 }, - "circuit": "resources/single_mimc_gate.json", + "circuit": "circuits/single_mimc_gate.json", "input": [ [ 1, diff --git a/internal/generator/backend/gkr/test_vectors/single_mul_gate_two_instances.json b/internal/gkr/test_vectors/gkr/single_mul_gate_two_instances.json similarity index 89% rename from internal/generator/backend/gkr/test_vectors/single_mul_gate_two_instances.json rename to internal/gkr/test_vectors/gkr/single_mul_gate_two_instances.json index 75c1d59c3d..b85a6df42c 100644 --- a/internal/generator/backend/gkr/test_vectors/single_mul_gate_two_instances.json +++ b/internal/gkr/test_vectors/gkr/single_mul_gate_two_instances.json @@ -3,7 +3,7 @@ "type": "const", "val": -1 }, - "circuit": "resources/single_mul_gate.json", + "circuit": "circuits/single_mul_gate.json", "input": [ [ 4, diff --git a/internal/generator/backend/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/internal/gkr/test_vectors/gkr/two_identity_gates_composed_single_input_two_instances.json similarity index 84% rename from internal/generator/backend/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json rename to internal/gkr/test_vectors/gkr/two_identity_gates_composed_single_input_two_instances.json index 10e5f1ff3c..69a2038a75 100644 --- a/internal/generator/backend/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json +++ b/internal/gkr/test_vectors/gkr/two_identity_gates_composed_single_input_two_instances.json @@ -3,7 +3,7 @@ "type": "const", "val": -1 }, - "circuit": "resources/two_identity_gates_composed_single_input.json", + "circuit": "circuits/two_identity_gates_composed_single_input.json", "input": [ [ 2, diff --git a/internal/generator/backend/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json b/internal/gkr/test_vectors/gkr/two_inputs_select-input-3_gate_two_instances.json similarity index 86% rename from internal/generator/backend/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json rename to internal/gkr/test_vectors/gkr/two_inputs_select-input-3_gate_two_instances.json index 19e127df71..2dca0746a2 100644 --- a/internal/generator/backend/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json +++ b/internal/gkr/test_vectors/gkr/two_inputs_select-input-3_gate_two_instances.json @@ -3,7 +3,7 @@ "type": "const", "val": -1 }, - "circuit": "resources/two_inputs_select-input-3_gate.json", + "circuit": "circuits/two_inputs_select-input-3_gate.json", "input": [ [ 0, From 6b4b26ccfc2f0956b0c02dc7c8eeaa87fad9181f Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 3 Apr 2025 20:32:26 -0500 Subject: [PATCH 42/62] fix: gkr test vector tests --- internal/generator/backend/main.go | 11 +++++------ .../generator/backend/template/gkr/gkr.test.go.tmpl | 2 +- internal/gkr/bls12-377/gkr_test.go | 2 +- internal/gkr/bls12-381/gkr_test.go | 2 +- internal/gkr/bls24-315/gkr_test.go | 2 +- internal/gkr/bls24-317/gkr_test.go | 2 +- internal/gkr/bn254/gkr_test.go | 2 +- internal/gkr/bw6-633/gkr_test.go | 2 +- internal/gkr/bw6-761/gkr_test.go | 2 +- .../gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go | 2 +- 10 files changed, 14 insertions(+), 15 deletions(-) diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 49a38f7e9d..6d805d22c5 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -320,12 +320,11 @@ func generateGkrBackend(cfg gkrConfig) error { type gkrConfig struct { config.FieldDependency - GkrPackageRelativePath string // the GKR package, relative to the repo root - TestVectorsRelativePath string // the test vectors, relative to the current package - CanUseFFT bool - OutsideGkrPackage bool - GenerateTestVectors bool - NoGkrTests bool + GkrPackageRelativePath string // the GKR package, relative to the repo root + CanUseFFT bool + OutsideGkrPackage bool + GenerateTestVectors bool + NoGkrTests bool } func assertNoError(err error) { diff --git a/internal/generator/backend/template/gkr/gkr.test.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.go.tmpl index d79c465519..6d7a58b0aa 100644 --- a/internal/generator/backend/template/gkr/gkr.test.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.go.tmpl @@ -398,7 +398,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "{{.TestVectorsRelativePath}}" + const testDirPath = "../test_vectors/gkr" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { diff --git a/internal/gkr/bls12-377/gkr_test.go b/internal/gkr/bls12-377/gkr_test.go index 1b214822be..5b63fd1c80 100644 --- a/internal/gkr/bls12-377/gkr_test.go +++ b/internal/gkr/bls12-377/gkr_test.go @@ -402,7 +402,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "" + const testDirPath = "../test_vectors/gkr" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { diff --git a/internal/gkr/bls12-381/gkr_test.go b/internal/gkr/bls12-381/gkr_test.go index 3f6716a5a3..8c932961dc 100644 --- a/internal/gkr/bls12-381/gkr_test.go +++ b/internal/gkr/bls12-381/gkr_test.go @@ -402,7 +402,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "" + const testDirPath = "../test_vectors/gkr" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { diff --git a/internal/gkr/bls24-315/gkr_test.go b/internal/gkr/bls24-315/gkr_test.go index ecdd478dde..350d807e56 100644 --- a/internal/gkr/bls24-315/gkr_test.go +++ b/internal/gkr/bls24-315/gkr_test.go @@ -402,7 +402,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "" + const testDirPath = "../test_vectors/gkr" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { diff --git a/internal/gkr/bls24-317/gkr_test.go b/internal/gkr/bls24-317/gkr_test.go index 91b49d4f89..e44c8ccab1 100644 --- a/internal/gkr/bls24-317/gkr_test.go +++ b/internal/gkr/bls24-317/gkr_test.go @@ -402,7 +402,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "" + const testDirPath = "../test_vectors/gkr" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { diff --git a/internal/gkr/bn254/gkr_test.go b/internal/gkr/bn254/gkr_test.go index 1cc04d21dd..09fd0a9be5 100644 --- a/internal/gkr/bn254/gkr_test.go +++ b/internal/gkr/bn254/gkr_test.go @@ -402,7 +402,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "" + const testDirPath = "../test_vectors/gkr" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { diff --git a/internal/gkr/bw6-633/gkr_test.go b/internal/gkr/bw6-633/gkr_test.go index 5127ffc3d7..0b018df326 100644 --- a/internal/gkr/bw6-633/gkr_test.go +++ b/internal/gkr/bw6-633/gkr_test.go @@ -402,7 +402,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "" + const testDirPath = "../test_vectors/gkr" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { diff --git a/internal/gkr/bw6-761/gkr_test.go b/internal/gkr/bw6-761/gkr_test.go index 93f16005b2..2aa45ac3c0 100644 --- a/internal/gkr/bw6-761/gkr_test.go +++ b/internal/gkr/bw6-761/gkr_test.go @@ -402,7 +402,7 @@ func generateTestVerifier(path string) func(t *testing.T) { func TestGkrVectors(t *testing.T) { - testDirPath := "" + const testDirPath = "../test_vectors/gkr" dirEntries, err := os.ReadDir(testDirPath) assert.NoError(t, err) for _, dirEntry := range dirEntries { diff --git a/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go index 52c2d0b89d..a264fd57d0 100644 --- a/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go +++ b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go @@ -78,7 +78,7 @@ func run(testCaseInfo *TestCaseInfo) error { func GenerateVectors() error { // read the test vectors file, generate the proof, make sure it verifies, // and add the proof to the same file - const relPath = "sumcheck/test_vectors/vectors.json" + const relPath = "../../gkr/test_vectors/sumcheck/vectors.json" var filename string var err error From 338df836b8861654c93c578026d7d92f8eaf9aae Mon Sep 17 00:00:00 2001 From: Tabaie Date: Thu, 3 Apr 2025 20:35:20 -0500 Subject: [PATCH 43/62] fix gkr imports in constraint --- constraint/bls12-377/gkr.go | 2 +- constraint/bls12-381/gkr.go | 2 +- constraint/bls24-315/gkr.go | 2 +- constraint/bls24-317/gkr.go | 2 +- constraint/bn254/gkr.go | 2 +- constraint/bw6-633/gkr.go | 2 +- constraint/bw6-761/gkr.go | 2 +- internal/generator/backend/template/imports.go.tmpl | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/constraint/bls12-377/gkr.go b/constraint/bls12-377/gkr.go index 744f22525c..63c04dce94 100644 --- a/constraint/bls12-377/gkr.go +++ b/constraint/bls12-377/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bls12-377" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" diff --git a/constraint/bls12-381/gkr.go b/constraint/bls12-381/gkr.go index b3b22b9a95..2a516d1805 100644 --- a/constraint/bls12-381/gkr.go +++ b/constraint/bls12-381/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bls12-381" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" diff --git a/constraint/bls24-315/gkr.go b/constraint/bls24-315/gkr.go index ba328c8bb4..a2b7297257 100644 --- a/constraint/bls24-315/gkr.go +++ b/constraint/bls24-315/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bls24-315" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" diff --git a/constraint/bls24-317/gkr.go b/constraint/bls24-317/gkr.go index be02e3455c..9c269ff686 100644 --- a/constraint/bls24-317/gkr.go +++ b/constraint/bls24-317/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bls24-317" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" diff --git a/constraint/bn254/gkr.go b/constraint/bn254/gkr.go index 21731b8ac9..2af2d3f035 100644 --- a/constraint/bn254/gkr.go +++ b/constraint/bn254/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bn254" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" diff --git a/constraint/bw6-633/gkr.go b/constraint/bw6-633/gkr.go index 125da817df..fb693f4851 100644 --- a/constraint/bw6-633/gkr.go +++ b/constraint/bw6-633/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bw6-633" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" diff --git a/constraint/bw6-761/gkr.go b/constraint/bw6-761/gkr.go index f40856cc36..72c07a3774 100644 --- a/constraint/bw6-761/gkr.go +++ b/constraint/bw6-761/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bw6-761" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" diff --git a/internal/generator/backend/template/imports.go.tmpl b/internal/generator/backend/template/imports.go.tmpl index c1cac6c90a..ce8ee23954 100644 --- a/internal/generator/backend/template/imports.go.tmpl +++ b/internal/generator/backend/template/imports.go.tmpl @@ -68,7 +68,7 @@ {{- end}} {{- define "import_gkr"}} - "github.com/consensys/gnark-crypto/ecc/{{ toLower .Curve }}/fr/gkr" + gkr "github.com/consensys/gnark/internal/gkr/{{ toLower .Curve }}" {{- end}} {{- define "import_hash_to_field" }} From 4bca28854ebe23a441f46ad0d614041b013a4423 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Fri, 4 Apr 2025 13:11:21 -0500 Subject: [PATCH 44/62] fix: remove wrongly generated files --- gkr.go | 867 -------------------------- gkr_test.go | 829 ------------------------ internal/gkr/gkr.go | 867 -------------------------- internal/gkr/gkr_test.go | 829 ------------------------ std/gkr/bn254_wrapper_api.go | 2 +- std/gkr/internal/bn254_wrapper_api.go | 206 ------ std/gkr/testing.go | 14 +- 7 files changed, 8 insertions(+), 3606 deletions(-) delete mode 100644 gkr.go delete mode 100644 gkr_test.go delete mode 100644 internal/gkr/gkr.go delete mode 100644 internal/gkr/gkr_test.go delete mode 100644 std/gkr/internal/bn254_wrapper_api.go diff --git a/gkr.go b/gkr.go deleted file mode 100644 index 70913dd297..0000000000 --- a/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark//sumcheck" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...small_rational.SmallRational) small_rational.SmallRational - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]small_rational.SmallRational - claimedEvaluations []small_rational.SmallRational - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a small_rational.SmallRational) small_rational.SmallRational { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]small_rational.SmallRational) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation small_rational.SmallRational - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]small_rational.SmallRational // x in the paper - claimedEvaluations []small_rational.SmallRational // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational.SmallRational) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]small_rational.SmallRational, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step small_rational.SmallRational - - res := make([]small_rational.SmallRational, degGJ) - operands := make([]small_rational.SmallRational, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element small_rational.SmallRational) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) interface{} { - - //defer the proof, return list of claims - evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]small_rational.SmallRational, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []small_rational.SmallRational, evaluation small_rational.SmallRational) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_rational.SmallRational, error) { - res := make([]small_rational.SmallRational, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []small_rational.SmallRational - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []small_rational.SmallRational{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]small_rational.SmallRational) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []small_rational.SmallRational - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]small_rational.SmallRational) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// TopologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func TopologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := TopologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]small_rational.SmallRational, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]small_rational.SmallRational, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]small_rational.SmallRational) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/gkr_test.go b/gkr_test.go deleted file mode 100644 index 31bd52133a..0000000000 --- a/gkr_test.go +++ /dev/null @@ -1,829 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/mimc" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "github.com/consensys/gnark/internal/small_rational/sumcheck" - "github.com/consensys/gnark/internal/small_rational/test_vector_utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []small_rational.SmallRational{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []small_rational.SmallRational{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []small_rational.SmallRational{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]small_rational.SmallRational) { - return func(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []small_rational.SmallRational{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []small_rational.SmallRational{three}, five) - manager.add(wire, []small_rational.SmallRational{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six small_rational.SmallRational - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]small_rational.SmallRational)) { - fullAssignments := make([][]small_rational.SmallRational, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]small_rational.SmallRational, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []small_rational.SmallRational) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := TopologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := TopologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := TopologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { - var sum small_rational.SmallRational - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []small_rational.SmallRational(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []small_rational.SmallRational - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/internal/gkr/gkr.go b/internal/gkr/gkr.go deleted file mode 100644 index 70913dd297..0000000000 --- a/internal/gkr/gkr.go +++ /dev/null @@ -1,867 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/internal/parallel" - "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark//sumcheck" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "math/big" - "strconv" - "sync" -) - -// The goal is to prove/verify evaluations of many instances of the same circuit - -// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. -type GateFunction func(...small_rational.SmallRational) small_rational.SmallRational - -// A Gate is a low-degree multivariate polynomial -type Gate struct { - Evaluate GateFunction // Evaluate the polynomial function defining the gate - nbIn int // number of inputs - degree int // total degree of f - solvableVar int // if there is a solvable variable, its index, -1 otherwise -} - -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 -func (g *Gate) Degree() int { - return g.degree -} - -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. -func (g *Gate) SolvableVar() int { - return g.solvableVar -} - -// NbIn returns the number of inputs to the gate (its fan-in) -func (g *Gate) NbIn() int { - return g.nbIn -} - -type Wire struct { - Gate *Gate - Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire - nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) -} - -type Circuit []Wire - -func (w Wire) IsInput() bool { - return len(w.Inputs) == 0 -} - -func (w Wire) IsOutput() bool { - return w.nbUniqueOutputs == 0 -} - -func (w Wire) NbClaims() int { - if w.IsOutput() { - return 1 - } - return w.nbUniqueOutputs -} - -func (w Wire) noProof() bool { - return w.IsInput() && w.NbClaims() == 1 -} - -func (c Circuit) maxGateDegree() int { - res := 1 - for i := range c { - if !c[i].IsInput() { - res = max(res, c[i].Gate.Degree()) - } - } - return res -} - -// WireAssignment is assignment of values to the same wire across many instances of the circuit -type WireAssignment map[*Wire]polynomial.MultiLin - -type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) - -type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]small_rational.SmallRational - claimedEvaluations []small_rational.SmallRational - manager *claimsManager // WARNING: Circular references -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { - return len(e.evaluationPoints) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { - return len(e.evaluationPoints[0]) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a small_rational.SmallRational) small_rational.SmallRational { - evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) - return evalsAsPoly.Eval(&a) -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { - return 1 + e.wire.Gate.Degree() -} - -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]small_rational.SmallRational) - - // the eq terms - numClaims := len(e.evaluationPoints) - evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) - for i := numClaims - 2; i >= 0; i-- { - evaluation.Mul(&evaluation, &combinationCoeff) - eq := polynomial.EvalEq(e.evaluationPoints[i], r) - evaluation.Add(&evaluation, &eq) - } - - // the g(...) term - var gateEvaluation small_rational.SmallRational - if e.wire.IsInput() { - gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { - inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) - - proofI := 0 - for inI, in := range e.wire.Inputs { - indexInProof, found := indexesInProof[in] - if !found { - indexInProof = proofI - indexesInProof[in] = indexInProof - - // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) - proofI++ - } - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] - } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) - } - gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) - } - - evaluation.Mul(&evaluation, &gateEvaluation) - - if evaluation.Equal(&purportedValue) { - return nil - } - return errors.New("incompatible evaluations") -} - -type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]small_rational.SmallRational // x in the paper - claimedEvaluations []small_rational.SmallRational // y in the paper - manager *claimsManager - - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations - - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) -} - -func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational.SmallRational) polynomial.Polynomial { - varsNum := c.VarsNum() - eqLength := 1 << varsNum - claimsNum := c.ClaimsNum() - // initialize the eq tables - c.eq = c.manager.memPool.Make(eqLength) - - c.eq[0].SetOne() - c.eq.Eq(c.evaluationPoints[0]) - - newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) - aI := combinationCoeff - - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points - newEq[0].Set(&aI) - - c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - - if k+1 < claimsNum { - aI.Mul(&aI, &combinationCoeff) - } - } - - c.manager.memPool.Dump(newEq) - - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - - return c.computeGJ() -} - -// eqAcc sets m to an eq table at q and then adds it to e -func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { - n := len(q) - - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) - for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ - // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ - const threshold = 1 << 6 - k := 1 << i - if k < threshold { - for j := 0; j < k; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - } else { - c.manager.workers.Submit(k, func(start, end int) { - for j := start; j < end; j++ { - j0 := j << (n - i) // bᵢ₊₁ = 0 - j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) - } - }, 1024).Wait() - } - - } - c.manager.workers.Submit(len(e), func(start, end int) { - for i := start; i < end; i++ { - e[i].Add(&e[i], &m[i]) - } - }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) -} - -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. -func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) - nbGateIn := len(c.inputPreprocessors) - - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) - - // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 - - gJ := make([]small_rational.SmallRational, degGJ) - var mu sync.Mutex - computeAll := func(start, end int) { - var step small_rational.SmallRational - - res := make([]small_rational.SmallRational, degGJ) - operands := make([]small_rational.SmallRational, degGJ*nbInner) - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) - for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) - } - } - - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner - } - } - mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) - } - mu.Unlock() - } - - const minBlockSize = 64 - - if nbOuter < minBlockSize { - // no parallelization - computeAll(0, nbOuter) - } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() - } - - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - - return gJ -} - -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j -func (c *eqTimesGateEvalSumcheckClaims) Next(element small_rational.SmallRational) polynomial.Polynomial { - const minBlockSize = 512 - n := len(c.eq) / 2 - if n < minBlockSize { - // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) - } - c.eq.Fold(element) - } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) - } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() - for _, wg := range wgs { - wg.Wait() - } - } - - return c.computeGJ() -} - -func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { - return len(c.evaluationPoints[0]) -} - -func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { - return len(c.claimedEvaluations) -} - -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) interface{} { - - //defer the proof, return list of claims - evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) - noMoreClaimsAllowed[c.wire] = struct{}{} - - for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) - } - c.manager.memPool.Dump(puI) - } - - c.manager.memPool.Dump(c.claimedEvaluations, c.eq) - - return evaluations -} - -type claimsManager struct { - claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims - assignment WireAssignment - memPool *polynomial.Pool - workers *utils.WorkerPool -} - -func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { - claims.assignment = assignment - claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) - claims.memPool = o.pool - claims.workers = o.workers - - for i := range c { - wire := &c[i] - - claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ - wire: wire, - evaluationPoints: make([][]small_rational.SmallRational, 0, wire.NbClaims()), - claimedEvaluations: claims.memPool.Make(wire.NbClaims()), - manager: &claims, - } - } - return -} - -func (m *claimsManager) add(wire *Wire, evaluationPoint []small_rational.SmallRational, evaluation small_rational.SmallRational) { - claim := m.claimsMap[wire] - i := len(claim.evaluationPoints) - claim.claimedEvaluations[i] = evaluation - claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) -} - -func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { - return m.claimsMap[wire] -} - -func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { - lazy := m.claimsMap[wire] - res := &eqTimesGateEvalSumcheckClaims{ - wire: wire, - evaluationPoints: lazy.evaluationPoints, - claimedEvaluations: lazy.claimedEvaluations, - manager: m, - } - - if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} - } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) - - for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied - } - } - return res -} - -func (m *claimsManager) deleteClaim(wire *Wire) { - delete(m.claimsMap, wire) -} - -type settings struct { - pool *polynomial.Pool - sorted []*Wire - transcript *fiatshamir.Transcript - transcriptPrefix string - nbVars int - workers *utils.WorkerPool -} - -type Option func(*settings) - -func WithPool(pool *polynomial.Pool) Option { - return func(options *settings) { - options.pool = pool - } -} - -func WithSortedCircuit(sorted []*Wire) Option { - return func(options *settings) { - options.sorted = sorted - } -} - -func WithWorkers(workers *utils.WorkerPool) Option { - return func(options *settings) { - options.workers = workers - } -} - -// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement -func (c Circuit) MemoryRequirements(nbInstances int) []int { - res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} - - if res[0] > res[1] { // make sure it's sorted - res[0], res[1] = res[1], res[0] - if res[1] > res[2] { - res[1], res[2] = res[2], res[1] - } - } - - return res -} - -func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { - var o settings - var err error - for _, option := range options { - option(&o) - } - - o.nbVars = assignment.NumVars() - nbInstances := assignment.NumInstances() - if 1< 1 { //combine the claims - size++ - } - size += logNbInstances // full run of sumcheck on logNbInstances variables - } - - nums := make([]string, max(len(sorted), logNbInstances)) - for i := range nums { - nums[i] = strconv.Itoa(i) - } - - challenges := make([]string, size) - - // output wire claims - firstChallengePrefix := prefix + "fC." - for j := 0; j < logNbInstances; j++ { - challenges[j] = firstChallengePrefix + nums[j] - } - j := logNbInstances - for i := len(sorted) - 1; i >= 0; i-- { - if sorted[i].noProof() { - continue - } - wirePrefix := prefix + "w" + nums[i] + "." - - if sorted[i].NbClaims() > 1 { - challenges[j] = wirePrefix + "comb" - j++ - } - - partialSumPrefix := wirePrefix + "pSP." - for k := 0; k < logNbInstances; k++ { - challenges[j] = partialSumPrefix + nums[k] - j++ - } - } - return challenges -} - -func getFirstChallengeNames(logNbInstances int, prefix string) []string { - res := make([]string, logNbInstances) - firstChallengePrefix := prefix + "fC." - for i := 0; i < logNbInstances; i++ { - res[i] = firstChallengePrefix + strconv.Itoa(i) - } - return res -} - -func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_rational.SmallRational, error) { - res := make([]small_rational.SmallRational, len(names)) - for i, name := range names { - if bytes, err := transcript.ComputeChallenge(name); err == nil { - res[i].SetBytes(bytes) - } else { - return nil, err - } - } - return res, nil -} - -// Prove consistency of the claimed assignment -func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return nil, err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - proof := make(Proof, len(c)) - // firstChallenge called rho in the paper - var firstChallenge []small_rational.SmallRational - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return nil, err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - claim := claims.getClaim(wire) - if wire.noProof() { // input wires with one claim only - proof[i] = sumcheck.Proof{ - PartialSumPolys: []polynomial.Polynomial{}, - FinalEvalProof: []small_rational.SmallRational{}, - } - } else { - if proof[i], err = sumcheck.Prove( - claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err != nil { - return proof, err - } - - finalEvalProof := proof[i].FinalEvalProof.([]small_rational.SmallRational) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } - // the verifier checks a single claim about input wires itself - claims.deleteClaim(wire) - } - - return proof, nil -} - -// Verify the consistency of the claimed output with the claimed input -// Unlike in Prove, the assignment argument need not be complete -func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { - o, err := setup(c, assignment, transcriptSettings, options...) - if err != nil { - return err - } - defer o.workers.Stop() - - claims := newClaimsManager(c, assignment, o) - - var firstChallenge []small_rational.SmallRational - firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) - if err != nil { - return err - } - - wirePrefix := o.transcriptPrefix + "w" - var baseChallenge [][]byte - for i := len(c) - 1; i >= 0; i-- { - wire := o.sorted[i] - - if wire.IsOutput() { - claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) - } - - proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]small_rational.SmallRational) - claim := claims.getLazyClaim(wire) - if wire.noProof() { // input wires with one claim only - // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { - return errors.New("no proof allowed for input wire with a single claim") - } - - if wire.NbClaims() == 1 { // input wire - // simply evaluate and see if it matches - evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) - if !claim.claimedEvaluations[0].Equal(&evaluation) { - return errors.New("incorrect input wire claim") - } - } - } else if err = sumcheck.Verify( - claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] - } - } else { - return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? - } - claims.deleteClaim(wire) - } - return nil -} - -// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. -func outputsList(c Circuit, indexes map[*Wire]int) [][]int { - idGate := GetGate("identity") - res := make([][]int, len(c)) - for i := range c { - res[i] = make([]int, 0) - c[i].nbUniqueOutputs = 0 - if c[i].IsInput() { - c[i].Gate = idGate - } - } - ins := make(map[int]struct{}, len(c)) - for i := range c { - for k := range ins { // clear map - delete(ins, k) - } - for _, in := range c[i].Inputs { - inI := indexes[in] - res[inI] = append(res[inI], i) - if _, ok := ins[inI]; !ok { - in.nbUniqueOutputs++ - ins[inI] = struct{}{} - } - } - } - return res -} - -type topSortData struct { - outputs [][]int - status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done - index map[*Wire]int - leastReady int -} - -func (d *topSortData) markDone(i int) { - - d.status[i] = -1 - - for _, outI := range d.outputs[i] { - d.status[outI]-- - if d.status[outI] == 0 && outI < d.leastReady { - d.leastReady = outI - } - } - - for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { - d.leastReady++ - } -} - -func indexMap(c Circuit) map[*Wire]int { - res := make(map[*Wire]int, len(c)) - for i := range c { - res[&c[i]] = i - } - return res -} - -func statusList(c Circuit) []int { - res := make([]int, len(c)) - for i := range c { - res[i] = len(c[i].Inputs) - } - return res -} - -// TopologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on -// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. -// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. -// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. -// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input -func TopologicalSort(c Circuit) []*Wire { - var data topSortData - data.index = indexMap(c) - data.outputs = outputsList(c, data.index) - data.status = statusList(c) - sorted := make([]*Wire, len(c)) - - for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { - } - - for i := range c { - sorted[i] = &c[data.leastReady] - data.markDone(data.leastReady) - } - - return sorted -} - -// Complete the circuit evaluation from input values -func (a WireAssignment) Complete(c Circuit) WireAssignment { - - sortedWires := TopologicalSort(c) - nbInstances := a.NumInstances() - maxNbIns := 0 - - for _, w := range sortedWires { - maxNbIns = max(maxNbIns, len(w.Inputs)) - if a[w] == nil { - a[w] = make([]small_rational.SmallRational, nbInstances) - } - } - - parallel.Execute(nbInstances, func(start, end int) { - ins := make([]small_rational.SmallRational, maxNbIns) - for i := start; i < end; i++ { - for _, w := range sortedWires { - if !w.IsInput() { - for inI, in := range w.Inputs { - ins[inI] = a[in][i] - } - a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) - } - } - } - }) - - return a -} - -func (a WireAssignment) NumInstances() int { - for _, aW := range a { - return len(aW) - } - panic("empty assignment") -} - -func (a WireAssignment) NumVars() int { - for _, aW := range a { - return aW.NumVars() - } - panic("empty assignment") -} - -// SerializeToBigInts flattens a proof object into the given slice of big.Ints -// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this -func (p Proof) SerializeToBigInts(outs []*big.Int) { - offset := 0 - for i := range p { - for _, poly := range p[i].PartialSumPolys { - frToBigInts(outs[offset:], poly) - offset += len(poly) - } - if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]small_rational.SmallRational) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) - } - } -} - -func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { - for i := range src { - src[i].BigInt(dst[i]) - } -} diff --git a/internal/gkr/gkr_test.go b/internal/gkr/gkr_test.go deleted file mode 100644 index 31bd52133a..0000000000 --- a/internal/gkr/gkr_test.go +++ /dev/null @@ -1,829 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package gkr - -import ( - "encoding/json" - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark-crypto/utils" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/mimc" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "github.com/consensys/gnark/internal/small_rational/sumcheck" - "github.com/consensys/gnark/internal/small_rational/test_vector_utils" - "github.com/stretchr/testify/assert" - "hash" - "os" - "path/filepath" - "reflect" - "strconv" - "testing" - "time" -) - -func TestNoGateTwoInstances(t *testing.T) { - // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case - testNoGate(t, []small_rational.SmallRational{four, three}) -} - -func TestNoGate(t *testing.T) { - testManyInstances(t, 1, testNoGate) -} - -func TestSingleAddGateTwoInstances(t *testing.T) { - testSingleAddGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) -} - -func TestSingleAddGate(t *testing.T) { - testManyInstances(t, 2, testSingleAddGate) -} - -func TestSingleMulGateTwoInstances(t *testing.T) { - testSingleMulGate(t, []small_rational.SmallRational{four, three}, []small_rational.SmallRational{two, three}) -} - -func TestSingleMulGate(t *testing.T) { - testManyInstances(t, 2, testSingleMulGate) -} - -func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { - - testSingleInputTwoIdentityGates(t, []small_rational.SmallRational{two, three}) -} - -func TestSingleInputTwoIdentityGates(t *testing.T) { - - testManyInstances(t, 2, testSingleInputTwoIdentityGates) -} - -func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { - testSingleInputTwoIdentityGatesComposed(t, []small_rational.SmallRational{two, one}) -} - -func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { - testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) -} - -func TestSingleMimcCipherGateTwoInstances(t *testing.T) { - testSingleMimcCipherGate(t, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestSingleMimcCipherGate(t *testing.T) { - testManyInstances(t, 2, testSingleMimcCipherGate) -} - -func TestATimesBSquaredTwoInstances(t *testing.T) { - testATimesBSquared(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestShallowMimcTwoInstances(t *testing.T) { - testMimc(t, 2, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestMimcTwoInstances(t *testing.T) { - testMimc(t, 93, []small_rational.SmallRational{one, one}, []small_rational.SmallRational{one, two}) -} - -func TestMimc(t *testing.T) { - testManyInstances(t, 2, generateTestMimc(93)) -} - -func generateTestMimc(numRounds int) func(*testing.T, ...[]small_rational.SmallRational) { - return func(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - testMimc(t, numRounds, inputAssignments...) - } -} - -func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { - circuit := Circuit{Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{}, - nbUniqueOutputs: 2, - }} - - wire := &circuit[0] - - assignment := WireAssignment{&circuit[0]: []small_rational.SmallRational{two, three}} - var o settings - pool := polynomial.NewPool(256, 1<<11) - workers := utils.NewWorkerPool() - o.pool = &pool - o.workers = workers - - claimsManagerGen := func() *claimsManager { - manager := newClaimsManager(circuit, assignment, o) - manager.add(wire, []small_rational.SmallRational{three}, five) - manager.add(wire, []small_rational.SmallRational{four}, six) - return &manager - } - - transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) - - proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) - err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) - assert.NoError(t, err) -} - -var one, two, three, four, five, six small_rational.SmallRational - -func init() { - one.SetOne() - two.Double(&one) - three.Add(&two, &one) - four.Double(&two) - five.Add(&three, &two) - six.Double(&three) -} - -var testManyInstancesLogMaxInstances = -1 - -func getLogMaxInstances(t *testing.T) int { - if testManyInstancesLogMaxInstances == -1 { - - s := os.Getenv("GKR_LOG_INSTANCES") - if s == "" { - testManyInstancesLogMaxInstances = 5 - } else { - var err error - testManyInstancesLogMaxInstances, err = strconv.Atoi(s) - if err != nil { - t.Error(err) - } - } - - } - return testManyInstancesLogMaxInstances -} - -func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]small_rational.SmallRational)) { - fullAssignments := make([][]small_rational.SmallRational, numInput) - maxSize := 1 << getLogMaxInstances(t) - - t.Log("Entered test orchestrator, assigning and randomizing inputs") - - for i := range fullAssignments { - fullAssignments[i] = make([]fr.Element, maxSize) - setRandomSlice(fullAssignments[i]) - } - - inputAssignments := make([][]small_rational.SmallRational, numInput) - for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { - for i, fullAssignment := range fullAssignments { - inputAssignments[i] = fullAssignment[:numEvals] - } - - t.Log("Selected inputs for test") - test(t, inputAssignments...) - } -} - -func testNoGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := Circuit{ - { - Inputs: []*Wire{}, - Gate: nil, - }, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]} - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - // Even though a hash is called here, the proof is empty - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") -} - -func testSingleAddGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Add2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMulGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - - c := make(Circuit, 3) - c[2] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[0], &c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - - c[2] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[0], &c[1]}, - } - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - t.Log("Proof complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification complete") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification complete") -} - -func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { - c := make(Circuit, 3) - - c[1] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[0]}, - } - c[2] = Wire{ - Gate: GetGate(Identity), - Inputs: []*Wire{&c[1]}, - } - - assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func mimcCircuit(numRounds int) Circuit { - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate("mimc"), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - return c -} - -func testMimc(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { - //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) - // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 - - c := mimcCircuit(numRounds) - - t.Log("Evaluating all circuit wires") - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - t.Log("Circuit evaluation complete") - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - t.Log("Proof finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - t.Log("Successful verification finished") - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") - t.Log("Unsuccessful verification finished") -} - -func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]small_rational.SmallRational) { - // This imitates the MiMC circuit - - c := make(Circuit, numRounds+2) - - for i := 2; i < len(c); i++ { - c[i] = Wire{ - Gate: GetGate(Mul2), - Inputs: []*Wire{&c[i-1], &c[0]}, - } - } - - assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) - - proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err) - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) - assert.NoError(t, err, "proof rejected") - - err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) - assert.NotNil(t, err, "bad proof accepted") -} - -func setRandomSlice(slice []small_rational.SmallRational) { - for i := range slice { - slice[i].MustSetRandom() - } -} - -func generateTestProver(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err) - assert.NoError(t, proofEquals(testCase.Proof, proof)) - } -} - -func generateTestVerifier(path string) func(t *testing.T) { - return func(t *testing.T) { - testCase, err := newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) - assert.NoError(t, err, "proof rejected") - testCase, err = newTestCase(path) - assert.NoError(t, err) - err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) - assert.NotNil(t, err, "bad proof accepted") - } -} - -func TestGkrVectors(t *testing.T) { - - testDirPath := "" - dirEntries, err := os.ReadDir(testDirPath) - assert.NoError(t, err) - for _, dirEntry := range dirEntries { - if !dirEntry.IsDir() { - - if filepath.Ext(dirEntry.Name()) == ".json" { - path := filepath.Join(testDirPath, dirEntry.Name()) - noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] - - t.Run(noExt+"_prover", generateTestProver(path)) - t.Run(noExt+"_verifier", generateTestVerifier(path)) - - } - } - } -} - -func proofEquals(expected Proof, seen Proof) error { - if len(expected) != len(seen) { - return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) - } - for i, x := range expected { - xSeen := seen[i] - - if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { - return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) - } - } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { - return fmt.Errorf("final evaluation proof mismatch") - } - } - if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { - return err - } - } - return nil -} - -func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { - fmt.Println("creating circuit structure") - c := mimcCircuit(mimcDepth) - - in0 := make([]fr.Element, nbInstances) - in1 := make([]fr.Element, nbInstances) - setRandomSlice(in0) - setRandomSlice(in1) - - fmt.Println("evaluating circuit") - start := time.Now().UnixMicro() - assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) - solved := time.Now().UnixMicro() - start - fmt.Println("solved in", solved, "μs") - - //b.ResetTimer() - fmt.Println("constructing proof") - start = time.Now().UnixMicro() - _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) - proved := time.Now().UnixMicro() - start - fmt.Println("proved in", proved, "μs") - assert.NoError(b, err) -} - -func BenchmarkGkrMimc19(b *testing.B) { - benchmarkGkrMiMC(b, 1<<19, 91) -} - -func BenchmarkGkrMimc17(b *testing.B) { - benchmarkGkrMiMC(b, 1<<17, 91) -} - -func TestTopSortTrivial(t *testing.T) { - c := make(Circuit, 2) - c[0].Inputs = []*Wire{&c[1]} - sorted := TopologicalSort(c) - assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) -} - -func TestTopSortDeep(t *testing.T) { - c := make(Circuit, 4) - c[0].Inputs = []*Wire{&c[2]} - c[1].Inputs = []*Wire{&c[3]} - c[2].Inputs = []*Wire{} - c[3].Inputs = []*Wire{&c[0]} - sorted := TopologicalSort(c) - assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) -} - -func TestTopSortWide(t *testing.T) { - c := make(Circuit, 10) - c[0].Inputs = []*Wire{&c[3], &c[8]} - c[1].Inputs = []*Wire{&c[6]} - c[2].Inputs = []*Wire{&c[4]} - c[3].Inputs = []*Wire{} - c[4].Inputs = []*Wire{} - c[5].Inputs = []*Wire{&c[9]} - c[6].Inputs = []*Wire{&c[9]} - c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} - c[8].Inputs = []*Wire{&c[4], &c[3]} - c[9].Inputs = []*Wire{} - - sorted := TopologicalSort(c) - sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} - - assert.Equal(t, sortedExpected, sorted) -} - -type WireInfo struct { - Gate GateName `json:"gate"` - Inputs []int `json:"inputs"` -} - -type CircuitInfo []WireInfo - -var circuitCache = make(map[string]Circuit) - -func getCircuit(path string) (Circuit, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - if circuit, ok := circuitCache[path]; ok { - return circuit, nil - } - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var circuitInfo CircuitInfo - if err = json.Unmarshal(bytes, &circuitInfo); err == nil { - circuit := circuitInfo.toCircuit() - circuitCache[path] = circuit - return circuit, nil - } else { - return nil, err - } - } else { - return nil, err - } -} - -func (c CircuitInfo) toCircuit() (circuit Circuit) { - circuit = make(Circuit, len(c)) - for i := range c { - circuit[i].Gate = GetGate(c[i].Gate) - circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) - for k, inputCoord := range c[i].Inputs { - input := &circuit[inputCoord] - circuit[i].Inputs[k] = input - } - } - return -} - -func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { - var sum small_rational.SmallRational - - sum. - Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark - res.Square(&sum) // sum^2 - res.Mul(&res, &sum) // sum^3 - res.Square(&res) //sum^6 - res.Mul(&res, &sum) //sum^7 - - return -} - -const ( - MiMC GateName = "mimc" - SelectInput3 GateName = "select-input-3" -) - -func init() { - if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { - panic(err) - } - - if err := RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { - return input[2] - }, 3, WithUnverifiedDegree(1)); err != nil { - panic(err) - } -} - -type PrintableProof []PrintableSumcheckProof - -type PrintableSumcheckProof struct { - FinalEvalProof interface{} `json:"finalEvalProof"` - PartialSumPolys [][]interface{} `json:"partialSumPolys"` -} - -func unmarshalProof(printable PrintableProof) (Proof, error) { - proof := make(Proof, len(printable)) - for i := range printable { - finalEvalProof := []small_rational.SmallRational(nil) - - if printable[i].FinalEvalProof != nil { - finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) - finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) - for k := range finalEvalProof { - if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { - return nil, err - } - } - } - - proof[i] = sumcheck.Proof{ - PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), - FinalEvalProof: finalEvalProof, - } - for k := range printable[i].PartialSumPolys { - var err error - if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { - return nil, err - } - } - } - return proof, nil -} - -type TestCase struct { - Circuit Circuit - Hash hash.Hash - Proof Proof - FullAssignment WireAssignment - InOutAssignment WireAssignment -} - -type TestCaseInfo struct { - Hash test_vector_utils.HashDescription `json:"hash"` - Circuit string `json:"circuit"` - Input [][]interface{} `json:"input"` - Output [][]interface{} `json:"output"` - Proof PrintableProof `json:"proof"` -} - -var testCases = make(map[string]*TestCase) - -func newTestCase(path string) (*TestCase, error) { - path, err := filepath.Abs(path) - if err != nil { - return nil, err - } - dir := filepath.Dir(path) - - tCase, ok := testCases[path] - if !ok { - var bytes []byte - if bytes, err = os.ReadFile(path); err == nil { - var info TestCaseInfo - err = json.Unmarshal(bytes, &info) - if err != nil { - return nil, err - } - - var circuit Circuit - if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { - return nil, err - } - var _hash hash.Hash - if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { - return nil, err - } - var proof Proof - if proof, err = unmarshalProof(info.Proof); err != nil { - return nil, err - } - - fullAssignment := make(WireAssignment) - inOutAssignment := make(WireAssignment) - - sorted := topologicalSort(circuit) - - inI, outI := 0, 0 - for _, w := range sorted { - var assignmentRaw []interface{} - if w.IsInput() { - if inI == len(info.Input) { - return nil, fmt.Errorf("fewer input in vector than in circuit") - } - assignmentRaw = info.Input[inI] - inI++ - } else if w.IsOutput() { - if outI == len(info.Output) { - return nil, fmt.Errorf("fewer output in vector than in circuit") - } - assignmentRaw = info.Output[outI] - outI++ - } - if assignmentRaw != nil { - var wireAssignment []small_rational.SmallRational - if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { - return nil, err - } - - fullAssignment[w] = wireAssignment - inOutAssignment[w] = wireAssignment - } - } - - fullAssignment.Complete(circuit) - - for _, w := range sorted { - if w.IsOutput() { - - if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { - return nil, fmt.Errorf("assignment mismatch: %v", err) - } - - } - } - - tCase = &TestCase{ - FullAssignment: fullAssignment, - InOutAssignment: inOutAssignment, - Proof: proof, - Hash: _hash, - Circuit: circuit, - } - - testCases[path] = tCase - } else { - return nil, err - } - } - - return tCase, nil -} - -func TestRegisterGateDegreeDetection(t *testing.T) { - testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { - t.Run(string(name), func(t *testing.T) { - name = name + "-register-gate-test" - - assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") - - assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") - - assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") - - assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") - }) - } - - testGate("select", func(x ...fr.Element) fr.Element { - return x[0] - }, 3, 1) - - testGate("add2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Add(&x[0], &x[1]) - res.Add(&res, &x[2]) - return res - }, 3, 1) - - testGate("mul2", func(x ...fr.Element) fr.Element { - var res fr.Element - res.Mul(&x[0], &x[1]) - return res - }, 2, 2) - - testGate("mimc", mimcRound, 2, 7) - - testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { - var res fr.Element - res. - SetOne(). - Add(&res, &x[0]). - Sub(&res, &x[1]) - return res - }, 2, 1) - - // zero polynomial must not be accepted - t.Run("zero", func(t *testing.T) { - const gateName GateName = "zero-register-gate-test" - expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) - zeroGate := func(x ...fr.Element) fr.Element { - var res fr.Element - return res - } - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) - - assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) - }) -} - -func TestIsAdditive(t *testing.T) { - - // f: x,y -> x² + xy - f := func(x ...fr.Element) fr.Element { - if len(x) != 2 { - panic("bivariate input needed") - } - var res fr.Element - res.Add(&x[0], &x[1]) - res.Mul(&res, &x[0]) - return res - } - - // g: x,y -> x² + 3y - g := func(x ...fr.Element) fr.Element { - var res, y3 fr.Element - res.Square(&x[0]) - y3.Mul(&x[1], &three) - res.Add(&res, &y3) - return res - } - - // h: x -> 2x - // but it edits it input - h := func(x ...fr.Element) fr.Element { - x[0].Double(&x[0]) - return x[0] - } - - assert.False(t, GateFunction(f).isAdditive(1, 2)) - assert.False(t, GateFunction(f).isAdditive(0, 2)) - - assert.False(t, GateFunction(g).isAdditive(0, 2)) - assert.True(t, GateFunction(g).isAdditive(1, 2)) - - assert.True(t, GateFunction(h).isAdditive(0, 1)) -} diff --git a/std/gkr/bn254_wrapper_api.go b/std/gkr/bn254_wrapper_api.go index 0538b1d3e2..a45109006f 100644 --- a/std/gkr/bn254_wrapper_api.go +++ b/std/gkr/bn254_wrapper_api.go @@ -4,8 +4,8 @@ import ( "errors" "fmt" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/bn254" "github.com/consensys/gnark/internal/utils" ) diff --git a/std/gkr/internal/bn254_wrapper_api.go b/std/gkr/internal/bn254_wrapper_api.go deleted file mode 100644 index cb12a81b86..0000000000 --- a/std/gkr/internal/bn254_wrapper_api.go +++ /dev/null @@ -1,206 +0,0 @@ -package internal - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/internal/utils" - "math/big" -) - -// wrap BN254 scalar field arithmetic in a frontend.API -// bn254WrapperApi uses *fr.Element as its variable type -type bn254WrapperApi struct { - err error -} - -func ToBn254GateFunction(f func(frontend.API, ...frontend.Variable) frontend.Variable) gkr.GateFunction { - var wrapper bn254WrapperApi - - return func(x ...fr.Element) fr.Element { - res := f(&wrapper, utils.Map(x, func(x fr.Element) frontend.Variable { - return &x - })...).(*fr.Element) - if wrapper.err != nil { - panic(wrapper.err) - } - return *res - } -} - -func (w *bn254WrapperApi) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Add(w.cast(i1), w.cast(i2)) - for i := range in { - res.Add(&res, w.cast(in[i])) - } - - return &res -} - -func (w *bn254WrapperApi) MulAcc(a, b, c frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(w.cast(b), w.cast(c)) - res.Add(&res, w.cast(a)) - return &res -} - -func (w *bn254WrapperApi) Neg(i1 frontend.Variable) frontend.Variable { - var res fr.Element - res.Neg(w.cast(i1)) - return &res -} - -func (w *bn254WrapperApi) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Sub(w.cast(i1), w.cast(i2)) - for i := range in { - res.Sub(&res, w.cast(in[i])) - } - return &res -} - -func (w *bn254WrapperApi) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - var res fr.Element - res.Mul(w.cast(i1), w.cast(i2)) - for i := range in { - res.Mul(&res, w.cast(in[i])) - } - return &res -} - -func (w *bn254WrapperApi) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { - return w.Div(i1, i2) -} - -func (w *bn254WrapperApi) Div(i1, i2 frontend.Variable) frontend.Variable { - return w.Mul(i1, w.Inverse(i2)) -} - -func (w *bn254WrapperApi) Inverse(i1 frontend.Variable) frontend.Variable { - w.newError("only polynomial (ring) operations supported") - return nil -} - -func (w *bn254WrapperApi) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) FromBinary(b ...frontend.Variable) frontend.Variable { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) Xor(a, b frontend.Variable) frontend.Variable { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) Or(a, b frontend.Variable) frontend.Variable { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) And(a, b frontend.Variable) frontend.Variable { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) Select(frontend.Variable, frontend.Variable, frontend.Variable) frontend.Variable { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) Lookup2(frontend.Variable, frontend.Variable, frontend.Variable, frontend.Variable, frontend.Variable, frontend.Variable) frontend.Variable { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) IsZero(frontend.Variable) frontend.Variable { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) Cmp(frontend.Variable, frontend.Variable) frontend.Variable { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) AssertIsEqual(i1, i2 frontend.Variable) { - w.newError("only field operations supported") -} - -func (w *bn254WrapperApi) AssertIsDifferent(frontend.Variable, frontend.Variable) { - w.newError("only field operations supported") -} - -func (w *bn254WrapperApi) AssertIsBoolean(frontend.Variable) { - w.newError("only field operations supported") -} - -func (w *bn254WrapperApi) AssertIsCrumb(frontend.Variable) { - w.newError("only field operations supported") -} - -func (w *bn254WrapperApi) AssertIsLessOrEqual(frontend.Variable, frontend.Variable) { - w.newError("only field operations supported") -} - -func (w *bn254WrapperApi) Println(a ...frontend.Variable) { - toPrint := make([]any, len(a)) - for i, v := range a { - var x fr.Element - if _, err := x.SetInterface(v); err != nil { - if s, ok := v.(string); ok { - toPrint[i] = s - continue - } else { - w.newError("not numeric or string") - } - } else { - toPrint[i] = x.String() - } - } - fmt.Println(toPrint...) -} - -func (w *bn254WrapperApi) Compiler() frontend.Compiler { - w.newError("only field operations supported") - return nil -} - -func (w *bn254WrapperApi) NewHint(solver.Hint, int, ...frontend.Variable) ([]frontend.Variable, error) { - err := errors.New("only field operations supported") - w.emitError(err) - return nil, err -} - -func (w *bn254WrapperApi) ConstantValue(frontend.Variable) (*big.Int, bool) { - w.newError("only field operations supported") - return nil, false -} - -func (w *bn254WrapperApi) cast(v frontend.Variable) *fr.Element { - var res fr.Element - if w.err != nil { - return &res - } - if _, err := res.SetInterface(v); err != nil { - w.emitError(err) - } - return &res -} - -func (w *bn254WrapperApi) emitError(err error) { - if w.err == nil { - w.err = err - } -} - -func (w *bn254WrapperApi) newError(msg string) { - w.emitError(errors.New(msg)) -} diff --git a/std/gkr/testing.go b/std/gkr/testing.go index 74a8fc1f5c..dd99608b1f 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -9,21 +9,21 @@ import ( "github.com/consensys/gnark-crypto/ecc" frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - gkrBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr" frBls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - gkrBls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/gkr" frBls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - gkrBls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/gkr" frBls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - gkrBls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/gkr" frBn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" - gkrBn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" frBw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - gkrBw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/gkr" frBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - gkrBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + gkrBls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" + gkrBls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" + gkrBls24315 "github.com/consensys/gnark/internal/gkr/bls24-315" + gkrBls24317 "github.com/consensys/gnark/internal/gkr/bls24-317" + gkrBn254 "github.com/consensys/gnark/internal/gkr/bn254" + gkrBw6633 "github.com/consensys/gnark/internal/gkr/bw6-633" + gkrBw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" ) type solveInTestEngineSettings struct { From b2d06ba82aa64911b3a32baed28ad1f4e30b5aaa Mon Sep 17 00:00:00 2001 From: Tabaie Date: Fri, 4 Apr 2025 13:23:55 -0500 Subject: [PATCH 45/62] feat: fr side of gkr-poseidon2 --- .../poseidon2/{ => gkr-poseidon2}/gkr.go | 17 +- .../poseidon2/{ => gkr-poseidon2}/gkr_test.go | 2 +- .../gkr-poseidon2/internal/bls12-377/gates.go | 218 ++++++++++++++++++ .../gkr-poseidon2/internal/commons.go | 29 +++ 4 files changed, 257 insertions(+), 9 deletions(-) rename std/permutation/poseidon2/{ => gkr-poseidon2}/gkr.go (95%) rename std/permutation/poseidon2/{ => gkr-poseidon2}/gkr_test.go (98%) create mode 100644 std/permutation/poseidon2/gkr-poseidon2/internal/bls12-377/gates.go create mode 100644 std/permutation/poseidon2/gkr-poseidon2/internal/commons.go diff --git a/std/permutation/poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go similarity index 95% rename from std/permutation/poseidon2/gkr.go rename to std/permutation/poseidon2/gkr-poseidon2/gkr.go index 0498ffeee7..ea135fc2c7 100644 --- a/std/permutation/poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -1,8 +1,9 @@ -package poseidon2 +package gkr_poseidon2 import ( "errors" "fmt" + "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2/internal" "hash" "math/big" "sync" @@ -13,13 +14,13 @@ import ( frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" mimcBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" poseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" - gkrPoseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2/gkrgates" "github.com/consensys/gnark/constraint" csBls12377 "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/gkr" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" + gkrPoseidon2Bls12377 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2/internal/bls12-377" ) // extKeyGate applies the external matrix mul, then adds the round key @@ -157,7 +158,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // poseidon2 parameters roundKeysFr := poseidon2Bls12377.GetDefaultParameters().RoundKeys - gateNamer := gkrPoseidon2Bls12377.RoundGateNamer(poseidon2Bls12377.GetDefaultParameters()) + gateNamer := internal.RoundGateNamer[gkr.GateName](poseidon2Bls12377.GetDefaultParameters()) rF := poseidon2Bls12377.GetDefaultParameters().NbFullRounds rP := poseidon2Bls12377.GetDefaultParameters().NbPartialRounds halfRf := rF / 2 @@ -199,7 +200,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // register and apply external matrix multiplication and round key addition // round dependent due to the round key extKeySBox := func(round, varI int, a, b constraint.GkrVariable) constraint.GkrVariable { - gate := gkr.GateName(gateNamer.Linear(varI, round)) + gate := gateNamer.Linear(varI, round) if err = gkr.RegisterGate(gate, extKeyGate(frToInt(&roundKeysFr[round][varI])), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return -1 } @@ -211,7 +212,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // for the second variable // round independent due to the round key intKeySBox2 := func(round int, a, b constraint.GkrVariable) constraint.GkrVariable { - gate := gkr.GateName(gateNamer.Linear(yI, round)) + gate := gateNamer.Linear(yI, round) if err = gkr.RegisterGate(gate, intKeyGate2(frToInt(&roundKeysFr[round][1])), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return -1 } @@ -235,7 +236,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // still using the external matrix, since the linear operation still belongs to a full (canonical) round x1 := extKeySBox(halfRf, xI, x, y) - gate := gkr.GateName(gateNamer.Linear(yI, halfRf)) + gate := gateNamer.Linear(yI, halfRf) if err = gkr.RegisterGate(gate, extGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return nil, -1, err } @@ -246,7 +247,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. for i := halfRf + 1; i < halfRf+rP; i++ { x1 := extKeySBox(i, xI, x, y) // the first row of the internal matrix is the same as that of the external matrix - gate := gkr.GateName(gateNamer.Linear(yI, i)) + gate := gateNamer.Linear(yI, i) if err = gkr.RegisterGate(gate, intKeyGate2(zero), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return nil, -1, err } @@ -266,7 +267,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. } // apply the external matrix one last time to obtain the final value of y - gate := gkr.GateName(gateNamer.Linear(yI, rP+rF)) + gate := gateNamer.Linear(yI, rP+rF) if err = gkr.RegisterGate(gate, extAddGate, 3, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return nil, -1, err } diff --git a/std/permutation/poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go similarity index 98% rename from std/permutation/poseidon2/gkr_test.go rename to std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 96a9c2494c..3fd5c42a33 100644 --- a/std/permutation/poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -1,4 +1,4 @@ -package poseidon2 +package gkr_poseidon2 import ( "fmt" diff --git a/std/permutation/poseidon2/gkr-poseidon2/internal/bls12-377/gates.go b/std/permutation/poseidon2/gkr-poseidon2/internal/bls12-377/gates.go new file mode 100644 index 0000000000..67d45d704f --- /dev/null +++ b/std/permutation/poseidon2/gkr-poseidon2/internal/bls12-377/gates.go @@ -0,0 +1,218 @@ +package bls12_377 + +import ( + "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2/internal" + "sync" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + gkr "github.com/consensys/gnark/internal/gkr/bls12-377" +) + +// The GKR gates needed for proving Poseidon2 permutations + +// extKeySBoxGate applies the external matrix mul, then adds the round key, then applies the sBox +// because of its symmetry, we don't need to define distinct x1 and x2 versions of it +func extKeySBoxGate(roundKey *fr.Element) gkr.GateFunction { + return func(x ...fr.Element) fr.Element { + x[0]. + Double(&x[0]). + Add(&x[0], &x[1]). + Add(&x[0], roundKey) + return sBox2(x[0]) + } +} + +// intKeySBoxGate2 applies the second row of internal matrix mul, then adds the round key, then applies the sBox, returning the second element +func intKeySBoxGate2(roundKey *fr.Element) gkr.GateFunction { + return func(x ...fr.Element) fr.Element { + x[0].Add(&x[0], &x[1]) + x[1]. + Double(&x[1]). + Add(&x[1], &x[0]). + Add(&x[1], roundKey) + + return sBox2(x[1]) + } +} + +// extAddGate (x,y,z) -> Ext . (x,y) + z +func extAddGate(x ...fr.Element) fr.Element { + x[0]. + Double(&x[0]). + Add(&x[0], &x[1]). + Add(&x[0], &x[2]) + return x[0] +} + +// sBox2 is Permutation.sBox for t=2 +func sBox2(x fr.Element) fr.Element { + var y fr.Element + y.Square(&x).Square(&y).Square(&y).Square(&y).Mul(&x, &y) + return y +} + +// extKeyGate applies the external matrix mul, then adds the round key, then applies the sBox +// because of its symmetry, we don't need to define distinct x1 and x2 versions of it +func extKeyGate(roundKey *fr.Element) func(...fr.Element) fr.Element { + return func(x ...fr.Element) fr.Element { + x[0]. + Double(&x[0]). + Add(&x[0], &x[1]). + Add(&x[0], roundKey) + return x[0] + } +} + +// for x1, the partial round gates are identical to full round gates +// for x2, the partial round gates are just a linear combination + +// extGate2 applies the external matrix mul, outputting the second element of the result +func extGate2(x ...fr.Element) fr.Element { + x[1]. + Double(&x[1]). + Add(&x[1], &x[0]) + return x[1] +} + +// intGate2 applies the internal matrix mul, returning the second element +func intGate2(x ...fr.Element) fr.Element { + x[0].Add(&x[0], &x[1]) + x[1]. + Double(&x[1]). + Add(&x[1], &x[0]) + return x[1] +} + +// intKeyGate2 applies the second row of internal matrix mul, then adds the round key +func intKeyGate2(roundKey *fr.Element) gkr.GateFunction { + return func(x ...fr.Element) fr.Element { + x[0].Add(&x[0], &x[1]) + x[1]. + Double(&x[1]). + Add(&x[1], &x[0]). + Add(&x[1], roundKey) + + return x[1] + } +} + +// powGate4 x -> x⁴ +func pow4Gate(x ...fr.Element) fr.Element { + x[0].Square(&x[0]).Square(&x[0]) + return x[0] +} + +// pow4TimesGate x,y -> x⁴ * y +func pow4TimesGate(x ...fr.Element) fr.Element { + x[0].Square(&x[0]).Square(&x[0]).Mul(&x[0], &x[1]) + return x[0] +} + +// pow2Gate x -> x² +func pow2Gate(x ...fr.Element) fr.Element { + x[0].Square(&x[0]) + return x[0] +} + +// pow2TimesGate x,y -> x² * y +func pow2TimesGate(x ...fr.Element) fr.Element { + x[0].Square(&x[0]).Mul(&x[0], &x[1]) + return x[0] +} + +var initOnce sync.Once + +// RegisterGkrGates registers the Poseidon2 compression gates for GKR +func RegisterGkrGates() error { + const ( + x = iota + y + ) + var err error + initOnce.Do( + func() { + p := poseidon2.GetDefaultParameters() + halfRf := p.NbFullRounds / 2 + gateNames := internal.RoundGateNamer[gkr.GateName](p) + + if err = gkr.RegisterGate(internal.Pow2GateName, pow2Gate, 1, gkr.WithUnverifiedDegree(2), gkr.WithNoSolvableVar()); err != nil { + return + } + if err = gkr.RegisterGate(internal.Pow4GateName, pow4Gate, 1, gkr.WithUnverifiedDegree(4), gkr.WithNoSolvableVar()); err != nil { + return + } + if err = gkr.RegisterGate(internal.Pow2TimesGateName, pow2TimesGate, 2, gkr.WithUnverifiedDegree(3), gkr.WithNoSolvableVar()); err != nil { + return + } + if err = gkr.RegisterGate(internal.Pow4TimesGateName, pow4TimesGate, 2, gkr.WithUnverifiedDegree(5), gkr.WithNoSolvableVar()); err != nil { + return + } + + extKeySBox := func(round int, varIndex int) error { + if err := gkr.RegisterGate(gateNames.Integrated(varIndex, round), extKeySBoxGate(&p.RoundKeys[round][varIndex]), 2, gkr.WithUnverifiedDegree(poseidon2.DegreeSBox()), gkr.WithNoSolvableVar()); err != nil { + return err + } + + return gkr.RegisterGate(gateNames.Linear(varIndex, round), extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)) + } + + intKeySBox2 := func(round int) error { + if err := gkr.RegisterGate(gateNames.Linear(y, round), intKeyGate2(&p.RoundKeys[round][1]), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return err + } + return gkr.RegisterGate(gateNames.Integrated(y, round), intKeySBoxGate2(&p.RoundKeys[round][1]), 2, gkr.WithUnverifiedDegree(poseidon2.DegreeSBox()), gkr.WithNoSolvableVar()) + } + + fullRound := func(i int) error { + if err := extKeySBox(i, x); err != nil { + return err + } + return extKeySBox(i, y) + } + + for i := range halfRf { + if err = fullRound(i); err != nil { + return + } + } + + { // i = halfRf: first partial round + if err = extKeySBox(halfRf, x); err != nil { + return + } + if err = gkr.RegisterGate(gateNames.Linear(y, halfRf), extGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return + } + } + + for i := halfRf + 1; i < halfRf+p.NbPartialRounds; i++ { + if err = extKeySBox(i, x); err != nil { // for x1, intKeySBox is identical to extKeySBox + return + } + if err = gkr.RegisterGate(gateNames.Linear(y, i), intGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return + } + } + + { + i := halfRf + p.NbPartialRounds + if err = extKeySBox(i, x); err != nil { + return + } + if err = intKeySBox2(i); err != nil { + return + } + } + + for i := halfRf + p.NbPartialRounds + 1; i < p.NbPartialRounds+p.NbFullRounds; i++ { + if err = fullRound(i); err != nil { + return + } + } + + err = gkr.RegisterGate(gateNames.Linear(y, p.NbPartialRounds+p.NbFullRounds), extAddGate, 3, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)) + }, + ) + return err +} diff --git a/std/permutation/poseidon2/gkr-poseidon2/internal/commons.go b/std/permutation/poseidon2/gkr-poseidon2/internal/commons.go new file mode 100644 index 0000000000..34ad6316e6 --- /dev/null +++ b/std/permutation/poseidon2/gkr-poseidon2/internal/commons.go @@ -0,0 +1,29 @@ +package internal + +import ( + "fmt" +) + +const ( + Pow2GateName = "pow2" + Pow4GateName = "pow4" + Pow2TimesGateName = "pow2Times" + Pow4TimesGateName = "pow4Times" +) + +type roundGateNamer[T ~string] string + +// RoundGateNamer returns an object that returns standardized names for gates in the GKR circuit +func RoundGateNamer[T ~string](p fmt.Stringer) roundGateNamer[T] { + return roundGateNamer[T](p.String()) +} + +// Linear is the name of a gate where a polynomial of total degree 1 is applied to the input +func (n roundGateNamer[T]) Linear(varIndex, round int) T { + return T(fmt.Sprintf("x%d-l-op-round=%d;%s", varIndex, round, n)) +} + +// Integrated is the name of a gate where a polynomial of total degree 1 is applied to the input, followed by an S-box +func (n roundGateNamer[T]) Integrated(varIndex, round int) T { + return T(fmt.Sprintf("x%d-i-op-round=%d;%s", varIndex, round, n)) +} From 877856fae5c244e55c4a1f5625b64c7cc743787c Mon Sep 17 00:00:00 2001 From: Tabaie Date: Fri, 4 Apr 2025 13:31:35 -0500 Subject: [PATCH 46/62] fix small codegen issues --- internal/gkr/registry.go | 374 ------------------------- internal/gkr/sumcheck/sumcheck.go | 170 ----------- internal/gkr/sumcheck/sumcheck_test.go | 149 ---------- registry.go | 374 ------------------------- std/gkr/api_test.go | 2 +- std/gkr/example_test.go | 12 +- sumcheck/sumcheck.go | 170 ----------- sumcheck/sumcheck_test.go | 149 ---------- 8 files changed, 7 insertions(+), 1393 deletions(-) delete mode 100644 internal/gkr/registry.go delete mode 100644 internal/gkr/sumcheck/sumcheck.go delete mode 100644 internal/gkr/sumcheck/sumcheck_test.go delete mode 100644 registry.go delete mode 100644 sumcheck/sumcheck.go delete mode 100644 sumcheck/sumcheck_test.go diff --git a/internal/gkr/registry.go b/internal/gkr/registry.go deleted file mode 100644 index b48f179c20..0000000000 --- a/internal/gkr/registry.go +++ /dev/null @@ -1,374 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(small_rational.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]small_rational.SmallRational, nbIn) - consts := make(small_rational.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - x := make(small_rational.Vector, degreeBound) - x.MustSetRandom() - for i := range x { - fIn[0] = x[i] - for j := range consts { - fIn[j+1].Mul(&x[i], &consts[j]) - } - p[i] = f(fIn...) - } - - // obtain p's coefficients - p, err := interpolate(x, p) - if err != nil { - panic(err) - } - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) -// Note that the runtime is O(len(X)³) -func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { - if len(X) != len(Y) { - return nil, errors.New("X and Y must have the same length") - } - - // solve the system of equations by Gaussian elimination - augmentedRows := make([][]small_rational.SmallRational, len(X)) // the last column is the Y values - for i := range augmentedRows { - augmentedRows[i] = make([]small_rational.SmallRational, len(X)+1) - augmentedRows[i][0].SetOne() - augmentedRows[i][1].Set(&X[i]) - for j := 2; j < len(augmentedRows[i])-1; j++ { - augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) - } - augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) - } - - // make the upper triangle - for i := range len(augmentedRows) - 1 { - // use row i to eliminate the ith element in all rows below - var negInv small_rational.SmallRational - if augmentedRows[i][i].IsZero() { - return nil, errors.New("singular matrix") - } - negInv.Inverse(&augmentedRows[i][i]) - negInv.Neg(&negInv) - for j := i + 1; j < len(augmentedRows); j++ { - var c small_rational.SmallRational - c.Mul(&augmentedRows[j][i], &negInv) - // augmentedRows[j][i].SetZero() omitted - for k := i + 1; k < len(augmentedRows[i]); k++ { - var t small_rational.SmallRational - t.Mul(&augmentedRows[i][k], &c) - augmentedRows[j][k].Add(&augmentedRows[j][k], &t) - } - } - } - - // back substitution - res := make(polynomial.Polynomial, len(X)) - for i := len(augmentedRows) - 1; i >= 0; i-- { - res[i] = augmentedRows[i][len(augmentedRows[i])-1] - for j := i + 1; j < len(augmentedRows[i])-1; j++ { - var t small_rational.SmallRational - t.Mul(&res[j], &augmentedRows[i][j]) - res[i].Sub(&res[i], &t) - } - res[i].Div(&res[i], &augmentedRows[i][i]) - } - - return res, nil -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...small_rational.SmallRational) small_rational.SmallRational { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/internal/gkr/sumcheck/sumcheck.go b/internal/gkr/sumcheck/sumcheck.go deleted file mode 100644 index e491815a87..0000000000 --- a/internal/gkr/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package sumcheck - -import ( - "errors" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []small_rational.SmallRational) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a small_rational.SmallRational) small_rational.SmallRational // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []small_rational.SmallRational, remainingChallengeNames *[]string) (small_rational.SmallRational, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return small_rational.SmallRational{}, err - } - } - var res small_rational.SmallRational - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff small_rational.SmallRational - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]small_rational.SmallRational, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff small_rational.SmallRational - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]small_rational.SmallRational, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/internal/gkr/sumcheck/sumcheck_test.go b/internal/gkr/sumcheck/sumcheck_test.go deleted file mode 100644 index 85230fdb9d..0000000000 --- a/internal/gkr/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package sumcheck - -import ( - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark//test_vector_utils" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []small_rational.SmallRational{sum} -} - -func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum small_rational.SmallRational -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs small_rational.SmallRational) small_rational.SmallRational { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} diff --git a/registry.go b/registry.go deleted file mode 100644 index b48f179c20..0000000000 --- a/registry.go +++ /dev/null @@ -1,374 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package gkr - -import ( - "errors" - "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "slices" - "sync" -) - -type GateName string - -var ( - gates = make(map[GateName]*Gate) - gatesLock sync.Mutex -) - -type registerGateSettings struct { - solvableVar int - noSolvableVarVerification bool - noDegreeVerification bool - degree int -} - -type RegisterGateOption func(*registerGateSettings) - -// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will return an error if it cannot verify that this claim is correct. -func WithSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = solvableVar - } -} - -// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not verify that the given index is correct. -func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noSolvableVarVerification = true - settings.solvableVar = solvableVar - } -} - -// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// RegisterGate will not check the correctness of this claim. -func WithNoSolvableVar() RegisterGateOption { - return func(settings *registerGateSettings) { - settings.solvableVar = -1 - settings.noSolvableVarVerification = true - } -} - -// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. -func WithUnverifiedDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.noDegreeVerification = true - settings.degree = degree - } -} - -// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. -func WithDegree(degree int) RegisterGateOption { - return func(settings *registerGateSettings) { - settings.degree = degree - } -} - -// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f -func (f GateFunction) isAdditive(i, nbIn int) bool { - // fix all variables except the i-th one at random points - // pick random value x1 for the i-th variable - // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) - x := make(small_rational.Vector, nbIn) - x.MustSetRandom() - x0 := x[i] - x[i].SetZero() - in := slices.Clone(x) - y0 := f(in...) - - x[i] = x0 - copy(in, x) - y1 := f(in...) - - x[i].Double(&x[i]) - copy(in, x) - y2 := f(in...) - - y2.Sub(&y2, &y1) - y1.Sub(&y1, &y0) - - if !y2.Equal(&y1) { - return false // not linear - } - - // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) - if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 - return false - } - - // compute the slope with another assignment for the other variables - x.MustSetRandom() - x[i].SetZero() - copy(in, x) - y0 = f(in...) - - x[i] = x0 - copy(in, x) - y1 = f(in...) - - y1.Sub(&y1, &y0) - - return y1.Equal(&y2) -} - -// fitPoly tries to fit a polynomial of degree less than degreeBound to f. -// degreeBound must be a power of 2. -// It returns the polynomial if successful, nil otherwise -func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { - // turn f univariate by defining p(x) as f(x, rx, ..., sx) - // where r, s, ... are random constants - fIn := make([]small_rational.SmallRational, nbIn) - consts := make(small_rational.Vector, nbIn-1) - consts.MustSetRandom() - - p := make(polynomial.Polynomial, degreeBound) - x := make(small_rational.Vector, degreeBound) - x.MustSetRandom() - for i := range x { - fIn[0] = x[i] - for j := range consts { - fIn[j+1].Mul(&x[i], &consts[j]) - } - p[i] = f(fIn...) - } - - // obtain p's coefficients - p, err := interpolate(x, p) - if err != nil { - panic(err) - } - - // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound - fIn[0].MustSetRandom() - for i := range consts { - fIn[i+1].Mul(&fIn[0], &consts[i]) - } - pAt := p.Eval(&fIn[0]) - fAt := f(fIn...) - if !pAt.Equal(&fAt) { - return nil - } - - // trim p - lastNonZero := len(p) - 1 - for lastNonZero >= 0 && p[lastNonZero].IsZero() { - lastNonZero-- - } - return p[:lastNonZero+1] -} - -type errorString string - -func (e errorString) Error() string { - return string(e) -} - -const errZeroFunction = errorString("detected a zero function") - -// FindDegree returns the degree of the gate function, or -1 if it fails. -// Failure could be due to the degree being higher than max or the function not being a polynomial at all. -func (f GateFunction) FindDegree(max, nbIn int) (int, error) { - bound := uint64(max) + 1 - for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { - if p := f.fitPoly(nbIn, degreeBound); p != nil { - if len(p) == 0 { - return -1, errZeroFunction - } - return len(p) - 1, nil - } - } - return -1, fmt.Errorf("could not find a degree: tried up to %d", max) -} - -func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { - if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { - return fmt.Errorf("detected a higher degree than %d", claimedDegree) - } else if len(p) == 0 { - return errZeroFunction - } else if len(p)-1 != claimedDegree { - return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) - } - return nil -} - -// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns -1 if it fails to find one. -// nbIn is the number of inputs to the gate -func (f GateFunction) FindSolvableVar(nbIn int) int { - for i := range nbIn { - if f.isAdditive(i, nbIn) { - return i - } - } - return -1 -} - -// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. -// It returns false if it fails to verify this claim. -// nbIn is the number of inputs to the gate. -func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { - return f.isAdditive(claimedSolvableVar, nbIn) -} - -// RegisterGate creates a gate object and stores it in the gates registry. -// name is a human-readable name for the gate. -// f is the polynomial function defining the gate. -// nbIn is the number of inputs to the gate. -func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { - s := registerGateSettings{degree: -1, solvableVar: -1} - for _, option := range options { - option(&s) - } - - if s.degree == -1 { // find a degree - if s.noDegreeVerification { - panic("invalid settings") - } - const maxAutoDegreeBound = 32 - var err error - if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } else { - if !s.noDegreeVerification { // check that the given degree is correct - if err := f.VerifyDegree(s.degree, nbIn); err != nil { - return fmt.Errorf("for gate %s: %v", name, err) - } - } - } - - if s.solvableVar == -1 { - if !s.noSolvableVarVerification { // find a solvable variable - s.solvableVar = f.FindSolvableVar(nbIn) - } - } else { - // solvable variable given - if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { - return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) - } - } - - gatesLock.Lock() - defer gatesLock.Unlock() - gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} - return nil -} - -func GetGate(name GateName) *Gate { - gatesLock.Lock() - defer gatesLock.Unlock() - return gates[name] -} - -// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) -// Note that the runtime is O(len(X)³) -func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { - if len(X) != len(Y) { - return nil, errors.New("X and Y must have the same length") - } - - // solve the system of equations by Gaussian elimination - augmentedRows := make([][]small_rational.SmallRational, len(X)) // the last column is the Y values - for i := range augmentedRows { - augmentedRows[i] = make([]small_rational.SmallRational, len(X)+1) - augmentedRows[i][0].SetOne() - augmentedRows[i][1].Set(&X[i]) - for j := 2; j < len(augmentedRows[i])-1; j++ { - augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) - } - augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) - } - - // make the upper triangle - for i := range len(augmentedRows) - 1 { - // use row i to eliminate the ith element in all rows below - var negInv small_rational.SmallRational - if augmentedRows[i][i].IsZero() { - return nil, errors.New("singular matrix") - } - negInv.Inverse(&augmentedRows[i][i]) - negInv.Neg(&negInv) - for j := i + 1; j < len(augmentedRows); j++ { - var c small_rational.SmallRational - c.Mul(&augmentedRows[j][i], &negInv) - // augmentedRows[j][i].SetZero() omitted - for k := i + 1; k < len(augmentedRows[i]); k++ { - var t small_rational.SmallRational - t.Mul(&augmentedRows[i][k], &c) - augmentedRows[j][k].Add(&augmentedRows[j][k], &t) - } - } - } - - // back substitution - res := make(polynomial.Polynomial, len(X)) - for i := len(augmentedRows) - 1; i >= 0; i-- { - res[i] = augmentedRows[i][len(augmentedRows[i])-1] - for j := i + 1; j < len(augmentedRows[i])-1; j++ { - var t small_rational.SmallRational - t.Mul(&res[j], &augmentedRows[i][j]) - res[i].Sub(&res[i], &t) - } - res[i].Div(&res[i], &augmentedRows[i][i]) - } - - return res, nil -} - -const ( - Identity GateName = "identity" // Identity gate: x -> x - Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y - Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y - Neg GateName = "neg" // Neg gate: x -> -x - Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y -) - -func init() { - // register some basic gates - - if err := RegisterGate(Identity, func(x ...small_rational.SmallRational) small_rational.SmallRational { - return x[0] - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Add2, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Add(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Sub2, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Sub(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Neg, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Neg(&x[0]) - return res - }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { - panic(err) - } - - if err := RegisterGate(Mul2, func(x ...small_rational.SmallRational) small_rational.SmallRational { - var res small_rational.SmallRational - res.Mul(&x[0], &x[1]) - return res - }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { - panic(err) - } -} diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go index 1d6c14e85a..be2f48492b 100644 --- a/std/gkr/api_test.go +++ b/std/gkr/api_test.go @@ -23,11 +23,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + gkr "github.com/consensys/gnark/internal/gkr/bn254" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" test_vector_utils "github.com/consensys/gnark/std/internal/test_vectors_utils" diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go index f0099f23a5..d61adae39e 100644 --- a/std/gkr/example_test.go +++ b/std/gkr/example_test.go @@ -6,10 +6,10 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - gkrBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" gcHash "github.com/consensys/gnark-crypto/hash" bw6761 "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark/frontend" + gkrBw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" "github.com/consensys/gnark/std/gkr" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" @@ -159,7 +159,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { gkrApi := gkr.NewApi() - assertNoError(gkr.RegisterGate("square", func(api frontend.API, input ...frontend.Variable) (res frontend.Variable) { + assertNoError(gkr.RegisterGate("square", func(api gkr.GateAPI, input ...frontend.Variable) (res frontend.Variable) { return api.Mul(input[0], input[0]) }, 1)) @@ -187,7 +187,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { ZZ := gkrApi.NamedGate("square", Z) // 408: ZZ.Square(&p.Z) // define the SNARK version of the custom gates, similarly to the ones in Example - assertNoError(gkr.RegisterGate(c.gateNamePrefix+"s", func(api frontend.API, input ...frontend.Variable) (S frontend.Variable) { + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"s", func(api gkr.GateAPI, input ...frontend.Variable) (S frontend.Variable) { S = api.Add(input[0], input[1]) // 409: S.Add(&p.X, &YY) S = api.Mul(S, S) // 410: S.Square(&S). S = api.Sub(S, input[2], input[3]) // 411: Sub(&S, &XX). @@ -202,7 +202,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { // combine the operations that define the assignment to p.Z // input = [p.Z, p.Y, YY, ZZ] // Z = (p.Z + p.Y)² - YY - ZZ - assertNoError(gkr.RegisterGate(c.gateNamePrefix+"z", func(api frontend.API, input ...frontend.Variable) (Z frontend.Variable) { + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"z", func(api gkr.GateAPI, input ...frontend.Variable) (Z frontend.Variable) { Z = api.Add(input[0], input[1]) // 415: p.Z.Add(&p.Z, &p.Y). Z = api.Mul(Z, Z) // 416: p.Z.Square(&p.Z). Z = api.Sub(Z, input[2], input[3]) // 417: Sub(&p.Z, &YY). @@ -214,7 +214,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { // combine the operations that define the assignment to p.X // input = [XX, S] // p.X = 9XX² - 2S - assertNoError(gkr.RegisterGate(c.gateNamePrefix+"x", func(api frontend.API, input ...frontend.Variable) (X frontend.Variable) { + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"x", func(api gkr.GateAPI, input ...frontend.Variable) (X frontend.Variable) { M := api.Mul(input[0], 3) // 414: M.Double(&XX).Add(&M, &XX) T := api.Mul(M, M) // 419: T.Square(&M) X = api.Sub(T, api.Mul(input[1], 2)) // 420: p.X = T @@ -227,7 +227,7 @@ func (c *exampleCircuit) Define(api frontend.API) error { // combine the operations that define the assignment to p.Y // input = [S, p.X, XX, YYYY] // p.Y = (S - p.X) * 3 * XX - 8 * YYYY - assertNoError(gkr.RegisterGate(c.gateNamePrefix+"y", func(api frontend.API, input ...frontend.Variable) (Y frontend.Variable) { + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"y", func(api gkr.GateAPI, input ...frontend.Variable) (Y frontend.Variable) { Y = api.Sub(input[0], input[1]) // 423: p.Y.Sub(&S, &p.X). Y = api.Mul(Y, input[2], 3) // 414: M.Double(&XX).Add(&M, &XX) // 424:Mul(&p.Y, &M) diff --git a/sumcheck/sumcheck.go b/sumcheck/sumcheck.go deleted file mode 100644 index e491815a87..0000000000 --- a/sumcheck/sumcheck.go +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package sumcheck - -import ( - "errors" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "strconv" -) - -// This does not make use of parallelism and represents polynomials as lists of coefficients -// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. - -// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. -// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) -type Claims interface { - Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []small_rational.SmallRational) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. -type LazyClaims interface { - ClaimsNum() int // ClaimsNum = m - VarsNum() int // VarsNum = n - CombinedSum(a small_rational.SmallRational) small_rational.SmallRational // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ - Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error -} - -// Proof of a multi-sumcheck statement. -type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof -} - -func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { - numChallenges := varsNum - if claimsNum >= 2 { - numChallenges++ - } - challengeNames = make([]string, numChallenges) - if claimsNum >= 2 { - challengeNames[0] = settings.Prefix + "comb" - } - prefix := settings.Prefix + "pSP." - for i := 0; i < varsNum; i++ { - challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) - } - if settings.Transcript == nil { - transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) - settings.Transcript = transcript - } - - for i := range settings.BaseChallenges { - if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { - return - } - } - return -} - -func next(transcript *fiatshamir.Transcript, bindings []small_rational.SmallRational, remainingChallengeNames *[]string) (small_rational.SmallRational, error) { - challengeName := (*remainingChallengeNames)[0] - for i := range bindings { - bytes := bindings[i].Bytes() - if err := transcript.Bind(challengeName, bytes[:]); err != nil { - return small_rational.SmallRational{}, err - } - } - var res small_rational.SmallRational - bytes, err := transcript.ComputeChallenge(challengeName) - res.SetBytes(bytes) - - *remainingChallengeNames = (*remainingChallengeNames)[1:] - - return res, err -} - -// Prove create a non-interactive sumcheck proof -func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { - - var proof Proof - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return proof, err - } - - var combinationCoeff small_rational.SmallRational - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { - return proof, err - } - } - - varsNum := claims.VarsNum() - proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) - proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) - challenges := make([]small_rational.SmallRational, varsNum) - - for j := 0; j+1 < varsNum; j++ { - if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return proof, err - } - proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) - } - - if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { - return proof, err - } - - proof.FinalEvalProof = claims.ProveFinalEval(challenges) - - return proof, nil -} - -func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { - remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) - transcript := transcriptSettings.Transcript - if err != nil { - return err - } - - var combinationCoeff small_rational.SmallRational - - if claims.ClaimsNum() >= 2 { - if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { - return err - } - } - - r := make([]small_rational.SmallRational, claims.VarsNum()) - - // Just so that there is enough room for gJ to be reused - maxDegree := claims.Degree(0) - for j := 1; j < claims.VarsNum(); j++ { - if d := claims.Degree(j); d > maxDegree { - maxDegree = d - } - } - gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() - gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - - for j := 0; j < claims.VarsNum(); j++ { - if len(proof.PartialSumPolys[j]) != claims.Degree(j) { - return errors.New("malformed proof") - } - copy(gJ[1:], proof.PartialSumPolys[j]) - gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) - // gJ is ready - - //Prepare for the next iteration - if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { - return err - } - // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial - gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) - gJR = gJCoeffs.Eval(&r[j]) - } - - return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) -} diff --git a/sumcheck/sumcheck_test.go b/sumcheck/sumcheck_test.go deleted file mode 100644 index 85230fdb9d..0000000000 --- a/sumcheck/sumcheck_test.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2020-2025 Consensys Software Inc. -// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. - -// Code generated by gnark DO NOT EDIT - -package sumcheck - -import ( - "fmt" - fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" - "github.com/consensys/gnark//test_vector_utils" - "github.com/consensys/gnark/internal/small_rational" - "github.com/consensys/gnark/internal/small_rational/polynomial" - "github.com/stretchr/testify/assert" - "hash" - "math/bits" - "strings" - "testing" -) - -type singleMultilinClaim struct { - g polynomial.MultiLin -} - -func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) interface{} { - return nil // verifier can compute the final eval itself -} - -func (c singleMultilinClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func (c singleMultilinClaim) ClaimsNum() int { - return 1 -} - -func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { - sum := g[len(g)/2] - for i := len(g)/2 + 1; i < len(g); i++ { - sum.Add(&sum, &g[i]) - } - return []small_rational.SmallRational{sum} -} - -func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { - return sumForX1One(c.g) -} - -func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { - c.g.Fold(r) - return sumForX1One(c.g) -} - -type singleMultilinLazyClaim struct { - g polynomial.MultiLin - claimedSum small_rational.SmallRational -} - -func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { - val := c.g.Evaluate(r, nil) - if val.Equal(&purportedValue) { - return nil - } - return fmt.Errorf("mismatch") -} - -func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs small_rational.SmallRational) small_rational.SmallRational { - return c.claimedSum -} - -func (c singleMultilinLazyClaim) Degree(i int) int { - return 1 -} - -func (c singleMultilinLazyClaim) ClaimsNum() int { - return 1 -} - -func (c singleMultilinLazyClaim) VarsNum() int { - return bits.TrailingZeros(uint(len(c.g))) -} - -func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { - poly := make(polynomial.MultiLin, len(polyInt)) - for i, n := range polyInt { - poly[i].SetUint64(n) - } - - claim := singleMultilinClaim{g: poly.Clone()} - - proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) - if err != nil { - return err - } - - var sb strings.Builder - for _, p := range proof.PartialSumPolys { - - sb.WriteString("\t{") - for i := 0; i < len(p); i++ { - sb.WriteString(p[i].String()) - if i+1 < len(p) { - sb.WriteString(", ") - } - } - sb.WriteString("}\n") - } - - lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { - return err - } - - proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) - lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} - if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { - return fmt.Errorf("bad proof accepted") - } - return nil -} - -func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { - - polys := [][]uint64{ - {1, 2, 3, 4}, // 1 + 2X₁ + X₂ - {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ - } - - const MaxStep = 4 - const MaxStart = 4 - hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) - - for step := 0; step < MaxStep; step++ { - for startState := 0; startState < MaxStart; startState++ { - if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted - continue - } - hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) - } - } - - for _, poly := range polys { - for _, hashGen := range hashGens { - assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), - "failed with poly %v and hashGen %v", poly, hashGen()) - } - } -} From 4b2b1830a2a35cf56f3f823d0e6926b581aa6c7a Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 7 Apr 2025 11:45:52 -0500 Subject: [PATCH 47/62] docs: better comments for gkr --- internal/gkr/bn254/gkr.go | 107 +++++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 42 deletions(-) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 971a3ac342..3eb6b1c168 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -27,16 +27,16 @@ type GateFunction func(...fr.Element) fr.Element type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of g solvableVar int // if there is a solvable variable, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. func (g *Gate) SolvableVar() int { return g.solvableVar } @@ -88,10 +88,13 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly manager *claimsManager // WARNING: Circular references } @@ -103,6 +106,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { return len(e.evaluationPoints[0]) } +// CombinedSum returns ∑ᵢ aⁱ yᵢ func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) return evalsAsPoly.Eval(&a) @@ -112,10 +116,22 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { return 1 + e.wire.Gate.Degree() } +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { inputEvaluationsNoRedundancy := proof.([]fr.Element) - // the eq terms + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -124,11 +140,11 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb evaluation.Add(&evaluation, &eq) } - // the g(...) term + // the w(...) term var gateEvaluation fr.Element - if e.wire.IsInput() { + if e.wire.IsInput() { // just compute w(r) gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) @@ -142,7 +158,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } + } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { @@ -159,40 +175,49 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb return errors.New("incompatible evaluations") } +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { varsNum := c.VarsNum() eqLength := 1 << varsNum claimsNum := c.ClaimsNum() - // initialize the eq tables + // initialize the eq tables ( E ) c.eq = c.manager.memPool.Make(eqLength) c.eq[0].SetOne() c.eq.Eq(c.evaluationPoints[0]) + // E := eq(x₀, -) newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) aI := combinationCoeff - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { newEq[0].Set(&aI) c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - if k+1 < claimsNum { aI.Mul(&aI, &combinationCoeff) } @@ -200,16 +225,16 @@ func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) pol c.manager.memPool.Dump(newEq) - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - return c.computeGJ() } -// eqAcc sets m to an eq table at q and then adds it to e +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { n := len(q) - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ const threshold = 1 << 6 @@ -219,8 +244,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } } else { c.manager.workers.Submit(k, func(start, end int) { @@ -228,8 +253,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } }, 1024).Wait() } @@ -240,19 +265,19 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E e[i].Add(&e[i], &m[i]) } }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) } -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) nbGateIn := len(c.inputPreprocessors) - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables s := make([]polynomial.MultiLin, nbGateIn+1) s[0] = c.eq copy(s[1:], c.inputPreprocessors) @@ -267,7 +292,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) + operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate for i := start; i < end; i++ { @@ -306,12 +331,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() } - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 From 188a49dfebad264d8a4c9e55cd88eda34590f418 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 7 Apr 2025 11:56:39 -0500 Subject: [PATCH 48/62] docs: generify comments --- .../backend/template/gkr/gkr.go.tmpl | 107 ++++++++++------- internal/gkr/bls12-377/gkr.go | 107 ++++++++++------- internal/gkr/bls12-381/gkr.go | 107 ++++++++++------- internal/gkr/bls24-315/gkr.go | 107 ++++++++++------- internal/gkr/bls24-317/gkr.go | 107 ++++++++++------- internal/gkr/bn254/gkr.go | 2 +- internal/gkr/bw6-633/gkr.go | 107 ++++++++++------- internal/gkr/bw6-761/gkr.go | 107 ++++++++++------- internal/gkr/small_rational/gkr.go | 109 +++++++++++------- 9 files changed, 522 insertions(+), 338 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 886feeb3ca..192ef25091 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -22,16 +22,16 @@ type GateFunction func(...{{.ElementType}}) {{.ElementType}} type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a solvable variable, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. func (g *Gate) SolvableVar() int { return g.solvableVar } @@ -83,10 +83,13 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]{{.ElementType}} - claimedEvaluations []{{.ElementType}} + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]{{.ElementType}} // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []{{.ElementType}} // yᵢ = w(xᵢ), allegedly manager *claimsManager // WARNING: Circular references } @@ -98,6 +101,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { return len(e.evaluationPoints[0]) } +// CombinedSum returns ∑ᵢ aⁱ yᵢ func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a {{.ElementType}}) {{.ElementType}} { evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) return evalsAsPoly.Eval(&a) @@ -107,10 +111,22 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { return 1 + e.wire.Gate.Degree() } +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { inputEvaluationsNoRedundancy := proof.([]{{.ElementType}}) - // the eq terms + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -119,11 +135,11 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}} evaluation.Add(&evaluation, &eq) } - // the g(...) term + // the w(...) term var gateEvaluation {{.ElementType}} - if e.wire.IsInput() { + if e.wire.IsInput() { // just compute w(r) gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]{{.ElementType}}, len(e.wire.Inputs)) indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) @@ -137,7 +153,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}} // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } + } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { @@ -154,40 +170,49 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}} return errors.New("incompatible evaluations") } +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]{{.ElementType}} // x in the paper - claimedEvaluations []{{.ElementType}} // y in the paper + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]{{.ElementType}} // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []{{.ElementType}} // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff {{.ElementType}}) polynomial.Polynomial { varsNum := c.VarsNum() eqLength := 1 << varsNum claimsNum := c.ClaimsNum() - // initialize the eq tables + // initialize the eq tables ( E ) c.eq = c.manager.memPool.Make(eqLength) c.eq[0].SetOne() c.eq.Eq(c.evaluationPoints[0]) + // E := eq(x₀, -) newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) aI := combinationCoeff - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { newEq[0].Set(&aI) c.eqAcc(c.eq, newEq,c.evaluationPoints[k]) - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - if k+1 < claimsNum { aI.Mul(&aI, &combinationCoeff) } @@ -195,16 +220,16 @@ func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff {{.ElementType} c.manager.memPool.Dump(newEq) - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - return c.computeGJ() } -// eqAcc sets m to an eq table at q and then adds it to e +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.ElementType}}) { n := len(q) - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ const threshold = 1 << 6 @@ -214,8 +239,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } } else { c.manager.workers.Submit(k, func(start, end int) { @@ -223,8 +248,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } }, 1024).Wait() } @@ -235,20 +260,20 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.E e[i].Add(&e[i], &m[i]) } }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) } -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) nbGateIn := len(c.inputPreprocessors) - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables s := make([]polynomial.MultiLin, nbGateIn+1) s[0] = c.eq copy(s[1:], c.inputPreprocessors) @@ -263,7 +288,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { var step {{.ElementType}} res := make([]{{.ElementType}}, degGJ) - operands := make([]{{.ElementType}}, degGJ*nbInner) + operands := make([]{{.ElementType}}, degGJ*nbInner) // the eq value, followed by input to the gate for i := start; i < end; i++ { @@ -302,12 +327,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() } - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ func (c *eqTimesGateEvalSumcheckClaims) Next(element {{.ElementType}}) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 725ba5fbcd..abc764edca 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -27,16 +27,16 @@ type GateFunction func(...fr.Element) fr.Element type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a solvable variable, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. func (g *Gate) SolvableVar() int { return g.solvableVar } @@ -88,10 +88,13 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly manager *claimsManager // WARNING: Circular references } @@ -103,6 +106,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { return len(e.evaluationPoints[0]) } +// CombinedSum returns ∑ᵢ aⁱ yᵢ func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) return evalsAsPoly.Eval(&a) @@ -112,10 +116,22 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { return 1 + e.wire.Gate.Degree() } +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { inputEvaluationsNoRedundancy := proof.([]fr.Element) - // the eq terms + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -124,11 +140,11 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb evaluation.Add(&evaluation, &eq) } - // the g(...) term + // the w(...) term var gateEvaluation fr.Element - if e.wire.IsInput() { + if e.wire.IsInput() { // just compute w(r) gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) @@ -142,7 +158,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } + } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { @@ -159,40 +175,49 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb return errors.New("incompatible evaluations") } +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { varsNum := c.VarsNum() eqLength := 1 << varsNum claimsNum := c.ClaimsNum() - // initialize the eq tables + // initialize the eq tables ( E ) c.eq = c.manager.memPool.Make(eqLength) c.eq[0].SetOne() c.eq.Eq(c.evaluationPoints[0]) + // E := eq(x₀, -) newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) aI := combinationCoeff - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { newEq[0].Set(&aI) c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - if k+1 < claimsNum { aI.Mul(&aI, &combinationCoeff) } @@ -200,16 +225,16 @@ func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) pol c.manager.memPool.Dump(newEq) - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - return c.computeGJ() } -// eqAcc sets m to an eq table at q and then adds it to e +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { n := len(q) - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ const threshold = 1 << 6 @@ -219,8 +244,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } } else { c.manager.workers.Submit(k, func(start, end int) { @@ -228,8 +253,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } }, 1024).Wait() } @@ -240,19 +265,19 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E e[i].Add(&e[i], &m[i]) } }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) } -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) nbGateIn := len(c.inputPreprocessors) - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables s := make([]polynomial.MultiLin, nbGateIn+1) s[0] = c.eq copy(s[1:], c.inputPreprocessors) @@ -267,7 +292,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) + operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate for i := start; i < end; i++ { @@ -306,12 +331,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() } - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index c0387ff7bd..696e9299b2 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -27,16 +27,16 @@ type GateFunction func(...fr.Element) fr.Element type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a solvable variable, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. func (g *Gate) SolvableVar() int { return g.solvableVar } @@ -88,10 +88,13 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly manager *claimsManager // WARNING: Circular references } @@ -103,6 +106,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { return len(e.evaluationPoints[0]) } +// CombinedSum returns ∑ᵢ aⁱ yᵢ func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) return evalsAsPoly.Eval(&a) @@ -112,10 +116,22 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { return 1 + e.wire.Gate.Degree() } +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { inputEvaluationsNoRedundancy := proof.([]fr.Element) - // the eq terms + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -124,11 +140,11 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb evaluation.Add(&evaluation, &eq) } - // the g(...) term + // the w(...) term var gateEvaluation fr.Element - if e.wire.IsInput() { + if e.wire.IsInput() { // just compute w(r) gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) @@ -142,7 +158,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } + } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { @@ -159,40 +175,49 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb return errors.New("incompatible evaluations") } +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { varsNum := c.VarsNum() eqLength := 1 << varsNum claimsNum := c.ClaimsNum() - // initialize the eq tables + // initialize the eq tables ( E ) c.eq = c.manager.memPool.Make(eqLength) c.eq[0].SetOne() c.eq.Eq(c.evaluationPoints[0]) + // E := eq(x₀, -) newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) aI := combinationCoeff - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { newEq[0].Set(&aI) c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - if k+1 < claimsNum { aI.Mul(&aI, &combinationCoeff) } @@ -200,16 +225,16 @@ func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) pol c.manager.memPool.Dump(newEq) - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - return c.computeGJ() } -// eqAcc sets m to an eq table at q and then adds it to e +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { n := len(q) - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ const threshold = 1 << 6 @@ -219,8 +244,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } } else { c.manager.workers.Submit(k, func(start, end int) { @@ -228,8 +253,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } }, 1024).Wait() } @@ -240,19 +265,19 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E e[i].Add(&e[i], &m[i]) } }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) } -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) nbGateIn := len(c.inputPreprocessors) - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables s := make([]polynomial.MultiLin, nbGateIn+1) s[0] = c.eq copy(s[1:], c.inputPreprocessors) @@ -267,7 +292,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) + operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate for i := start; i < end; i++ { @@ -306,12 +331,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() } - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 22809d20f0..fc054855af 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -27,16 +27,16 @@ type GateFunction func(...fr.Element) fr.Element type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a solvable variable, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. func (g *Gate) SolvableVar() int { return g.solvableVar } @@ -88,10 +88,13 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly manager *claimsManager // WARNING: Circular references } @@ -103,6 +106,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { return len(e.evaluationPoints[0]) } +// CombinedSum returns ∑ᵢ aⁱ yᵢ func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) return evalsAsPoly.Eval(&a) @@ -112,10 +116,22 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { return 1 + e.wire.Gate.Degree() } +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { inputEvaluationsNoRedundancy := proof.([]fr.Element) - // the eq terms + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -124,11 +140,11 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb evaluation.Add(&evaluation, &eq) } - // the g(...) term + // the w(...) term var gateEvaluation fr.Element - if e.wire.IsInput() { + if e.wire.IsInput() { // just compute w(r) gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) @@ -142,7 +158,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } + } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { @@ -159,40 +175,49 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb return errors.New("incompatible evaluations") } +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { varsNum := c.VarsNum() eqLength := 1 << varsNum claimsNum := c.ClaimsNum() - // initialize the eq tables + // initialize the eq tables ( E ) c.eq = c.manager.memPool.Make(eqLength) c.eq[0].SetOne() c.eq.Eq(c.evaluationPoints[0]) + // E := eq(x₀, -) newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) aI := combinationCoeff - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { newEq[0].Set(&aI) c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - if k+1 < claimsNum { aI.Mul(&aI, &combinationCoeff) } @@ -200,16 +225,16 @@ func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) pol c.manager.memPool.Dump(newEq) - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - return c.computeGJ() } -// eqAcc sets m to an eq table at q and then adds it to e +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { n := len(q) - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ const threshold = 1 << 6 @@ -219,8 +244,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } } else { c.manager.workers.Submit(k, func(start, end int) { @@ -228,8 +253,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } }, 1024).Wait() } @@ -240,19 +265,19 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E e[i].Add(&e[i], &m[i]) } }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) } -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) nbGateIn := len(c.inputPreprocessors) - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables s := make([]polynomial.MultiLin, nbGateIn+1) s[0] = c.eq copy(s[1:], c.inputPreprocessors) @@ -267,7 +292,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) + operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate for i := start; i < end; i++ { @@ -306,12 +331,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() } - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 5b26065286..7dfc5765f1 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -27,16 +27,16 @@ type GateFunction func(...fr.Element) fr.Element type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a solvable variable, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. func (g *Gate) SolvableVar() int { return g.solvableVar } @@ -88,10 +88,13 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly manager *claimsManager // WARNING: Circular references } @@ -103,6 +106,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { return len(e.evaluationPoints[0]) } +// CombinedSum returns ∑ᵢ aⁱ yᵢ func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) return evalsAsPoly.Eval(&a) @@ -112,10 +116,22 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { return 1 + e.wire.Gate.Degree() } +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { inputEvaluationsNoRedundancy := proof.([]fr.Element) - // the eq terms + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -124,11 +140,11 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb evaluation.Add(&evaluation, &eq) } - // the g(...) term + // the w(...) term var gateEvaluation fr.Element - if e.wire.IsInput() { + if e.wire.IsInput() { // just compute w(r) gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) @@ -142,7 +158,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } + } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { @@ -159,40 +175,49 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb return errors.New("incompatible evaluations") } +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { varsNum := c.VarsNum() eqLength := 1 << varsNum claimsNum := c.ClaimsNum() - // initialize the eq tables + // initialize the eq tables ( E ) c.eq = c.manager.memPool.Make(eqLength) c.eq[0].SetOne() c.eq.Eq(c.evaluationPoints[0]) + // E := eq(x₀, -) newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) aI := combinationCoeff - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { newEq[0].Set(&aI) c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - if k+1 < claimsNum { aI.Mul(&aI, &combinationCoeff) } @@ -200,16 +225,16 @@ func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) pol c.manager.memPool.Dump(newEq) - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - return c.computeGJ() } -// eqAcc sets m to an eq table at q and then adds it to e +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { n := len(q) - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ const threshold = 1 << 6 @@ -219,8 +244,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } } else { c.manager.workers.Submit(k, func(start, end int) { @@ -228,8 +253,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } }, 1024).Wait() } @@ -240,19 +265,19 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E e[i].Add(&e[i], &m[i]) } }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) } -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) nbGateIn := len(c.inputPreprocessors) - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables s := make([]polynomial.MultiLin, nbGateIn+1) s[0] = c.eq copy(s[1:], c.inputPreprocessors) @@ -267,7 +292,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) + operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate for i := start; i < end; i++ { @@ -306,12 +331,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() } - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 3eb6b1c168..d20a7fa8f9 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -27,7 +27,7 @@ type GateFunction func(...fr.Element) fr.Element type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of g + degree int // total degree of the polynomial solvableVar int // if there is a solvable variable, its index, -1 otherwise } diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 932070198f..f4f79ac01d 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -27,16 +27,16 @@ type GateFunction func(...fr.Element) fr.Element type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a solvable variable, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. func (g *Gate) SolvableVar() int { return g.solvableVar } @@ -88,10 +88,13 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly manager *claimsManager // WARNING: Circular references } @@ -103,6 +106,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { return len(e.evaluationPoints[0]) } +// CombinedSum returns ∑ᵢ aⁱ yᵢ func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) return evalsAsPoly.Eval(&a) @@ -112,10 +116,22 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { return 1 + e.wire.Gate.Degree() } +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { inputEvaluationsNoRedundancy := proof.([]fr.Element) - // the eq terms + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -124,11 +140,11 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb evaluation.Add(&evaluation, &eq) } - // the g(...) term + // the w(...) term var gateEvaluation fr.Element - if e.wire.IsInput() { + if e.wire.IsInput() { // just compute w(r) gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) @@ -142,7 +158,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } + } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { @@ -159,40 +175,49 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb return errors.New("incompatible evaluations") } +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { varsNum := c.VarsNum() eqLength := 1 << varsNum claimsNum := c.ClaimsNum() - // initialize the eq tables + // initialize the eq tables ( E ) c.eq = c.manager.memPool.Make(eqLength) c.eq[0].SetOne() c.eq.Eq(c.evaluationPoints[0]) + // E := eq(x₀, -) newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) aI := combinationCoeff - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { newEq[0].Set(&aI) c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - if k+1 < claimsNum { aI.Mul(&aI, &combinationCoeff) } @@ -200,16 +225,16 @@ func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) pol c.manager.memPool.Dump(newEq) - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - return c.computeGJ() } -// eqAcc sets m to an eq table at q and then adds it to e +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { n := len(q) - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ const threshold = 1 << 6 @@ -219,8 +244,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } } else { c.manager.workers.Submit(k, func(start, end int) { @@ -228,8 +253,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } }, 1024).Wait() } @@ -240,19 +265,19 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E e[i].Add(&e[i], &m[i]) } }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) } -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) nbGateIn := len(c.inputPreprocessors) - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables s := make([]polynomial.MultiLin, nbGateIn+1) s[0] = c.eq copy(s[1:], c.inputPreprocessors) @@ -267,7 +292,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) + operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate for i := start; i < end; i++ { @@ -306,12 +331,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() } - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index e369b7c52b..e2960cbd43 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -27,16 +27,16 @@ type GateFunction func(...fr.Element) fr.Element type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a solvable variable, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. func (g *Gate) SolvableVar() int { return g.solvableVar } @@ -88,10 +88,13 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]fr.Element - claimedEvaluations []fr.Element + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly manager *claimsManager // WARNING: Circular references } @@ -103,6 +106,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { return len(e.evaluationPoints[0]) } +// CombinedSum returns ∑ᵢ aⁱ yᵢ func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) return evalsAsPoly.Eval(&a) @@ -112,10 +116,22 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { return 1 + e.wire.Gate.Degree() } +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { inputEvaluationsNoRedundancy := proof.([]fr.Element) - // the eq terms + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -124,11 +140,11 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb evaluation.Add(&evaluation, &eq) } - // the g(...) term + // the w(...) term var gateEvaluation fr.Element - if e.wire.IsInput() { + if e.wire.IsInput() { // just compute w(r) gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) @@ -142,7 +158,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } + } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { @@ -159,40 +175,49 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb return errors.New("incompatible evaluations") } +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]fr.Element // x in the paper - claimedEvaluations []fr.Element // y in the paper + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { varsNum := c.VarsNum() eqLength := 1 << varsNum claimsNum := c.ClaimsNum() - // initialize the eq tables + // initialize the eq tables ( E ) c.eq = c.manager.memPool.Make(eqLength) c.eq[0].SetOne() c.eq.Eq(c.evaluationPoints[0]) + // E := eq(x₀, -) newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) aI := combinationCoeff - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { newEq[0].Set(&aI) c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - if k+1 < claimsNum { aI.Mul(&aI, &combinationCoeff) } @@ -200,16 +225,16 @@ func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) pol c.manager.memPool.Dump(newEq) - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - return c.computeGJ() } -// eqAcc sets m to an eq table at q and then adds it to e +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { n := len(q) - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ const threshold = 1 << 6 @@ -219,8 +244,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } } else { c.manager.workers.Submit(k, func(start, end int) { @@ -228,8 +253,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } }, 1024).Wait() } @@ -240,19 +265,19 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E e[i].Add(&e[i], &m[i]) } }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) } -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) nbGateIn := len(c.inputPreprocessors) - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables s := make([]polynomial.MultiLin, nbGateIn+1) s[0] = c.eq copy(s[1:], c.inputPreprocessors) @@ -267,7 +292,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) + operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate for i := start; i < end; i++ { @@ -306,12 +331,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() } - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 9119e58363..c75d09ee1b 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -27,16 +27,16 @@ type GateFunction func(...small_rational.SmallRational) small_rational.SmallRati type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a solvable variable, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } -// SolvableVar returns I such that x_I can always be determined from {x_i} - {x_I} and f(x...). If there is no such variable, it returns -1. +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. func (g *Gate) SolvableVar() int { return g.solvableVar } @@ -88,11 +88,14 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { - wire *Wire - evaluationPoints [][]small_rational.SmallRational - claimedEvaluations []small_rational.SmallRational - manager *claimsManager // WARNING: Circular references + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]small_rational.SmallRational // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []small_rational.SmallRational // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references } func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { @@ -103,6 +106,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { return len(e.evaluationPoints[0]) } +// CombinedSum returns ∑ᵢ aⁱ yᵢ func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a small_rational.SmallRational) small_rational.SmallRational { evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) return evalsAsPoly.Eval(&a) @@ -112,10 +116,22 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { return 1 + e.wire.Gate.Degree() } +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { inputEvaluationsNoRedundancy := proof.([]small_rational.SmallRational) - // the eq terms + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -124,11 +140,11 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.S evaluation.Add(&evaluation, &eq) } - // the g(...) term + // the w(...) term var gateEvaluation small_rational.SmallRational - if e.wire.IsInput() { + if e.wire.IsInput() { // just compute w(r) gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) - } else { + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) @@ -142,7 +158,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.S // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } + } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { @@ -159,40 +175,49 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.S return errors.New("incompatible evaluations") } +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckClaims struct { - wire *Wire - evaluationPoints [][]small_rational.SmallRational // x in the paper - claimedEvaluations []small_rational.SmallRational // y in the paper + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]small_rational.SmallRational // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []small_rational.SmallRational // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) - eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational.SmallRational) polynomial.Polynomial { varsNum := c.VarsNum() eqLength := 1 << varsNum claimsNum := c.ClaimsNum() - // initialize the eq tables + // initialize the eq tables ( E ) c.eq = c.manager.memPool.Make(eqLength) c.eq[0].SetOne() c.eq.Eq(c.evaluationPoints[0]) + // E := eq(x₀, -) newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) aI := combinationCoeff - for k := 1; k < claimsNum; k++ { //TODO: parallelizable? - // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { newEq[0].Set(&aI) c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) - // newEq.Eq(c.evaluationPoints[k]) - // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics - // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) - if k+1 < claimsNum { aI.Mul(&aI, &combinationCoeff) } @@ -200,16 +225,16 @@ func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational. c.manager.memPool.Dump(newEq) - // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree - return c.computeGJ() } -// eqAcc sets m to an eq table at q and then adds it to e +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { n := len(q) - //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ const threshold = 1 << 6 @@ -219,8 +244,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []smal j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } } else { c.manager.workers.Submit(k, func(start, end int) { @@ -228,8 +253,8 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []smal j0 := j << (n - i) // bᵢ₊₁ = 0 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 - m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ - m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) } }, 1024).Wait() } @@ -240,19 +265,19 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []smal e[i].Add(&e[i], &m[i]) } }, 512).Wait() - - // e.Add(e, polynomial.Polynomial(m)) } -// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k -// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). -// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { - degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) nbGateIn := len(c.inputPreprocessors) - // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables s := make([]polynomial.MultiLin, nbGateIn+1) s[0] = c.eq copy(s[1:], c.inputPreprocessors) @@ -267,7 +292,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { var step small_rational.SmallRational res := make([]small_rational.SmallRational, degGJ) - operands := make([]small_rational.SmallRational, degGJ*nbInner) + operands := make([]small_rational.SmallRational, degGJ*nbInner) // the eq value, followed by input to the gate for i := start; i < end; i++ { @@ -306,12 +331,10 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() } - // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though - return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ func (c *eqTimesGateEvalSumcheckClaims) Next(element small_rational.SmallRational) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 From e1f6f7656d151508fe4849b57cda7f4c84a2fbc6 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 7 Apr 2025 11:58:21 -0500 Subject: [PATCH 49/62] refactor: remove unnecessary set --- internal/generator/backend/template/gkr/gkr.go.tmpl | 3 +-- internal/gkr/bls12-377/gkr.go | 3 +-- internal/gkr/bls12-381/gkr.go | 3 +-- internal/gkr/bls24-315/gkr.go | 3 +-- internal/gkr/bls24-317/gkr.go | 3 +-- internal/gkr/bn254/gkr.go | 3 +-- internal/gkr/bw6-633/gkr.go | 3 +-- internal/gkr/bw6-761/gkr.go | 3 +-- internal/gkr/small_rational/gkr.go | 3 +-- 9 files changed, 9 insertions(+), 18 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 192ef25091..0b77db2819 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -294,9 +294,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { block := nbOuter + i for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) + step.Sub(&operands[j], &s[j][i]) for d := 1; d < degGJ; d++ { operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) } diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index abc764edca..9d68efa299 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -298,9 +298,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { block := nbOuter + i for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) + step.Sub(&operands[j], &s[j][i]) for d := 1; d < degGJ; d++ { operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) } diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 696e9299b2..7f601734f0 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -298,9 +298,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { block := nbOuter + i for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) + step.Sub(&operands[j], &s[j][i]) for d := 1; d < degGJ; d++ { operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) } diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index fc054855af..deae88a8b2 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -298,9 +298,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { block := nbOuter + i for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) + step.Sub(&operands[j], &s[j][i]) for d := 1; d < degGJ; d++ { operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) } diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 7dfc5765f1..86367774c8 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -298,9 +298,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { block := nbOuter + i for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) + step.Sub(&operands[j], &s[j][i]) for d := 1; d < degGJ; d++ { operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) } diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index d20a7fa8f9..6e3ffc917c 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -298,9 +298,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { block := nbOuter + i for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) + step.Sub(&operands[j], &s[j][i]) for d := 1; d < degGJ; d++ { operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) } diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index f4f79ac01d..8be28f5048 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -298,9 +298,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { block := nbOuter + i for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) + step.Sub(&operands[j], &s[j][i]) for d := 1; d < degGJ; d++ { operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) } diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index e2960cbd43..6b912ea55a 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -298,9 +298,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { block := nbOuter + i for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) + step.Sub(&operands[j], &s[j][i]) for d := 1; d < degGJ; d++ { operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index c75d09ee1b..2b9c13039d 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -298,9 +298,8 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { block := nbOuter + i for j := 0; j < nbInner; j++ { - step.Set(&s[j][i]) operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &step) + step.Sub(&operands[j], &s[j][i]) for d := 1; d < degGJ; d++ { operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) } From 66c730486141b464bd26e852152e062f14a5c826 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 7 Apr 2025 13:20:03 -0500 Subject: [PATCH 50/62] docs computeGJ --- .../backend/template/gkr/gkr.go.tmpl | 64 +++++++++++-------- internal/gkr/bls12-377/gkr.go | 64 +++++++++++-------- internal/gkr/bls12-381/gkr.go | 64 +++++++++++-------- internal/gkr/bls24-315/gkr.go | 64 +++++++++++-------- internal/gkr/bls24-317/gkr.go | 64 +++++++++++-------- internal/gkr/bn254/gkr.go | 64 +++++++++++-------- internal/gkr/bw6-633/gkr.go | 64 +++++++++++-------- internal/gkr/bw6-761/gkr.go | 64 +++++++++++-------- internal/gkr/small_rational/gkr.go | 64 +++++++++++-------- 9 files changed, 324 insertions(+), 252 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 0b77db2819..5a164c27fc 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -273,57 +273,65 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.inputPreprocessors) + + sumSize := len(c.eq) / 2; // the range of h, over which we sum // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 gJ := make([]{{.ElementType}}, degGJ) var mu sync.Mutex - computeAll := func(start, end int) { + computeAll := func(start, end int) { // compute method to allow parallelization across instances var step {{.ElementType}} res := make([]{{.ElementType}}, degGJ) - operands := make([]{{.ElementType}}, degGJ*nbInner) // the eq value, followed by input to the gate - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &s[j][i]) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]{{.ElementType}}, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h])// step = ml[k](1) - ml[k](0) for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if nbOuter < minBlockSize { + if sumSize < minBlockSize { // no parallelization - computeAll(0, nbOuter) + computeAll(0, sumSize) } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } return gJ diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 9d68efa299..9374bdf48a 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -277,57 +277,65 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.inputPreprocessors) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 gJ := make([]fr.Element, degGJ) var mu sync.Mutex - computeAll := func(start, end int) { + computeAll := func(start, end int) { // compute method to allow parallelization across instances var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &s[j][i]) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if nbOuter < minBlockSize { + if sumSize < minBlockSize { // no parallelization - computeAll(0, nbOuter) + computeAll(0, sumSize) } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } return gJ diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 7f601734f0..933227d798 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -277,57 +277,65 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.inputPreprocessors) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 gJ := make([]fr.Element, degGJ) var mu sync.Mutex - computeAll := func(start, end int) { + computeAll := func(start, end int) { // compute method to allow parallelization across instances var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &s[j][i]) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if nbOuter < minBlockSize { + if sumSize < minBlockSize { // no parallelization - computeAll(0, nbOuter) + computeAll(0, sumSize) } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } return gJ diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index deae88a8b2..7857b9c46c 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -277,57 +277,65 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.inputPreprocessors) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 gJ := make([]fr.Element, degGJ) var mu sync.Mutex - computeAll := func(start, end int) { + computeAll := func(start, end int) { // compute method to allow parallelization across instances var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &s[j][i]) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if nbOuter < minBlockSize { + if sumSize < minBlockSize { // no parallelization - computeAll(0, nbOuter) + computeAll(0, sumSize) } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } return gJ diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 86367774c8..290832c9f9 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -277,57 +277,65 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.inputPreprocessors) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 gJ := make([]fr.Element, degGJ) var mu sync.Mutex - computeAll := func(start, end int) { + computeAll := func(start, end int) { // compute method to allow parallelization across instances var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &s[j][i]) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if nbOuter < minBlockSize { + if sumSize < minBlockSize { // no parallelization - computeAll(0, nbOuter) + computeAll(0, sumSize) } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } return gJ diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 6e3ffc917c..9b7d00b3f7 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -277,57 +277,65 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.inputPreprocessors) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 gJ := make([]fr.Element, degGJ) var mu sync.Mutex - computeAll := func(start, end int) { + computeAll := func(start, end int) { // compute method to allow parallelization across instances var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &s[j][i]) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if nbOuter < minBlockSize { + if sumSize < minBlockSize { // no parallelization - computeAll(0, nbOuter) + computeAll(0, sumSize) } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } return gJ diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 8be28f5048..4662a7da5a 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -277,57 +277,65 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.inputPreprocessors) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 gJ := make([]fr.Element, degGJ) var mu sync.Mutex - computeAll := func(start, end int) { + computeAll := func(start, end int) { // compute method to allow parallelization across instances var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &s[j][i]) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if nbOuter < minBlockSize { + if sumSize < minBlockSize { // no parallelization - computeAll(0, nbOuter) + computeAll(0, sumSize) } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } return gJ diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 6b912ea55a..753623fba8 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -277,57 +277,65 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.inputPreprocessors) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 gJ := make([]fr.Element, degGJ) var mu sync.Mutex - computeAll := func(start, end int) { + computeAll := func(start, end int) { // compute method to allow parallelization across instances var step fr.Element res := make([]fr.Element, degGJ) - operands := make([]fr.Element, degGJ*nbInner) // the eq value, followed by input to the gate - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &s[j][i]) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if nbOuter < minBlockSize { + if sumSize < minBlockSize { // no parallelization - computeAll(0, nbOuter) + computeAll(0, sumSize) } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } return gJ diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 2b9c13039d..c192205b62 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -277,57 +277,65 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. - // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables - s := make([]polynomial.MultiLin, nbGateIn+1) - s[0] = c.eq - copy(s[1:], c.inputPreprocessors) + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.inputPreprocessors) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called - nbInner := len(s) // wrt output, which has high nbOuter and low nbInner - nbOuter := len(s[0]) / 2 gJ := make([]small_rational.SmallRational, degGJ) var mu sync.Mutex - computeAll := func(start, end int) { + computeAll := func(start, end int) { // compute method to allow parallelization across instances var step small_rational.SmallRational res := make([]small_rational.SmallRational, degGJ) - operands := make([]small_rational.SmallRational, degGJ*nbInner) // the eq value, followed by input to the gate - - for i := start; i < end; i++ { - - block := nbOuter + i - for j := 0; j < nbInner; j++ { - operands[j].Set(&s[j][block]) - step.Sub(&operands[j], &s[j][i]) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]small_rational.SmallRational, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) for d := 1; d < degGJ; d++ { - operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) } } - _s := 0 - _e := nbInner - for d := 0; d < degGJ; d++ { - summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) - summand.Mul(&summand, &operands[_s]) - res[d].Add(&res[d], &summand) - _s, _e = _e, _e+nbInner + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) } } mu.Lock() - for i := 0; i < len(gJ); i++ { - gJ[i].Add(&gJ[i], &res[i]) + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum } mu.Unlock() } const minBlockSize = 64 - if nbOuter < minBlockSize { + if sumSize < minBlockSize { // no parallelization - computeAll(0, nbOuter) + computeAll(0, sumSize) } else { - c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() } return gJ From f4c074aaef5fc9d4eaf6ab9607defd1bfccffa06 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Mon, 7 Apr 2025 13:44:31 -0500 Subject: [PATCH 51/62] docs: next, proveFinalEval --- .../backend/template/gkr/gkr.go.tmpl | 44 ++++++++++--------- internal/gkr/bls12-377/gkr.go | 44 ++++++++++--------- internal/gkr/bls12-381/gkr.go | 44 ++++++++++--------- internal/gkr/bls24-315/gkr.go | 44 ++++++++++--------- internal/gkr/bls24-317/gkr.go | 44 ++++++++++--------- internal/gkr/bn254/gkr.go | 44 ++++++++++--------- internal/gkr/bw6-633/gkr.go | 44 ++++++++++--------- internal/gkr/bw6-761/gkr.go | 44 ++++++++++--------- internal/gkr/small_rational/gkr.go | 44 ++++++++++--------- 9 files changed, 207 insertions(+), 189 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 5a164c27fc..0d63e721e3 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -179,7 +179,7 @@ type eqTimesGateEvalSumcheckClaims struct { claimedEvaluations []{{.ElementType}} // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } @@ -269,7 +269,7 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.E func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.inputPreprocessors) + nbGateIn := len(c.input) // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. @@ -277,7 +277,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. ml := make([]polynomial.MultiLin, nbGateIn+1) ml[0] = c.eq - copy(ml[1:], c.inputPreprocessors) + copy(ml[1:], c.input) sumSize := len(c.eq) / 2; // the range of h, over which we sum @@ -337,22 +337,23 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ -func (c *eqTimesGateEvalSumcheckClaims) Next(element {{.ElementType}}) polynomial.Polynomial { +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge {{.ElementType}}) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 if n < minBlockSize { // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) } - c.eq.Fold(element) + c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } @@ -369,22 +370,23 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { return len(c.claimedEvaluations) } +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) interface{} { //defer the proof, return list of claims evaluations := make([]{{.ElementType}}, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. noMoreClaimsAllowed[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] + wI := c.input[inI] if _, found := noMoreClaimsAllowed[in]; !found { noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) } - c.manager.memPool.Dump(puI) + c.manager.memPool.Dump(wI) } c.manager.memPool.Dump(c.claimedEvaluations, c.eq) @@ -439,12 +441,12 @@ func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { } if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied } } return res diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 9374bdf48a..0ba1e96e80 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -184,7 +184,7 @@ type eqTimesGateEvalSumcheckClaims struct { claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } @@ -273,7 +273,7 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.inputPreprocessors) + nbGateIn := len(c.input) // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. @@ -281,7 +281,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. ml := make([]polynomial.MultiLin, nbGateIn+1) ml[0] = c.eq - copy(ml[1:], c.inputPreprocessors) + copy(ml[1:], c.input) sumSize := len(c.eq) / 2 // the range of h, over which we sum @@ -341,22 +341,23 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 if n < minBlockSize { // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) } - c.eq.Fold(element) + c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } @@ -373,22 +374,23 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { return len(c.claimedEvaluations) } +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. noMoreClaimsAllowed[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] + wI := c.input[inI] if _, found := noMoreClaimsAllowed[in]; !found { noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) } - c.manager.memPool.Dump(puI) + c.manager.memPool.Dump(wI) } c.manager.memPool.Dump(c.claimedEvaluations, c.eq) @@ -443,12 +445,12 @@ func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { } if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied } } return res diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 933227d798..b1b259f51d 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -184,7 +184,7 @@ type eqTimesGateEvalSumcheckClaims struct { claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } @@ -273,7 +273,7 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.inputPreprocessors) + nbGateIn := len(c.input) // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. @@ -281,7 +281,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. ml := make([]polynomial.MultiLin, nbGateIn+1) ml[0] = c.eq - copy(ml[1:], c.inputPreprocessors) + copy(ml[1:], c.input) sumSize := len(c.eq) / 2 // the range of h, over which we sum @@ -341,22 +341,23 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 if n < minBlockSize { // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) } - c.eq.Fold(element) + c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } @@ -373,22 +374,23 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { return len(c.claimedEvaluations) } +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. noMoreClaimsAllowed[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] + wI := c.input[inI] if _, found := noMoreClaimsAllowed[in]; !found { noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) } - c.manager.memPool.Dump(puI) + c.manager.memPool.Dump(wI) } c.manager.memPool.Dump(c.claimedEvaluations, c.eq) @@ -443,12 +445,12 @@ func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { } if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied } } return res diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7857b9c46c..d206d3f511 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -184,7 +184,7 @@ type eqTimesGateEvalSumcheckClaims struct { claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } @@ -273,7 +273,7 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.inputPreprocessors) + nbGateIn := len(c.input) // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. @@ -281,7 +281,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. ml := make([]polynomial.MultiLin, nbGateIn+1) ml[0] = c.eq - copy(ml[1:], c.inputPreprocessors) + copy(ml[1:], c.input) sumSize := len(c.eq) / 2 // the range of h, over which we sum @@ -341,22 +341,23 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 if n < minBlockSize { // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) } - c.eq.Fold(element) + c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } @@ -373,22 +374,23 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { return len(c.claimedEvaluations) } +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. noMoreClaimsAllowed[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] + wI := c.input[inI] if _, found := noMoreClaimsAllowed[in]; !found { noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) } - c.manager.memPool.Dump(puI) + c.manager.memPool.Dump(wI) } c.manager.memPool.Dump(c.claimedEvaluations, c.eq) @@ -443,12 +445,12 @@ func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { } if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied } } return res diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 290832c9f9..4b19c7521f 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -184,7 +184,7 @@ type eqTimesGateEvalSumcheckClaims struct { claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } @@ -273,7 +273,7 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.inputPreprocessors) + nbGateIn := len(c.input) // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. @@ -281,7 +281,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. ml := make([]polynomial.MultiLin, nbGateIn+1) ml[0] = c.eq - copy(ml[1:], c.inputPreprocessors) + copy(ml[1:], c.input) sumSize := len(c.eq) / 2 // the range of h, over which we sum @@ -341,22 +341,23 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 if n < minBlockSize { // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) } - c.eq.Fold(element) + c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } @@ -373,22 +374,23 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { return len(c.claimedEvaluations) } +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. noMoreClaimsAllowed[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] + wI := c.input[inI] if _, found := noMoreClaimsAllowed[in]; !found { noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) } - c.manager.memPool.Dump(puI) + c.manager.memPool.Dump(wI) } c.manager.memPool.Dump(c.claimedEvaluations, c.eq) @@ -443,12 +445,12 @@ func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { } if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied } } return res diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 9b7d00b3f7..9c64ac13f3 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -184,7 +184,7 @@ type eqTimesGateEvalSumcheckClaims struct { claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } @@ -273,7 +273,7 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.inputPreprocessors) + nbGateIn := len(c.input) // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. @@ -281,7 +281,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. ml := make([]polynomial.MultiLin, nbGateIn+1) ml[0] = c.eq - copy(ml[1:], c.inputPreprocessors) + copy(ml[1:], c.input) sumSize := len(c.eq) / 2 // the range of h, over which we sum @@ -341,22 +341,23 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 if n < minBlockSize { // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) } - c.eq.Fold(element) + c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } @@ -373,22 +374,23 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { return len(c.claimedEvaluations) } +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. noMoreClaimsAllowed[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] + wI := c.input[inI] if _, found := noMoreClaimsAllowed[in]; !found { noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) } - c.manager.memPool.Dump(puI) + c.manager.memPool.Dump(wI) } c.manager.memPool.Dump(c.claimedEvaluations, c.eq) @@ -443,12 +445,12 @@ func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { } if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied } } return res diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 4662a7da5a..fe439ce9ee 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -184,7 +184,7 @@ type eqTimesGateEvalSumcheckClaims struct { claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } @@ -273,7 +273,7 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.inputPreprocessors) + nbGateIn := len(c.input) // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. @@ -281,7 +281,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. ml := make([]polynomial.MultiLin, nbGateIn+1) ml[0] = c.eq - copy(ml[1:], c.inputPreprocessors) + copy(ml[1:], c.input) sumSize := len(c.eq) / 2 // the range of h, over which we sum @@ -341,22 +341,23 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 if n < minBlockSize { // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) } - c.eq.Fold(element) + c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } @@ -373,22 +374,23 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { return len(c.claimedEvaluations) } +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. noMoreClaimsAllowed[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] + wI := c.input[inI] if _, found := noMoreClaimsAllowed[in]; !found { noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) } - c.manager.memPool.Dump(puI) + c.manager.memPool.Dump(wI) } c.manager.memPool.Dump(c.claimedEvaluations, c.eq) @@ -443,12 +445,12 @@ func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { } if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied } } return res diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 753623fba8..7149c63014 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -184,7 +184,7 @@ type eqTimesGateEvalSumcheckClaims struct { claimedEvaluations []fr.Element // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } @@ -273,7 +273,7 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.E func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.inputPreprocessors) + nbGateIn := len(c.input) // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. @@ -281,7 +281,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. ml := make([]polynomial.MultiLin, nbGateIn+1) ml[0] = c.eq - copy(ml[1:], c.inputPreprocessors) + copy(ml[1:], c.input) sumSize := len(c.eq) / 2 // the range of h, over which we sum @@ -341,22 +341,23 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ -func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 if n < minBlockSize { // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) } - c.eq.Fold(element) + c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } @@ -373,22 +374,23 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { return len(c.claimedEvaluations) } +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. noMoreClaimsAllowed[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] + wI := c.input[inI] if _, found := noMoreClaimsAllowed[in]; !found { noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) } - c.manager.memPool.Dump(puI) + c.manager.memPool.Dump(wI) } c.manager.memPool.Dump(c.claimedEvaluations, c.eq) @@ -443,12 +445,12 @@ func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { } if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied } } return res diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index c192205b62..d69e925d12 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -184,7 +184,7 @@ type eqTimesGateEvalSumcheckClaims struct { claimedEvaluations []small_rational.SmallRational // yᵢ = w(xᵢ) manager *claimsManager - inputPreprocessors []polynomial.MultiLin // the values of wᵢ (input to the gate of w) over the hypercube (across all instances) + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) } @@ -273,7 +273,7 @@ func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []smal func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) - nbGateIn := len(c.inputPreprocessors) + nbGateIn := len(c.input) // Both E and wᵢ (the input wires and the eq table) are multilinear, thus // they are linear in Xⱼ. @@ -281,7 +281,7 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. ml := make([]polynomial.MultiLin, nbGateIn+1) ml[0] = c.eq - copy(ml[1:], c.inputPreprocessors) + copy(ml[1:], c.input) sumSize := len(c.eq) / 2 // the range of h, over which we sum @@ -341,22 +341,23 @@ func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { return gJ } -// Next first folds the "preprocessing" and "eq" polynomials then compute the new gⱼ -func (c *eqTimesGateEvalSumcheckClaims) Next(element small_rational.SmallRational) polynomial.Polynomial { +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge small_rational.SmallRational) polynomial.Polynomial { const minBlockSize = 512 n := len(c.eq) / 2 if n < minBlockSize { // no parallelization - for i := 0; i < len(c.inputPreprocessors); i++ { - c.inputPreprocessors[i].Fold(element) + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) } - c.eq.Fold(element) + c.eq.Fold(challenge) } else { - wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) - for i := 0; i < len(c.inputPreprocessors); i++ { - wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) } - c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() for _, wg := range wgs { wg.Wait() } @@ -373,22 +374,23 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { return len(c.claimedEvaluations) } +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) interface{} { //defer the proof, return list of claims evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. noMoreClaimsAllowed[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { - puI := c.inputPreprocessors[inI] + wI := c.input[inI] if _, found := noMoreClaimsAllowed[in]; !found { noMoreClaimsAllowed[in] = struct{}{} - puI.Fold(r[len(r)-1]) - c.manager.add(in, r, puI[0]) - evaluations = append(evaluations, puI[0]) + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) } - c.manager.memPool.Dump(puI) + c.manager.memPool.Dump(wI) } c.manager.memPool.Dump(c.claimedEvaluations, c.eq) @@ -443,12 +445,12 @@ func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { } if wire.IsInput() { - res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} } else { - res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) for inputI, inputW := range wire.Inputs { - res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied } } return res From c703ae68158b17ff2f4f1a54cdfa2d1cdc12319f Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 8 Apr 2025 11:31:02 -0500 Subject: [PATCH 52/62] refactor: finalEvalProof as fr slice --- internal/gkr/bn254/gkr.go | 42 ++++++++------------ internal/gkr/bn254/sumcheck/sumcheck.go | 8 ++-- internal/gkr/bn254/sumcheck/sumcheck_test.go | 4 +- 3 files changed, 23 insertions(+), 31 deletions(-) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 9c64ac13f3..e55020c164 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -128,9 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, inputEvaluations []fr.Element) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -146,7 +144,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + indexesInProof := make(map[*Wire]int, len(inputEvaluations)) proofI := 0 for inI, in := range e.wire.Inputs { @@ -156,13 +154,13 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb indexesInProof[in] = indexInProof // defer verification, store new claim - e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + e.manager.add(in, r, inputEvaluations[indexInProof]) proofI++ } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? - inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + inputEvaluations[inI] = inputEvaluations[indexInProof] } - if proofI != len(inputEvaluationsNoRedundancy) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + if proofI != len(inputEvaluations) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluations), proofI) } gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) } @@ -375,8 +373,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { } // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { - +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. @@ -667,11 +664,9 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S return proof, err } - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) + for j := range proof[i].FinalEvalProof { + baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() } } // the verifier checks a single claim about input wires itself @@ -708,11 +703,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -725,11 +719,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) + for j := range baseChallenge { + baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() } } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? @@ -883,9 +876,8 @@ func (p Proof) SerializeToBigInts(outs []*big.Int) { offset += len(poly) } if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) } } } diff --git a/internal/gkr/bn254/sumcheck/sumcheck.go b/internal/gkr/bn254/sumcheck/sumcheck.go index 821399b4f3..fc57a1f3cd 100644 --- a/internal/gkr/bn254/sumcheck/sumcheck.go +++ b/internal/gkr/bn254/sumcheck/sumcheck.go @@ -23,7 +23,7 @@ type Claims interface { Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ VarsNum() int //number of variables ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. @@ -32,13 +32,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { @@ -149,7 +149,7 @@ func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settin gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { + for j := range claims.VarsNum() { if len(proof.PartialSumPolys[j]) != claims.Degree(j) { return errors.New("malformed proof") } diff --git a/internal/gkr/bn254/sumcheck/sumcheck_test.go b/internal/gkr/bn254/sumcheck/sumcheck_test.go index 8053589b35..9c3c6c5dc2 100644 --- a/internal/gkr/bn254/sumcheck/sumcheck_test.go +++ b/internal/gkr/bn254/sumcheck/sumcheck_test.go @@ -22,7 +22,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } @@ -56,7 +56,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil From 8733799dc3789edbb0f46133416fdbccad93b2ba Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 8 Apr 2025 12:28:00 -0500 Subject: [PATCH 53/62] fix: inputevaluations --- internal/gkr/bn254/gkr.go | 12 ++++++------ internal/gkr/bn254/gkr_test.go | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index e55020c164..a7cf2fefce 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -128,7 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, inputEvaluations []fr.Element) error { +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -144,7 +144,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) - indexesInProof := make(map[*Wire]int, len(inputEvaluations)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) proofI := 0 for inI, in := range e.wire.Inputs { @@ -154,13 +154,13 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb indexesInProof[in] = indexInProof // defer verification, store new claim - e.manager.add(in, r, inputEvaluations[indexInProof]) + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? - inputEvaluations[inI] = inputEvaluations[indexInProof] + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } - if proofI != len(inputEvaluations) { - return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluations), proofI) + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) } gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) } diff --git a/internal/gkr/bn254/gkr_test.go b/internal/gkr/bn254/gkr_test.go index 09fd0a9be5..095c78218d 100644 --- a/internal/gkr/bn254/gkr_test.go +++ b/internal/gkr/bn254/gkr_test.go @@ -428,11 +428,11 @@ func proofEquals(expected Proof, seen Proof) error { xSeen := seen[i] if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) } } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { return fmt.Errorf("final evaluation proof mismatch") } } From f8715b1c5e44d68f9c2e9c79d9968ce1e894ece5 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 8 Apr 2025 12:41:10 -0500 Subject: [PATCH 54/62] chore: generify changes --- .../backend/template/gkr/gkr.go.tmpl | 31 +++++++------------ .../backend/template/gkr/gkr.test.go.tmpl | 4 +-- .../backend/template/gkr/sumcheck.go.tmpl | 8 ++--- .../template/gkr/sumcheck.test.go.tmpl | 4 +-- internal/gkr/bls12-377/gkr.go | 31 +++++++------------ internal/gkr/bls12-377/gkr_test.go | 4 +-- internal/gkr/bls12-377/sumcheck/sumcheck.go | 8 ++--- .../gkr/bls12-377/sumcheck/sumcheck_test.go | 4 +-- internal/gkr/bls12-381/gkr.go | 31 +++++++------------ internal/gkr/bls12-381/gkr_test.go | 4 +-- internal/gkr/bls12-381/sumcheck/sumcheck.go | 8 ++--- .../gkr/bls12-381/sumcheck/sumcheck_test.go | 4 +-- internal/gkr/bls24-315/gkr.go | 31 +++++++------------ internal/gkr/bls24-315/gkr_test.go | 4 +-- internal/gkr/bls24-315/sumcheck/sumcheck.go | 8 ++--- .../gkr/bls24-315/sumcheck/sumcheck_test.go | 4 +-- internal/gkr/bls24-317/gkr.go | 31 +++++++------------ internal/gkr/bls24-317/gkr_test.go | 4 +-- internal/gkr/bls24-317/sumcheck/sumcheck.go | 8 ++--- .../gkr/bls24-317/sumcheck/sumcheck_test.go | 4 +-- internal/gkr/bn254/gkr.go | 3 +- internal/gkr/bw6-633/gkr.go | 31 +++++++------------ internal/gkr/bw6-633/gkr_test.go | 4 +-- internal/gkr/bw6-633/sumcheck/sumcheck.go | 8 ++--- .../gkr/bw6-633/sumcheck/sumcheck_test.go | 4 +-- internal/gkr/bw6-761/gkr.go | 31 +++++++------------ internal/gkr/bw6-761/gkr_test.go | 4 +-- internal/gkr/bw6-761/sumcheck/sumcheck.go | 8 ++--- .../gkr/bw6-761/sumcheck/sumcheck_test.go | 4 +-- internal/gkr/small_rational/gkr.go | 31 +++++++------------ .../gkr/small_rational/sumcheck/sumcheck.go | 18 +++++------ .../small_rational/sumcheck/sumcheck_test.go | 4 +-- .../sumcheck/sumcheck-gen-vectors.go | 4 +-- internal/small_rational/small-rational.go | 5 +++ 34 files changed, 172 insertions(+), 222 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 0d63e721e3..0eb794e5fa 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -123,9 +123,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]{{.ElementType}}) - +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff, purportedValue {{.ElementType}}, inputEvaluationsNoRedundancy []{{.ElementType}}) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -371,7 +369,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { } // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) interface{} { +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) []{{.ElementType}} { //defer the proof, return list of claims evaluations := make([]{{.ElementType}}, 0, len(c.wire.Inputs)) @@ -663,11 +661,9 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S return proof, err } - finalEvalProof := proof[i].FinalEvalProof.([]{{.ElementType}}) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) + for j := range proof[i].FinalEvalProof { + baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() } } // the verifier checks a single claim about input wires itself @@ -704,11 +700,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]{{.ElementType}}) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -721,11 +716,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) + for j := range baseChallenge { + baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() } } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? @@ -879,9 +873,8 @@ func (p Proof) SerializeToBigInts(outs []*big.Int) { offset += len(poly) } if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]{{.ElementType}}) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) } } } diff --git a/internal/generator/backend/template/gkr/gkr.test.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.go.tmpl index 6d7a58b0aa..09d5a4b01b 100644 --- a/internal/generator/backend/template/gkr/gkr.test.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.test.go.tmpl @@ -424,11 +424,11 @@ func proofEquals(expected Proof, seen Proof) error { xSeen := seen[i] if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) } } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { return fmt.Errorf("final evaluation proof mismatch") } } diff --git a/internal/generator/backend/template/gkr/sumcheck.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.go.tmpl index 2ca7ec4975..3e8444add6 100644 --- a/internal/generator/backend/template/gkr/sumcheck.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.go.tmpl @@ -16,7 +16,7 @@ type Claims interface { Next({{.ElementType}}) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ VarsNum() int //number of variables ClaimsNum() int //number of claims - ProveFinalEval(r []{{.ElementType}}) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + ProveFinalEval(r []{{.ElementType}}) []{{.ElementType}} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. @@ -25,13 +25,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(a {{.ElementType}}) {{.ElementType}} // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error + VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof []{{.ElementType}}) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + FinalEvalProof []{{.ElementType}} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { @@ -142,7 +142,7 @@ func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settin gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { + for j := range claims.VarsNum() { if len(proof.PartialSumPolys[j]) != claims.Degree(j) { return errors.New("malformed proof") } diff --git a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl index f85214d1cd..e7194ed89e 100644 --- a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl +++ b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl @@ -15,7 +15,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval(r []{{.ElementType}}) interface{} { +func (c singleMultilinClaim) ProveFinalEval(r []{{.ElementType}}) []{{.ElementType}} { return nil // verifier can compute the final eval itself } @@ -49,7 +49,7 @@ type singleMultilinLazyClaim struct { claimedSum {{.ElementType}} } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof []{{.ElementType}}) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 0ba1e96e80..53f660ee1e 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -128,9 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -375,7 +373,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { } // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) @@ -667,11 +665,9 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S return proof, err } - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) + for j := range proof[i].FinalEvalProof { + baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() } } // the verifier checks a single claim about input wires itself @@ -708,11 +704,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -725,11 +720,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) + for j := range baseChallenge { + baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() } } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? @@ -883,9 +877,8 @@ func (p Proof) SerializeToBigInts(outs []*big.Int) { offset += len(poly) } if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) } } } diff --git a/internal/gkr/bls12-377/gkr_test.go b/internal/gkr/bls12-377/gkr_test.go index 5b63fd1c80..209b77cc0d 100644 --- a/internal/gkr/bls12-377/gkr_test.go +++ b/internal/gkr/bls12-377/gkr_test.go @@ -428,11 +428,11 @@ func proofEquals(expected Proof, seen Proof) error { xSeen := seen[i] if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) } } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { return fmt.Errorf("final evaluation proof mismatch") } } diff --git a/internal/gkr/bls12-377/sumcheck/sumcheck.go b/internal/gkr/bls12-377/sumcheck/sumcheck.go index d7be95ccb8..3a0c516cc9 100644 --- a/internal/gkr/bls12-377/sumcheck/sumcheck.go +++ b/internal/gkr/bls12-377/sumcheck/sumcheck.go @@ -23,7 +23,7 @@ type Claims interface { Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ VarsNum() int //number of variables ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. @@ -32,13 +32,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { @@ -149,7 +149,7 @@ func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settin gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { + for j := range claims.VarsNum() { if len(proof.PartialSumPolys[j]) != claims.Degree(j) { return errors.New("malformed proof") } diff --git a/internal/gkr/bls12-377/sumcheck/sumcheck_test.go b/internal/gkr/bls12-377/sumcheck/sumcheck_test.go index a9d152c10a..24011c9552 100644 --- a/internal/gkr/bls12-377/sumcheck/sumcheck_test.go +++ b/internal/gkr/bls12-377/sumcheck/sumcheck_test.go @@ -22,7 +22,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } @@ -56,7 +56,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index b1b259f51d..f62a981e1b 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -128,9 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -375,7 +373,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { } // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) @@ -667,11 +665,9 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S return proof, err } - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) + for j := range proof[i].FinalEvalProof { + baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() } } // the verifier checks a single claim about input wires itself @@ -708,11 +704,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -725,11 +720,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) + for j := range baseChallenge { + baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() } } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? @@ -883,9 +877,8 @@ func (p Proof) SerializeToBigInts(outs []*big.Int) { offset += len(poly) } if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) } } } diff --git a/internal/gkr/bls12-381/gkr_test.go b/internal/gkr/bls12-381/gkr_test.go index 8c932961dc..dd2bebb645 100644 --- a/internal/gkr/bls12-381/gkr_test.go +++ b/internal/gkr/bls12-381/gkr_test.go @@ -428,11 +428,11 @@ func proofEquals(expected Proof, seen Proof) error { xSeen := seen[i] if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) } } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { return fmt.Errorf("final evaluation proof mismatch") } } diff --git a/internal/gkr/bls12-381/sumcheck/sumcheck.go b/internal/gkr/bls12-381/sumcheck/sumcheck.go index 6ecb1722a6..800a67938e 100644 --- a/internal/gkr/bls12-381/sumcheck/sumcheck.go +++ b/internal/gkr/bls12-381/sumcheck/sumcheck.go @@ -23,7 +23,7 @@ type Claims interface { Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ VarsNum() int //number of variables ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. @@ -32,13 +32,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { @@ -149,7 +149,7 @@ func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settin gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { + for j := range claims.VarsNum() { if len(proof.PartialSumPolys[j]) != claims.Degree(j) { return errors.New("malformed proof") } diff --git a/internal/gkr/bls12-381/sumcheck/sumcheck_test.go b/internal/gkr/bls12-381/sumcheck/sumcheck_test.go index 4d98d79437..3d8c096e8c 100644 --- a/internal/gkr/bls12-381/sumcheck/sumcheck_test.go +++ b/internal/gkr/bls12-381/sumcheck/sumcheck_test.go @@ -22,7 +22,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } @@ -56,7 +56,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index d206d3f511..a2b44dd5e9 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -128,9 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -375,7 +373,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { } // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) @@ -667,11 +665,9 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S return proof, err } - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) + for j := range proof[i].FinalEvalProof { + baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() } } // the verifier checks a single claim about input wires itself @@ -708,11 +704,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -725,11 +720,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) + for j := range baseChallenge { + baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() } } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? @@ -883,9 +877,8 @@ func (p Proof) SerializeToBigInts(outs []*big.Int) { offset += len(poly) } if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) } } } diff --git a/internal/gkr/bls24-315/gkr_test.go b/internal/gkr/bls24-315/gkr_test.go index 350d807e56..458ab5ed30 100644 --- a/internal/gkr/bls24-315/gkr_test.go +++ b/internal/gkr/bls24-315/gkr_test.go @@ -428,11 +428,11 @@ func proofEquals(expected Proof, seen Proof) error { xSeen := seen[i] if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) } } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { return fmt.Errorf("final evaluation proof mismatch") } } diff --git a/internal/gkr/bls24-315/sumcheck/sumcheck.go b/internal/gkr/bls24-315/sumcheck/sumcheck.go index 4d6fd2a15a..d2ca7f2d5c 100644 --- a/internal/gkr/bls24-315/sumcheck/sumcheck.go +++ b/internal/gkr/bls24-315/sumcheck/sumcheck.go @@ -23,7 +23,7 @@ type Claims interface { Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ VarsNum() int //number of variables ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. @@ -32,13 +32,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { @@ -149,7 +149,7 @@ func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settin gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { + for j := range claims.VarsNum() { if len(proof.PartialSumPolys[j]) != claims.Degree(j) { return errors.New("malformed proof") } diff --git a/internal/gkr/bls24-315/sumcheck/sumcheck_test.go b/internal/gkr/bls24-315/sumcheck/sumcheck_test.go index f41552f57c..d7219b08e1 100644 --- a/internal/gkr/bls24-315/sumcheck/sumcheck_test.go +++ b/internal/gkr/bls24-315/sumcheck/sumcheck_test.go @@ -22,7 +22,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } @@ -56,7 +56,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 4b19c7521f..caa628f606 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -128,9 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -375,7 +373,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { } // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) @@ -667,11 +665,9 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S return proof, err } - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) + for j := range proof[i].FinalEvalProof { + baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() } } // the verifier checks a single claim about input wires itself @@ -708,11 +704,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -725,11 +720,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) + for j := range baseChallenge { + baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() } } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? @@ -883,9 +877,8 @@ func (p Proof) SerializeToBigInts(outs []*big.Int) { offset += len(poly) } if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) } } } diff --git a/internal/gkr/bls24-317/gkr_test.go b/internal/gkr/bls24-317/gkr_test.go index e44c8ccab1..d4749815b0 100644 --- a/internal/gkr/bls24-317/gkr_test.go +++ b/internal/gkr/bls24-317/gkr_test.go @@ -428,11 +428,11 @@ func proofEquals(expected Proof, seen Proof) error { xSeen := seen[i] if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) } } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { return fmt.Errorf("final evaluation proof mismatch") } } diff --git a/internal/gkr/bls24-317/sumcheck/sumcheck.go b/internal/gkr/bls24-317/sumcheck/sumcheck.go index 90dc85ffdf..aeb31069b3 100644 --- a/internal/gkr/bls24-317/sumcheck/sumcheck.go +++ b/internal/gkr/bls24-317/sumcheck/sumcheck.go @@ -23,7 +23,7 @@ type Claims interface { Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ VarsNum() int //number of variables ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. @@ -32,13 +32,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { @@ -149,7 +149,7 @@ func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settin gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { + for j := range claims.VarsNum() { if len(proof.PartialSumPolys[j]) != claims.Degree(j) { return errors.New("malformed proof") } diff --git a/internal/gkr/bls24-317/sumcheck/sumcheck_test.go b/internal/gkr/bls24-317/sumcheck/sumcheck_test.go index 7053f04844..58c43f491e 100644 --- a/internal/gkr/bls24-317/sumcheck/sumcheck_test.go +++ b/internal/gkr/bls24-317/sumcheck/sumcheck_test.go @@ -22,7 +22,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } @@ -56,7 +56,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index a7cf2fefce..a17d85d750 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -128,7 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -374,6 +374,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { + //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index fe439ce9ee..c1c5bcfdde 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -128,9 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -375,7 +373,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { } // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) @@ -667,11 +665,9 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S return proof, err } - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) + for j := range proof[i].FinalEvalProof { + baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() } } // the verifier checks a single claim about input wires itself @@ -708,11 +704,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -725,11 +720,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) + for j := range baseChallenge { + baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() } } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? @@ -883,9 +877,8 @@ func (p Proof) SerializeToBigInts(outs []*big.Int) { offset += len(poly) } if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) } } } diff --git a/internal/gkr/bw6-633/gkr_test.go b/internal/gkr/bw6-633/gkr_test.go index 0b018df326..cead952cbd 100644 --- a/internal/gkr/bw6-633/gkr_test.go +++ b/internal/gkr/bw6-633/gkr_test.go @@ -428,11 +428,11 @@ func proofEquals(expected Proof, seen Proof) error { xSeen := seen[i] if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) } } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { return fmt.Errorf("final evaluation proof mismatch") } } diff --git a/internal/gkr/bw6-633/sumcheck/sumcheck.go b/internal/gkr/bw6-633/sumcheck/sumcheck.go index 8a8c25f3c5..71c472de96 100644 --- a/internal/gkr/bw6-633/sumcheck/sumcheck.go +++ b/internal/gkr/bw6-633/sumcheck/sumcheck.go @@ -23,7 +23,7 @@ type Claims interface { Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ VarsNum() int //number of variables ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. @@ -32,13 +32,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { @@ -149,7 +149,7 @@ func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settin gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { + for j := range claims.VarsNum() { if len(proof.PartialSumPolys[j]) != claims.Degree(j) { return errors.New("malformed proof") } diff --git a/internal/gkr/bw6-633/sumcheck/sumcheck_test.go b/internal/gkr/bw6-633/sumcheck/sumcheck_test.go index 4c740ab0ec..357e169b4f 100644 --- a/internal/gkr/bw6-633/sumcheck/sumcheck_test.go +++ b/internal/gkr/bw6-633/sumcheck/sumcheck_test.go @@ -22,7 +22,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } @@ -56,7 +56,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 7149c63014..b68aa1641e 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -128,9 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]fr.Element) - +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -375,7 +373,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { } // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) @@ -667,11 +665,9 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S return proof, err } - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) + for j := range proof[i].FinalEvalProof { + baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() } } // the verifier checks a single claim about input wires itself @@ -708,11 +704,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]fr.Element) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -725,11 +720,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) + for j := range baseChallenge { + baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() } } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? @@ -883,9 +877,8 @@ func (p Proof) SerializeToBigInts(outs []*big.Int) { offset += len(poly) } if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) } } } diff --git a/internal/gkr/bw6-761/gkr_test.go b/internal/gkr/bw6-761/gkr_test.go index 2aa45ac3c0..99ea1ff5d7 100644 --- a/internal/gkr/bw6-761/gkr_test.go +++ b/internal/gkr/bw6-761/gkr_test.go @@ -428,11 +428,11 @@ func proofEquals(expected Proof, seen Proof) error { xSeen := seen[i] if xSeen.FinalEvalProof == nil { - if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) } } else { - if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { return fmt.Errorf("final evaluation proof mismatch") } } diff --git a/internal/gkr/bw6-761/sumcheck/sumcheck.go b/internal/gkr/bw6-761/sumcheck/sumcheck.go index ce9800a258..ddcc4d0057 100644 --- a/internal/gkr/bw6-761/sumcheck/sumcheck.go +++ b/internal/gkr/bw6-761/sumcheck/sumcheck.go @@ -23,7 +23,7 @@ type Claims interface { Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ VarsNum() int //number of variables ClaimsNum() int //number of claims - ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. @@ -32,13 +32,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { @@ -149,7 +149,7 @@ func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settin gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { + for j := range claims.VarsNum() { if len(proof.PartialSumPolys[j]) != claims.Degree(j) { return errors.New("malformed proof") } diff --git a/internal/gkr/bw6-761/sumcheck/sumcheck_test.go b/internal/gkr/bw6-761/sumcheck/sumcheck_test.go index d6f520fc19..1ca6bbb57c 100644 --- a/internal/gkr/bw6-761/sumcheck/sumcheck_test.go +++ b/internal/gkr/bw6-761/sumcheck/sumcheck_test.go @@ -22,7 +22,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { return nil // verifier can compute the final eval itself } @@ -56,7 +56,7 @@ type singleMultilinLazyClaim struct { claimedSum fr.Element } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index d69e925d12..5fb74adb15 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -128,9 +128,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { // The claims are communicated through the proof parameter. // The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with // the main claim, by checking E w(wᵢ(r)...) = purportedValue. -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]small_rational.SmallRational) - +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff, purportedValue small_rational.SmallRational, inputEvaluationsNoRedundancy []small_rational.SmallRational) error { // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) @@ -375,7 +373,7 @@ func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { } // ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) -func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) interface{} { +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational { //defer the proof, return list of claims evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) @@ -667,11 +665,9 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S return proof, err } - finalEvalProof := proof[i].FinalEvalProof.([]small_rational.SmallRational) - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) + for j := range proof[i].FinalEvalProof { + baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() } } // the verifier checks a single claim about input wires itself @@ -708,11 +704,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]small_rational.SmallRational) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -725,11 +720,10 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), - ); err == nil { - baseChallenge = make([][]byte, len(finalEvalProof)) - for j := range finalEvalProof { - bytes := finalEvalProof[j].Bytes() - baseChallenge[j] = bytes[:] + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) + for j := range baseChallenge { + baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() } } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? @@ -883,9 +877,8 @@ func (p Proof) SerializeToBigInts(outs []*big.Int) { offset += len(poly) } if p[i].FinalEvalProof != nil { - finalEvalProof := p[i].FinalEvalProof.([]small_rational.SmallRational) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) } } } diff --git a/internal/gkr/small_rational/sumcheck/sumcheck.go b/internal/gkr/small_rational/sumcheck/sumcheck.go index e491815a87..b0d233b1dd 100644 --- a/internal/gkr/small_rational/sumcheck/sumcheck.go +++ b/internal/gkr/small_rational/sumcheck/sumcheck.go @@ -19,11 +19,11 @@ import ( // Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. // Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) type Claims interface { - Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. - Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ - VarsNum() int //number of variables - ClaimsNum() int //number of claims - ProveFinalEval(r []small_rational.SmallRational) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } // LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. @@ -32,13 +32,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(a small_rational.SmallRational) small_rational.SmallRational // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error + VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error } // Proof of a multi-sumcheck statement. type Proof struct { - PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` - FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []small_rational.SmallRational `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof } func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { @@ -149,7 +149,7 @@ func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settin gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) - for j := 0; j < claims.VarsNum(); j++ { + for j := range claims.VarsNum() { if len(proof.PartialSumPolys[j]) != claims.Degree(j) { return errors.New("malformed proof") } diff --git a/internal/gkr/small_rational/sumcheck/sumcheck_test.go b/internal/gkr/small_rational/sumcheck/sumcheck_test.go index c2166b7c12..43cc89393e 100644 --- a/internal/gkr/small_rational/sumcheck/sumcheck_test.go +++ b/internal/gkr/small_rational/sumcheck/sumcheck_test.go @@ -22,7 +22,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) interface{} { +func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational { return nil // verifier can compute the final eval itself } @@ -56,7 +56,7 @@ type singleMultilinLazyClaim struct { claimedSum small_rational.SmallRational } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil diff --git a/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go index a264fd57d0..14542ef228 100644 --- a/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go +++ b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go @@ -145,7 +145,7 @@ type singleMultilinClaim struct { g polynomial.MultiLin } -func (c singleMultilinClaim) ProveFinalEval([]small_rational.SmallRational) interface{} { +func (c singleMultilinClaim) ProveFinalEval([]small_rational.SmallRational) []small_rational.SmallRational { return nil // verifier can compute the final eval itself } @@ -179,7 +179,7 @@ type singleMultilinLazyClaim struct { claimedSum small_rational.SmallRational } -func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, _ small_rational.SmallRational, purportedValue small_rational.SmallRational, _ interface{}) error { +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, _ small_rational.SmallRational, purportedValue small_rational.SmallRational, _ []small_rational.SmallRational) error { val := c.g.Evaluate(r, nil) if val.Equal(&purportedValue) { return nil diff --git a/internal/small_rational/small-rational.go b/internal/small_rational/small-rational.go index 39ee2bfe2e..6dbd87f1df 100644 --- a/internal/small_rational/small-rational.go +++ b/internal/small_rational/small-rational.go @@ -402,6 +402,11 @@ func (z *SmallRational) Bytes() [Bytes]byte { return res } +func (z *SmallRational) Marshal() []byte { + res := z.Bytes() + return res[:] +} + func bytesToBigIntSigned(src []byte) big.Int { var res big.Int res.SetBytes(src[1:]) From 05716d706aa1969e194696883af780ab48fbfc43 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Tue, 8 Apr 2025 12:42:38 -0500 Subject: [PATCH 55/62] refactor: remove casting to []fr.Element --- constraint/bls12-377/gkr.go | 5 ++--- constraint/bls12-381/gkr.go | 5 ++--- constraint/bls24-315/gkr.go | 5 ++--- constraint/bls24-317/gkr.go | 5 ++--- constraint/bn254/gkr.go | 5 ++--- constraint/bw6-633/gkr.go | 5 ++--- constraint/bw6-761/gkr.go | 5 ++--- .../generator/backend/template/representations/gkr.go.tmpl | 5 ++--- 8 files changed, 16 insertions(+), 24 deletions(-) diff --git a/constraint/bls12-377/gkr.go b/constraint/bls12-377/gkr.go index 63c04dce94..c798143286 100644 --- a/constraint/bls12-377/gkr.go +++ b/constraint/bls12-377/gkr.go @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bls12-381/gkr.go b/constraint/bls12-381/gkr.go index 2a516d1805..54784922ea 100644 --- a/constraint/bls12-381/gkr.go +++ b/constraint/bls12-381/gkr.go @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bls24-315/gkr.go b/constraint/bls24-315/gkr.go index a2b7297257..0c9cfeb271 100644 --- a/constraint/bls24-315/gkr.go +++ b/constraint/bls24-315/gkr.go @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bls24-317/gkr.go b/constraint/bls24-317/gkr.go index 9c269ff686..b171d09c64 100644 --- a/constraint/bls24-317/gkr.go +++ b/constraint/bls24-317/gkr.go @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bn254/gkr.go b/constraint/bn254/gkr.go index 2af2d3f035..b8d2052e2e 100644 --- a/constraint/bn254/gkr.go +++ b/constraint/bn254/gkr.go @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bw6-633/gkr.go b/constraint/bw6-633/gkr.go index fb693f4851..056d6d12d2 100644 --- a/constraint/bw6-633/gkr.go +++ b/constraint/bw6-633/gkr.go @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bw6-761/gkr.go b/constraint/bw6-761/gkr.go index 72c07a3774..6e97543e36 100644 --- a/constraint/bw6-761/gkr.go +++ b/constraint/bw6-761/gkr.go @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/internal/generator/backend/template/representations/gkr.go.tmpl b/internal/generator/backend/template/representations/gkr.go.tmpl index 89b4b8dbc5..83d73b666f 100644 --- a/internal/generator/backend/template/representations/gkr.go.tmpl +++ b/internal/generator/backend/template/representations/gkr.go.tmpl @@ -164,9 +164,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } From 694a3de7424d65c4f07e64eb3e89741461509ae7 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 9 Apr 2025 20:48:34 -0500 Subject: [PATCH 56/62] docs remove panicky todo comment --- internal/generator/backend/template/gkr/gkr.go.tmpl | 2 +- internal/gkr/bls12-377/gkr.go | 2 +- internal/gkr/bls12-381/gkr.go | 2 +- internal/gkr/bls24-315/gkr.go | 2 +- internal/gkr/bls24-317/gkr.go | 2 +- internal/gkr/bn254/gkr.go | 2 +- internal/gkr/bw6-633/gkr.go | 2 +- internal/gkr/bw6-761/gkr.go | 2 +- internal/gkr/small_rational/gkr.go | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 0eb794e5fa..76f847c309 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -151,7 +151,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}} // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? + } inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 53f660ee1e..91aae5961b 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -156,7 +156,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? + } inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index f62a981e1b..0ed0c5bf5e 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -156,7 +156,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? + } inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index a2b44dd5e9..7dfca036a0 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -156,7 +156,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? + } inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index caa628f606..7e0638a17b 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -156,7 +156,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? + } inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index a17d85d750..1d78f73131 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -156,7 +156,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? + } inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index c1c5bcfdde..de36d88dfe 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -156,7 +156,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? + } inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index b68aa1641e..8a7e3656b6 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -156,7 +156,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? + } inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 5fb74adb15..6f20e222c2 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -156,7 +156,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.S // defer verification, store new claim e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) proofI++ - } // TODO WHERE ARE THE INPUT EVALS ADDED TO FS TRANSCRIPT? + } inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] } if proofI != len(inputEvaluationsNoRedundancy) { From 0648dbc84336c19470d05e9d8d5259610a4743fa Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 9 Apr 2025 21:13:31 -0500 Subject: [PATCH 57/62] feat: solvable vars for bn254 --- internal/gkr/bn254/gkr.go | 66 +++++++++++++++++++++++++++++++-------- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 1d78f73131..a46b113e35 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -73,6 +73,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if w.Inputs[i].Gate.SolvableVar() != -1 { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + func (c Circuit) maxGateDegree() int { res := 1 for i := range c { @@ -141,6 +172,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // the w(...) term var gateEvaluation fr.Element if e.wire.IsInput() { // just compute w(r) + if inputEvaluationsNoRedundancy != nil { + return errors.New("final evaluation proof not needed for input wire") + } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) @@ -377,13 +411,13 @@ func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Elem //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. - noMoreClaimsAllowed[c.wire] = struct{}{} + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { wI := c.input[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} + if _, found := visited[in]; !found { + visited[in] = struct{}{} wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. c.manager.add(in, r, wI[0]) evaluations = append(evaluations, wI[0]) @@ -624,6 +658,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Elem return res, nil } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + // Prove consistency of the claimed assignment func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { o, err := setup(c, assignment, transcriptSettings, options...) @@ -664,11 +711,7 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S ); err != nil { return proof, err } - - baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) - for j := range proof[i].FinalEvalProof { - baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } // the verifier checks a single claim about input wires itself claims.deleteClaim(wire) @@ -721,10 +764,7 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? } From 88850d24b0e78879a79301522bbbd990c369ac5a Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 9 Apr 2025 21:15:12 -0500 Subject: [PATCH 58/62] fix: don't count on a slice being nil --- internal/gkr/bn254/gkr.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index a46b113e35..f36a2cd045 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -172,7 +172,7 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // the w(...) term var gateEvaluation fr.Element if e.wire.IsInput() { // just compute w(r) - if inputEvaluationsNoRedundancy != nil { + if len(inputEvaluationsNoRedundancy) != 0 { return errors.New("final evaluation proof not needed for input wire") } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) From f1b2b47235e91b243b756b18c65e8bdfe1f7db69 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 9 Apr 2025 21:20:09 -0500 Subject: [PATCH 59/62] chore: generify --- .../backend/template/gkr/gkr.go.tmpl | 66 +++++++++++++++---- internal/gkr/bls12-377/gkr.go | 66 +++++++++++++++---- internal/gkr/bls12-381/gkr.go | 66 +++++++++++++++---- internal/gkr/bls24-315/gkr.go | 66 +++++++++++++++---- internal/gkr/bls24-317/gkr.go | 66 +++++++++++++++---- internal/gkr/bw6-633/gkr.go | 66 +++++++++++++++---- internal/gkr/bw6-761/gkr.go | 66 +++++++++++++++---- internal/gkr/small_rational/gkr.go | 66 +++++++++++++++---- 8 files changed, 424 insertions(+), 104 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 76f847c309..37ab9ceb9e 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -68,6 +68,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if w.Inputs[i].Gate.SolvableVar() != -1 { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + func (c Circuit) maxGateDegree() int { res := 1 for i := range c { @@ -136,6 +167,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}} // the w(...) term var gateEvaluation {{.ElementType}} if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]{{.ElementType}}, len(e.wire.Inputs)) @@ -373,13 +407,13 @@ func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) []{ //defer the proof, return list of claims evaluations := make([]{{.ElementType}}, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. - noMoreClaimsAllowed[c.wire] = struct{}{} + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { wI := c.input[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} + if _, found := visited[in]; !found { + visited[in] = struct{}{} wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. c.manager.add(in, r, wI[0]) evaluations = append(evaluations, wI[0]) @@ -620,6 +654,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]{{.Elem return res, nil } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []{{.ElementType}}) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + // Prove consistency of the claimed assignment func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { o, err := setup(c, assignment, transcriptSettings, options...) @@ -660,11 +707,7 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S ); err != nil { return proof, err } - - baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) - for j := range proof[i].FinalEvalProof { - baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } // the verifier checks a single claim about input wires itself claims.deleteClaim(wire) @@ -717,10 +760,7 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? } diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index 91aae5961b..f30f4154d2 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -73,6 +73,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if w.Inputs[i].Gate.SolvableVar() != -1 { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + func (c Circuit) maxGateDegree() int { res := 1 for i := range c { @@ -141,6 +172,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // the w(...) term var gateEvaluation fr.Element if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) @@ -377,13 +411,13 @@ func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Elem //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. - noMoreClaimsAllowed[c.wire] = struct{}{} + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { wI := c.input[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} + if _, found := visited[in]; !found { + visited[in] = struct{}{} wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. c.manager.add(in, r, wI[0]) evaluations = append(evaluations, wI[0]) @@ -624,6 +658,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Elem return res, nil } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + // Prove consistency of the claimed assignment func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { o, err := setup(c, assignment, transcriptSettings, options...) @@ -664,11 +711,7 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S ); err != nil { return proof, err } - - baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) - for j := range proof[i].FinalEvalProof { - baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } // the verifier checks a single claim about input wires itself claims.deleteClaim(wire) @@ -721,10 +764,7 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? } diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 0ed0c5bf5e..fd6472358e 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -73,6 +73,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if w.Inputs[i].Gate.SolvableVar() != -1 { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + func (c Circuit) maxGateDegree() int { res := 1 for i := range c { @@ -141,6 +172,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // the w(...) term var gateEvaluation fr.Element if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) @@ -377,13 +411,13 @@ func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Elem //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. - noMoreClaimsAllowed[c.wire] = struct{}{} + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { wI := c.input[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} + if _, found := visited[in]; !found { + visited[in] = struct{}{} wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. c.manager.add(in, r, wI[0]) evaluations = append(evaluations, wI[0]) @@ -624,6 +658,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Elem return res, nil } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + // Prove consistency of the claimed assignment func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { o, err := setup(c, assignment, transcriptSettings, options...) @@ -664,11 +711,7 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S ); err != nil { return proof, err } - - baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) - for j := range proof[i].FinalEvalProof { - baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } // the verifier checks a single claim about input wires itself claims.deleteClaim(wire) @@ -721,10 +764,7 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? } diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 7dfca036a0..6ef742b2a5 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -73,6 +73,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if w.Inputs[i].Gate.SolvableVar() != -1 { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + func (c Circuit) maxGateDegree() int { res := 1 for i := range c { @@ -141,6 +172,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // the w(...) term var gateEvaluation fr.Element if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) @@ -377,13 +411,13 @@ func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Elem //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. - noMoreClaimsAllowed[c.wire] = struct{}{} + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { wI := c.input[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} + if _, found := visited[in]; !found { + visited[in] = struct{}{} wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. c.manager.add(in, r, wI[0]) evaluations = append(evaluations, wI[0]) @@ -624,6 +658,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Elem return res, nil } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + // Prove consistency of the claimed assignment func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { o, err := setup(c, assignment, transcriptSettings, options...) @@ -664,11 +711,7 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S ); err != nil { return proof, err } - - baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) - for j := range proof[i].FinalEvalProof { - baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } // the verifier checks a single claim about input wires itself claims.deleteClaim(wire) @@ -721,10 +764,7 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? } diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 7e0638a17b..3da846a6c2 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -73,6 +73,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if w.Inputs[i].Gate.SolvableVar() != -1 { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + func (c Circuit) maxGateDegree() int { res := 1 for i := range c { @@ -141,6 +172,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // the w(...) term var gateEvaluation fr.Element if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) @@ -377,13 +411,13 @@ func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Elem //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. - noMoreClaimsAllowed[c.wire] = struct{}{} + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { wI := c.input[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} + if _, found := visited[in]; !found { + visited[in] = struct{}{} wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. c.manager.add(in, r, wI[0]) evaluations = append(evaluations, wI[0]) @@ -624,6 +658,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Elem return res, nil } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + // Prove consistency of the claimed assignment func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { o, err := setup(c, assignment, transcriptSettings, options...) @@ -664,11 +711,7 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S ); err != nil { return proof, err } - - baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) - for j := range proof[i].FinalEvalProof { - baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } // the verifier checks a single claim about input wires itself claims.deleteClaim(wire) @@ -721,10 +764,7 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? } diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index de36d88dfe..709af1fdaf 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -73,6 +73,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if w.Inputs[i].Gate.SolvableVar() != -1 { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + func (c Circuit) maxGateDegree() int { res := 1 for i := range c { @@ -141,6 +172,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // the w(...) term var gateEvaluation fr.Element if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) @@ -377,13 +411,13 @@ func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Elem //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. - noMoreClaimsAllowed[c.wire] = struct{}{} + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { wI := c.input[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} + if _, found := visited[in]; !found { + visited[in] = struct{}{} wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. c.manager.add(in, r, wI[0]) evaluations = append(evaluations, wI[0]) @@ -624,6 +658,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Elem return res, nil } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + // Prove consistency of the claimed assignment func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { o, err := setup(c, assignment, transcriptSettings, options...) @@ -664,11 +711,7 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S ); err != nil { return proof, err } - - baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) - for j := range proof[i].FinalEvalProof { - baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } // the verifier checks a single claim about input wires itself claims.deleteClaim(wire) @@ -721,10 +764,7 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? } diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index 8a7e3656b6..b1382004a3 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -73,6 +73,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if w.Inputs[i].Gate.SolvableVar() != -1 { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + func (c Circuit) maxGateDegree() int { res := 1 for i := range c { @@ -141,6 +172,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, comb // the w(...) term var gateEvaluation fr.Element if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) @@ -377,13 +411,13 @@ func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Elem //defer the proof, return list of claims evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. - noMoreClaimsAllowed[c.wire] = struct{}{} + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { wI := c.input[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} + if _, found := visited[in]; !found { + visited[in] = struct{}{} wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. c.manager.add(in, r, wI[0]) evaluations = append(evaluations, wI[0]) @@ -624,6 +658,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Elem return res, nil } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + // Prove consistency of the claimed assignment func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { o, err := setup(c, assignment, transcriptSettings, options...) @@ -664,11 +711,7 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S ); err != nil { return proof, err } - - baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) - for j := range proof[i].FinalEvalProof { - baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } // the verifier checks a single claim about input wires itself claims.deleteClaim(wire) @@ -721,10 +764,7 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 6f20e222c2..1c456531ab 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -73,6 +73,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if w.Inputs[i].Gate.SolvableVar() != -1 { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + func (c Circuit) maxGateDegree() int { res := 1 for i := range c { @@ -141,6 +172,9 @@ func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.S // the w(...) term var gateEvaluation small_rational.SmallRational if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) @@ -377,13 +411,13 @@ func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallR //defer the proof, return list of claims evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) - noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. - noMoreClaimsAllowed[c.wire] = struct{}{} + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} for inI, in := range c.wire.Inputs { wI := c.input[inI] - if _, found := noMoreClaimsAllowed[in]; !found { - noMoreClaimsAllowed[in] = struct{}{} + if _, found := visited[in]; !found { + visited[in] = struct{}{} wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. c.manager.add(in, r, wI[0]) evaluations = append(evaluations, wI[0]) @@ -624,6 +658,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_r return res, nil } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []small_rational.SmallRational) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + // Prove consistency of the claimed assignment func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { o, err := setup(c, assignment, transcriptSettings, options...) @@ -664,11 +711,7 @@ func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.S ); err != nil { return proof, err } - - baseChallenge = make([][]byte, len(proof[i].FinalEvalProof)) - for j := range proof[i].FinalEvalProof { - baseChallenge[j] = proof[i].FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } // the verifier checks a single claim about input wires itself claims.deleteClaim(wire) @@ -721,10 +764,7 @@ func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSetting } else if err = sumcheck.Verify( claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { // incorporate prover claims about w's input into the transcript - baseChallenge = make([][]byte, len(proofW.FinalEvalProof)) - for j := range baseChallenge { - baseChallenge[j] = proofW.FinalEvalProof[j].Marshal() - } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) } else { return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? } From 99775e6e59598e4b4883a798a7e41796e7ed7c0f Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 9 Apr 2025 21:59:14 -0500 Subject: [PATCH 60/62] fix: elemIndex --- internal/generator/backend/template/gkr/gkr.go.tmpl | 2 +- internal/gkr/bls12-377/gkr.go | 2 +- internal/gkr/bls12-381/gkr.go | 2 +- internal/gkr/bls24-315/gkr.go | 2 +- internal/gkr/bls24-317/gkr.go | 2 +- internal/gkr/bn254/gkr.go | 2 +- internal/gkr/bw6-633/gkr.go | 2 +- internal/gkr/bw6-761/gkr.go | 2 +- internal/gkr/small_rational/gkr.go | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 37ab9ceb9e..c4ff3f0c99 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -89,7 +89,7 @@ func (w Wire) unhashedFinalEvalProofElemIndex() int { continue } - if w.Inputs[i].Gate.SolvableVar() != -1 { + if i == w.Gate.SolvableVar() { return indexInProof } diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index f30f4154d2..af779b1dab 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -94,7 +94,7 @@ func (w Wire) unhashedFinalEvalProofElemIndex() int { continue } - if w.Inputs[i].Gate.SolvableVar() != -1 { + if i == w.Gate.SolvableVar() { return indexInProof } diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index fd6472358e..ce8281d8e3 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -94,7 +94,7 @@ func (w Wire) unhashedFinalEvalProofElemIndex() int { continue } - if w.Inputs[i].Gate.SolvableVar() != -1 { + if i == w.Gate.SolvableVar() { return indexInProof } diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index 6ef742b2a5..5a27e68471 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -94,7 +94,7 @@ func (w Wire) unhashedFinalEvalProofElemIndex() int { continue } - if w.Inputs[i].Gate.SolvableVar() != -1 { + if i == w.Gate.SolvableVar() { return indexInProof } diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index 3da846a6c2..252c1c7c97 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -94,7 +94,7 @@ func (w Wire) unhashedFinalEvalProofElemIndex() int { continue } - if w.Inputs[i].Gate.SolvableVar() != -1 { + if i == w.Gate.SolvableVar() { return indexInProof } diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index f36a2cd045..8f78c2ee9e 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -94,7 +94,7 @@ func (w Wire) unhashedFinalEvalProofElemIndex() int { continue } - if w.Inputs[i].Gate.SolvableVar() != -1 { + if i == w.Gate.SolvableVar() { return indexInProof } diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index 709af1fdaf..6028a22ed4 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -94,7 +94,7 @@ func (w Wire) unhashedFinalEvalProofElemIndex() int { continue } - if w.Inputs[i].Gate.SolvableVar() != -1 { + if i == w.Gate.SolvableVar() { return indexInProof } diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index b1382004a3..dc08567116 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -94,7 +94,7 @@ func (w Wire) unhashedFinalEvalProofElemIndex() int { continue } - if w.Inputs[i].Gate.SolvableVar() != -1 { + if i == w.Gate.SolvableVar() { return indexInProof } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index 1c456531ab..8efda8508f 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -94,7 +94,7 @@ func (w Wire) unhashedFinalEvalProofElemIndex() int { continue } - if w.Inputs[i].Gate.SolvableVar() != -1 { + if i == w.Gate.SolvableVar() { return indexInProof } From 3d3a69ccd6bdc8021b23abf923b4ecef97d85c2e Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 9 Apr 2025 22:27:36 -0500 Subject: [PATCH 61/62] docs: comments for gkr snark verifier --- std/gkr/gkr.go | 34 +++++++++++++++++++++++----------- std/sumcheck/sumcheck.go | 4 ++-- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index bd08a33e18..d859eba7ee 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -55,11 +55,11 @@ type GateFunction func(GateAPI, ...frontend.Variable) frontend.Variable type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a variable whose value can be uniquely determined from the value of the gate and the other inputs, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } @@ -114,6 +114,9 @@ type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { wire *Wire evaluationPoints [][]frontend.Variable @@ -121,10 +124,20 @@ type eqTimesGateEvalSumcheckLazyClaims struct { manager *claimsManager // WARNING: Circular references } -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(api frontend.API, r []frontend.Variable, combinationCoeff, purportedValue frontend.Variable, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]frontend.Variable) - - // the eq terms +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(api frontend.API, r []frontend.Variable, combinationCoeff, purportedValue frontend.Variable, inputEvaluationsNoRedundancy []frontend.Variable) error { + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(api, e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -370,11 +383,10 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]frontend.Variable) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -386,7 +398,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, } else if err = sumcheck.Verify( api, claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { - baseChallenge = finalEvalProof + baseChallenge = proofW.FinalEvalProof } else { return err } @@ -510,7 +522,7 @@ func (p Proof) Serialize() []frontend.Variable { for j := range p[i].PartialSumPolys { size += len(p[i].PartialSumPolys[j]) } - size += len(p[i].FinalEvalProof.([]frontend.Variable)) + size += len(p[i].FinalEvalProof) } res := make([]frontend.Variable, 0, size) @@ -518,7 +530,7 @@ func (p Proof) Serialize() []frontend.Variable { for j := range p[i].PartialSumPolys { res = append(res, p[i].PartialSumPolys[j]...) } - res = append(res, p[i].FinalEvalProof.([]frontend.Variable)...) + res = append(res, p[i].FinalEvalProof...) } if len(res) != size { panic("bug") // TODO: Remove diff --git a/std/sumcheck/sumcheck.go b/std/sumcheck/sumcheck.go index ad96621c96..ec3f724130 100644 --- a/std/sumcheck/sumcheck.go +++ b/std/sumcheck/sumcheck.go @@ -15,13 +15,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(api frontend.API, a frontend.Variable) frontend.Variable // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(api frontend.API, r []frontend.Variable, combinationCoeff, purportedValue frontend.Variable, proof interface{}) error + VerifyFinalEval(api frontend.API, r []frontend.Variable, combinationCoeff, purportedValue frontend.Variable, proof []frontend.Variable) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial - FinalEvalProof interface{} + FinalEvalProof []frontend.Variable } func setupTranscript(api frontend.API, claimsNum int, varsNum int, settings *fiatshamir.Settings) ([]string, error) { From cf46fc3ca980a1303c98fbcfd4b097be4fa58448 Mon Sep 17 00:00:00 2001 From: Tabaie Date: Wed, 9 Apr 2025 22:32:37 -0500 Subject: [PATCH 62/62] fix: getBaseChallenge in snark verifier --- std/gkr/gkr.go | 46 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index d859eba7ee..fa6c6020c6 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -109,6 +109,37 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + // WireAssignment is assignment of values to the same wire across many instances of the circuit type WireAssignment map[*Wire]polynomial.MultiLin @@ -357,6 +388,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) (challenge return } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []frontend.Variable) []frontend.Variable { + baseChallenge := make([]frontend.Variable, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j]) + } + } + return baseChallenge +} + // Verify the consistency of the claimed output with the claimed input // Unlike in Prove, the assignment argument need not be complete func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { @@ -398,7 +442,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, } else if err = sumcheck.Verify( api, claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { - baseChallenge = proofW.FinalEvalProof + baseChallenge = getBaseChallenge(wire, proofW.FinalEvalProof) } else { return err }