diff --git a/constraint/bls12-377/gkr.go b/constraint/bls12-377/gkr.go index 744f22525c..c798143286 100644 --- a/constraint/bls12-377/gkr.go +++ b/constraint/bls12-377/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bls12-377" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bls12-381/gkr.go b/constraint/bls12-381/gkr.go index b3b22b9a95..54784922ea 100644 --- a/constraint/bls12-381/gkr.go +++ b/constraint/bls12-381/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bls12-381" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bls24-315/gkr.go b/constraint/bls24-315/gkr.go index ba328c8bb4..0c9cfeb271 100644 --- a/constraint/bls24-315/gkr.go +++ b/constraint/bls24-315/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bls24-315" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bls24-317/gkr.go b/constraint/bls24-317/gkr.go index be02e3455c..b171d09c64 100644 --- a/constraint/bls24-317/gkr.go +++ b/constraint/bls24-317/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bls24-317" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bn254/gkr.go b/constraint/bn254/gkr.go index 21731b8ac9..b8d2052e2e 100644 --- a/constraint/bn254/gkr.go +++ b/constraint/bn254/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bn254" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bw6-633/gkr.go b/constraint/bw6-633/gkr.go index 125da817df..056d6d12d2 100644 --- a/constraint/bw6-633/gkr.go +++ b/constraint/bw6-633/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bw6-633" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/constraint/bw6-761/gkr.go b/constraint/bw6-761/gkr.go index f40856cc36..6e97543e36 100644 --- a/constraint/bw6-761/gkr.go +++ b/constraint/bw6-761/gkr.go @@ -8,12 +8,12 @@ package cs import ( "fmt" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" "github.com/consensys/gnark-crypto/utils" "github.com/consensys/gnark/constraint" hint "github.com/consensys/gnark/constraint/solver" + gkr "github.com/consensys/gnark/internal/gkr/bw6-761" algo_utils "github.com/consensys/gnark/internal/utils" "hash" "math/big" @@ -171,9 +171,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/go.mod b/go.mod index efe1278a23..8ab4e0b6b9 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/blang/semver/v4 v4.0.0 github.com/consensys/bavard v0.1.31-0.20250314194434-b30d4344e6d4 github.com/consensys/compress v0.2.5 - github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1 + github.com/consensys/gnark-crypto v0.17.1-0.20250331132656-820ac1d108bd github.com/fxamacker/cbor/v2 v2.7.0 github.com/google/go-cmp v0.6.0 github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 diff --git a/go.sum b/go.sum index d460981671..1eacc9d554 100644 --- a/go.sum +++ b/go.sum @@ -63,6 +63,10 @@ github.com/consensys/compress v0.2.5 h1:gJr1hKzbOD36JFsF1AN8lfXz1yevnJi1YolffY19 github.com/consensys/compress v0.2.5/go.mod h1:pyM+ZXiNUh7/0+AUjUf9RKUM6vSH7T/fsn5LLS0j1Tk= github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1 h1:6cK71BoMAjWHNl+EpvBh2PDDa0PIeoz1KFJ/6R16DjQ= github.com/consensys/gnark-crypto v0.17.1-0.20250326164229-5fd6610ac2a1/go.mod h1:uV1HwfBwGRj50DGK3LbDLeCvq0RX/vFXST3CRSAu0Fs= +github.com/consensys/gnark-crypto v0.17.1-0.20250327163404-2fc9f58298e2 h1:vWEj2nIXK3dG2oyNlqLXVOFRGi1E3BQn+YVkwqf7GnM= +github.com/consensys/gnark-crypto v0.17.1-0.20250327163404-2fc9f58298e2/go.mod h1:uV1HwfBwGRj50DGK3LbDLeCvq0RX/vFXST3CRSAu0Fs= +github.com/consensys/gnark-crypto v0.17.1-0.20250331132656-820ac1d108bd h1:og4X8KhBpFv37u0PuXLz7yOLz9vwOcpTX4pKfbbwtgM= +github.com/consensys/gnark-crypto v0.17.1-0.20250331132656-820ac1d108bd/go.mod h1:uV1HwfBwGRj50DGK3LbDLeCvq0RX/vFXST3CRSAu0Fs= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= diff --git a/internal/generator/backend/main.go b/internal/generator/backend/main.go index 027ed4c4e3..6d805d22c5 100644 --- a/internal/generator/backend/main.go +++ b/internal/generator/backend/main.go @@ -1,6 +1,8 @@ package main import ( + "fmt" + "github.com/consensys/gnark-crypto/field/generator/config" "os" "os/exec" "path/filepath" @@ -9,7 +11,6 @@ import ( "github.com/consensys/bavard" "github.com/consensys/gnark-crypto/field/generator" - "github.com/consensys/gnark-crypto/field/generator/config" ) const copyrightHolder = "Consensys Software Inc." @@ -79,7 +80,7 @@ func main() { panic(err) } - datas := []templateData{ + data := []templateData{ bls12_377, bls12_381, bn254, @@ -91,10 +92,9 @@ func main() { } const importCurve = "../imports.go.tmpl" - var wg sync.WaitGroup - for _, d := range datas { + for _, d := range data { wg.Add(1) @@ -129,10 +129,24 @@ func main() { // gkr backend if d.Curve != "tinyfield" { + // solver and proof delegator TODO merge with "backend" below entries = []bavard.Entry{{File: filepath.Join(csDir, "gkr.go"), Templates: []string{"gkr.go.tmpl", importCurve}}} - if err := bgen.Generate(d, "cs", "./template/representations/", entries...); err != nil { - panic(err) + err := bgen.Generate(d, "cs", "./template/representations/", entries...) + assertNoError(err) + + curvePackageName := strings.ToLower(d.Curve) + + cfg := gkrConfig{ + FieldDependency: config.FieldDependency{ + ElementType: "fr.Element", + FieldPackageName: "fr", + FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + curvePackageName + "/fr", + }, + GkrPackageRelativePath: "internal/gkr/" + curvePackageName, + CanUseFFT: true, } + + assertNoError(generateGkrBackend(cfg)) } entries = []bavard.Entry{ @@ -203,6 +217,41 @@ func main() { } + wg.Add(1) + // GKR test vectors + go func() { + // generate gkr and sumcheck for small-rational + cfg := gkrConfig{ + FieldDependency: config.FieldDependency{ + ElementType: "small_rational.SmallRational", + FieldPackagePath: "github.com/consensys/gnark/internal/small_rational", + FieldPackageName: "small_rational", + }, + GkrPackageRelativePath: "internal/gkr/small_rational", + CanUseFFT: false, + NoGkrTests: true, + } + assertNoError(generateGkrBackend(cfg)) + + // generate gkr test vector generator + cfg.GenerateTestVectors = true + cfg.OutsideGkrPackage = true + + assertNoError(bgen.Generate(cfg, "gkr", "./template/gkr/", + bavard.Entry{ + File: "../../gkr/test_vectors/gkr/gkr-gen-vectors.go", + Templates: []string{"gkr.test.vectors.gen.go.tmpl", "gkr.test.vectors.go.tmpl"}, + }, + )) + + fmt.Println("generating test vectors for gkr and sumcheck") + cmd := exec.Command("go", "run", "../../gkr/test_vectors") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + assertNoError(cmd.Run()) + wg.Done() + }() + wg.Wait() // run go fmt on whole directory @@ -223,3 +272,63 @@ type templateData struct { noBackend bool NoGKR bool } + +func generateGkrBackend(cfg gkrConfig) error { + const repoRoot = "../../../" + packageOutPath := filepath.Join(repoRoot, cfg.GkrPackageRelativePath) + + // test vector utils + packageDir := filepath.Join(packageOutPath, "test_vector_utils") + entries := []bavard.Entry{ + {File: filepath.Join(packageDir, "test_vector_utils.go"), Templates: []string{"test_vector_utils.go.tmpl"}}, + } + + if err := bgen.Generate(cfg, "test_vector_utils", "./template/gkr/", entries...); err != nil { + return err + } + + // sumcheck backend + packageDir = filepath.Join(packageOutPath, "sumcheck") + entries = []bavard.Entry{ + {File: filepath.Join(packageDir, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, + {File: filepath.Join(packageDir, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl"}}, + } + + if err := bgen.Generate(cfg, "sumcheck", "./template/gkr/", entries...); err != nil { + return err + } + + // gkr backend + packageDir = packageOutPath + entries = []bavard.Entry{ + {File: filepath.Join(packageDir, "gkr.go"), Templates: []string{"gkr.go.tmpl"}}, + {File: filepath.Join(packageDir, "registry.go"), Templates: []string{"registry.go.tmpl"}}, + } + + if !cfg.NoGkrTests { + entries = append(entries, bavard.Entry{ + File: filepath.Join(packageDir, "gkr_test.go"), Templates: []string{"gkr.test.go.tmpl", "gkr.test.vectors.go.tmpl"}, + }) + } + + if err := bgen.Generate(cfg, "gkr", "./template/gkr/", entries...); err != nil { + return err + } + + return nil +} + +type gkrConfig struct { + config.FieldDependency + GkrPackageRelativePath string // the GKR package, relative to the repo root + CanUseFFT bool + OutsideGkrPackage bool + GenerateTestVectors bool + NoGkrTests bool +} + +func assertNoError(err error) { + if err != nil { + panic(err) + } +} diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl new file mode 100644 index 0000000000..c4ff3f0c99 --- /dev/null +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -0,0 +1,926 @@ +import ( + "errors" + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "math/big" + "strconv" + "sync" +) + +{{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...{{.ElementType}}) {{.ElementType}} + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of the polynomial + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]{{.ElementType}} // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []{{.ElementType}} // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +// CombinedSum returns ∑ᵢ aⁱ yᵢ +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a {{.ElementType}}) {{.ElementType}} { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff, purportedValue {{.ElementType}}, inputEvaluationsNoRedundancy []{{.ElementType}}) error { + // the eq terms ( E ) + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the w(...) term + var gateEvaluation {{.ElementType}} + if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire + inputEvaluations := make([]{{.ElementType}}, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]{{.ElementType}} // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []{{.ElementType}} // yᵢ = w(xᵢ) + manager *claimsManager + + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) + + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) +} + +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff {{.ElementType}}) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables ( E ) + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + // E := eq(x₀, -) + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq,c.evaluationPoints[k]) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.ElementType}}) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() +} + + +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) + nbGateIn := len(c.input) + + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.input) + + sumSize := len(c.eq) / 2; // the range of h, over which we sum + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + + gJ := make([]{{.ElementType}}, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { // compute method to allow parallelization across instances + var step {{.ElementType}} + + res := make([]{{.ElementType}}, degGJ) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]{{.ElementType}}, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h])// step = ml[k](1) - ml[k](0) + for d := 1; d < degGJ; d++ { + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) + } + } + + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + } + } + mu.Lock() + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + } + mu.Unlock() + } + + const minBlockSize = 64 + + if sumSize < minBlockSize { + // no parallelization + computeAll(0, sumSize) + } else { + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return gJ +} + +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge {{.ElementType}}) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) + } + c.eq.Fold(challenge) + } else { + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) []{{.ElementType}} { + + //defer the proof, return list of claims + evaluations := make([]{{.ElementType}}, 0, len(c.wire.Inputs)) + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + wI := c.input[inI] + if _, found := visited[in]; !found { + visited[in] = struct{}{} + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) + } + c.manager.memPool.Dump(wI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]{{.ElementType}}, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []{{.ElementType}}, evaluation {{.ElementType}}) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func (options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]{{.ElementType}}, error) { + res := make([]{{.ElementType}}, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []{{.ElementType}}) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []{{.ElementType}} + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []{{.ElementType}}{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []{{.ElementType}} + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// {{$topologicalSort}} sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func {{$topologicalSort}}(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := {{$topologicalSort}}(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]{{.ElementType}}, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]{{.ElementType}}, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []{{.ElementType}}) { + for i := range src { + src[i].BigInt(dst[i]) + } +} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/gkr.test.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.go.tmpl new file mode 100644 index 0000000000..09d5a4b01b --- /dev/null +++ b/internal/generator/backend/template/gkr/gkr.test.go.tmpl @@ -0,0 +1,607 @@ + +import ( + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/mimc" + "{{.FieldPackagePath}}/polynomial" + "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/sumcheck" + "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/stretchr/testify/assert" + "fmt" + "hash" + "os" + "strconv" + "testing" + "path/filepath" + "encoding/json" + "reflect" + "time" +) + +{{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []{{.ElementType}}{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []{{.ElementType}}{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []{{.ElementType}}{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]{{.ElementType}}) { + return func(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{ Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + } } + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []{{.ElementType}}{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []{{.ElementType}}{three}, five) + manager.add(wire, []{{.ElementType}}{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six {{.ElementType}} + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]{{.ElementType}})) { + fullAssignments := make([][]{{.ElementType}}, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]{{.ElementType}}, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t,err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []{{.ElementType}}) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + const testDirPath = "../test_vectors/gkr" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := {{$topologicalSort}}(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := {{$topologicalSort}}(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := {{$topologicalSort}}(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +{{template "gkrTestVectors" .}} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl new file mode 100644 index 0000000000..a747c5ed04 --- /dev/null +++ b/internal/generator/backend/template/gkr/gkr.test.vectors.gen.go.tmpl @@ -0,0 +1,114 @@ +import ( + "encoding/json" + "fmt" + "github.com/consensys/bavard" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/small_rational" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" + "hash" + "os" + "path/filepath" + "reflect" +) + +func GenerateVectors() error { + testDirPath, err := filepath.Abs("../../gkr/test_vectors/gkr") + if err != nil { + return err + } + + fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) + + dirEntries, err := os.ReadDir(testDirPath) + if err != nil { + return err + } + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + if !bavard.ShouldGenerate(path) { + continue + } + fmt.Println("\tprocessing", dirEntry.Name()) + if err = run(path); err != nil { + return err + } + } + } + } + + return nil +} + +func run(absPath string) error { + testCase, err := newTestCase(absPath) + if err != nil { + return err + } + + transcriptSetting := fiatshamir.WithHash(testCase.Hash) + + var proof gkr.Proof + proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) + if err != nil { + return err + } + + if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { + return err + } + var outBytes []byte + if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { + if err = os.WriteFile(absPath, outBytes, 0); err != nil { + return err + } + } else { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) + if err != nil { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + if err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { + res := make(PrintableProof, len(proof)) + + for i := range proof { + + partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) + for k, partialK := range proof[i].PartialSumPolys { + partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) + } + + res[i] = PrintableSumcheckProof{ + FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), + PartialSumPolys: partialSumPolys, + } + } + return res, nil +} + +{{template "gkrTestVectors" .}} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl b/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl new file mode 100644 index 0000000000..b1f8bbdf9b --- /dev/null +++ b/internal/generator/backend/template/gkr/gkr.test.vectors.go.tmpl @@ -0,0 +1,254 @@ +{{define "gkrTestVectors"}} + +{{$GkrPackagePrefix := select .OutsideGkrPackage "" "gkr."}} +{{$CheckOutputCorrectness := not .OutsideGkrPackage}} + +{{$Circuit := print $GkrPackagePrefix "Circuit"}} +{{$Gate := print $GkrPackagePrefix "Gate"}} +{{$Proof := print $GkrPackagePrefix "Proof"}} +{{$WireAssignment := print $GkrPackagePrefix "WireAssignment"}} +{{$Wire := print $GkrPackagePrefix "Wire"}} +{{$CircuitLayer := print $GkrPackagePrefix "CircuitLayer"}} + +{{$PackagePrefix := ""}} +{{- if .OutsideGkrPackage}} + {{$PackagePrefix = "gkr."}} +{{end}} + +type WireInfo struct { + Gate {{$PackagePrefix}}GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]{{$Circuit}}) + +func getCircuit(path string) ({{$Circuit}}, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit {{$Circuit}}) { + circuit = make({{$Circuit}}, len(c)) + for i := range c { + circuit[i].Gate = {{$PackagePrefix}}GetGate(c[i].Gate) + circuit[i].Inputs = make([]*{{$Wire}}, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...{{.ElementType}}) (res {{.ElementType}}) { + var sum {{.ElementType}} + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC {{$PackagePrefix}}GateName = "mimc" + SelectInput3 {{$PackagePrefix}}GateName = "select-input-3" +) + +func init() { + if err := {{$PackagePrefix}}RegisterGate(MiMC, mimcRound, 2, {{$PackagePrefix}}WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := {{$PackagePrefix}}RegisterGate(SelectInput3, func(input ...{{.ElementType}}) {{.ElementType}} { + return input[2] + }, 3, {{$PackagePrefix}}WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) ({{$Proof}}, error) { + proof := make({{$Proof}}, len(printable)) + for i := range printable { + finalEvalProof := []{{.ElementType}}(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]{{.ElementType}}, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := {{ setElement "finalEvalProof[k]" "finalEvalSlice.Index(k).Interface()" .ElementType}}; err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit {{$Circuit}} + Hash hash.Hash + Proof {{$Proof}} + FullAssignment {{$WireAssignment}} + InOutAssignment {{$WireAssignment}} + {{if .GenerateTestVectors}}Info TestCaseInfo // we are generating the test vectors, so we need to keep the circuit instance info to ADD the proof to it and resave it{{end}} +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit {{$Circuit}} + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof {{$Proof}} + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make({{$WireAssignment}}) + inOutAssignment := make({{$WireAssignment}}) + + sorted := {{select .OutsideGkrPackage "t" "gkr.T"}}opologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []{{.ElementType}} + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + {{if not $CheckOutputCorrectness}} + info.Output = make([][]interface{}, 0, outI) + {{end}} + + for _, w := range sorted { + if w.IsOutput() { + {{if $CheckOutputCorrectness}} + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + {{else}} + info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) + {{end}} + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + {{if .GenerateTestVectors }}Info: info,{{end}} + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +{{end}} + +{{- define "setElement element value elementType"}} +{{- if eq .elementType "fr.Element"}} test_vector_utils.SetElement(&{{.element}}, {{.value}}) +{{- else if eq .elementType "small_rational.SmallRational"}} {{.element}}.SetInterface({{.value}}) +{{- else}} +{{print "\"UNEXPECTED TYPE" .elementType "\""}} +{{- end}} +{{- end}} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/registry.go.tmpl b/internal/generator/backend/template/gkr/registry.go.tmpl new file mode 100644 index 0000000000..75ca8d0267 --- /dev/null +++ b/internal/generator/backend/template/gkr/registry.go.tmpl @@ -0,0 +1,390 @@ +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "{{.FieldPackagePath}}" + {{- if .CanUseFFT }} + "{{.FieldPackagePath}}/fft"{{- else}} + "errors"{{- end }} + "{{.FieldPackagePath}}/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make({{.FieldPackageName}}.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]{{.ElementType}}, nbIn) + consts := make({{.FieldPackageName}}.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + {{- if .CanUseFFT }} + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := {{.FieldPackageName}}.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + {{- else }} + x := make({{.FieldPackageName}}.Vector, degreeBound) + x.MustSetRandom() + for i := range x { + fIn[0] = x[i] + for j := range consts { + fIn[j+1].Mul(&x[i], &consts[j]) + } + p[i] = f(fIn...) + } + + // obtain p's coefficients + p, err := interpolate(x, p) + if err != nil { + panic(err) + } + {{- end }} + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max)+1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +{{- if not .CanUseFFT }} +// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) +// Note that the runtime is O(len(X)³) +func interpolate(X, Y []{{.ElementType}}) (polynomial.Polynomial, error) { + if len(X) != len(Y) { + return nil, errors.New("X and Y must have the same length") + } + + // solve the system of equations by Gaussian elimination + augmentedRows := make([][]{{.ElementType}}, len(X)) // the last column is the Y values + for i := range augmentedRows { + augmentedRows[i] = make([]{{.ElementType}}, len(X)+1) + augmentedRows[i][0].SetOne() + augmentedRows[i][1].Set(&X[i]) + for j := 2; j < len(augmentedRows[i])-1; j++ { + augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) + } + augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) + } + + // make the upper triangle + for i := range len(augmentedRows) - 1 { + // use row i to eliminate the ith element in all rows below + var negInv {{.ElementType}} + if augmentedRows[i][i].IsZero() { + return nil, errors.New("singular matrix") + } + negInv.Inverse(&augmentedRows[i][i]) + negInv.Neg(&negInv) + for j := i + 1; j < len(augmentedRows); j++ { + var c {{.ElementType}} + c.Mul(&augmentedRows[j][i], &negInv) + // augmentedRows[j][i].SetZero() omitted + for k := i + 1; k < len(augmentedRows[i]); k++ { + var t {{.ElementType}} + t.Mul(&augmentedRows[i][k], &c) + augmentedRows[j][k].Add(&augmentedRows[j][k], &t) + } + } + } + + // back substitution + res := make(polynomial.Polynomial, len(X)) + for i := len(augmentedRows) - 1; i >= 0; i-- { + res[i] = augmentedRows[i][len(augmentedRows[i])-1] + for j := i + 1; j < len(augmentedRows[i])-1; j++ { + var t {{.ElementType}} + t.Mul(&res[j], &augmentedRows[i][j]) + res[i].Sub(&res[i], &t) + } + res[i].Div(&res[i], &augmentedRows[i][i]) + } + + return res, nil +} +{{- end }} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...{{.ElementType}}) {{.ElementType}} { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...{{.ElementType}}) {{.ElementType}} { + var res {{.ElementType}} + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} \ No newline at end of file diff --git a/internal/generator/backend/template/gkr/sumcheck.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.go.tmpl new file mode 100644 index 0000000000..3e8444add6 --- /dev/null +++ b/internal/generator/backend/template/gkr/sumcheck.go.tmpl @@ -0,0 +1,163 @@ +import ( + "errors" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a {{.ElementType}}) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next({{.ElementType}}) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []{{.ElementType}}) []{{.ElementType}} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a {{.ElementType}}) {{.ElementType}} // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof []{{.ElementType}}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []{{.ElementType}} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []{{.ElementType}}, remainingChallengeNames *[]string) ({{.ElementType}}, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return {{.ElementType}}{}, err + } + } + var res {{.ElementType}} + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff {{.ElementType}} + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]{{.ElementType}}, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff {{.ElementType}} + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]{{.ElementType}}, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := range claims.VarsNum() { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl new file mode 100644 index 0000000000..e7194ed89e --- /dev/null +++ b/internal/generator/backend/template/gkr/sumcheck.test.go.tmpl @@ -0,0 +1,142 @@ +import ( + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/{{.GkrPackageRelativePath}}/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []{{.ElementType}}) []{{.ElementType}} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []{{.ElementType}}{sum} +} + +func (c singleMultilinClaim) Combine({{.ElementType}}) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r {{.ElementType}}) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum {{.ElementType}} +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof []{{.ElementType}}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs {{.ElementType}}) {{.ElementType}} { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/generator/backend/template/gkr/test_vector_utils.go.tmpl b/internal/generator/backend/template/gkr/test_vector_utils.go.tmpl new file mode 100644 index 0000000000..5b7495eec3 --- /dev/null +++ b/internal/generator/backend/template/gkr/test_vector_utils.go.tmpl @@ -0,0 +1,220 @@ +import ( + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + "hash" + "reflect" + {{if eq .ElementType "fr.Element"}}"strings"{{- end}} +) + +func ToElement(i int64) *{{.ElementType}} { + var res {{.ElementType}} + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter {startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/{{.FieldPackageName}}.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/{{.FieldPackageName}}.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res {{.ElementType}} + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return {{.FieldPackageName}}.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return {{.FieldPackageName}}.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []{{.ElementType}} + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return {{.FieldPackageName}}.Bytes +} + +func (h *ListHash) BlockSize() int { +return {{.FieldPackageName}}.Bytes +} + +{{- if eq .ElementType "fr.Element"}} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} +{{- end}} + +{{- define "setElement element value elementType"}} +{{- if eq .elementType "fr.Element"}} SetElement(&{{.element}}, {{.value}}) +{{- else if eq .elementType "small_rational.SmallRational"}} {{.element}}.SetInterface({{.value}}) +{{- else}} + {{print "\"UNEXPECTED TYPE" .elementType "\""}} +{{- end}} +{{- end}} + +func SliceToElementSlice[T any](slice []T) ([]{{.ElementType}}, error) { + elementSlice := make([]{{.ElementType}}, len(slice)) + for i, v := range slice { + if _, err := {{setElement "elementSlice[i]" "v" .ElementType}}; err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []{{.ElementType}}, b []{{.ElementType}}) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]{{.ElementType}}, b [][]{{.ElementType}}) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i],b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i],b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *{{.ElementType}}) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().({{.ElementType}}) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/generator/backend/template/imports.go.tmpl b/internal/generator/backend/template/imports.go.tmpl index c1cac6c90a..ce8ee23954 100644 --- a/internal/generator/backend/template/imports.go.tmpl +++ b/internal/generator/backend/template/imports.go.tmpl @@ -68,7 +68,7 @@ {{- end}} {{- define "import_gkr"}} - "github.com/consensys/gnark-crypto/ecc/{{ toLower .Curve }}/fr/gkr" + gkr "github.com/consensys/gnark/internal/gkr/{{ toLower .Curve }}" {{- end}} {{- define "import_hash_to_field" }} diff --git a/internal/generator/backend/template/representations/gkr.go.tmpl b/internal/generator/backend/template/representations/gkr.go.tmpl index 89b4b8dbc5..83d73b666f 100644 --- a/internal/generator/backend/template/representations/gkr.go.tmpl +++ b/internal/generator/backend/template/representations/gkr.go.tmpl @@ -164,9 +164,8 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { offset += len(poly) } if proof[i].FinalEvalProof != nil { - finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) - frToBigInts(outs[offset:], finalEvalProof) - offset += len(finalEvalProof) + frToBigInts(outs[offset:], proof[i].FinalEvalProof) + offset += len(proof[i].FinalEvalProof) } } diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go new file mode 100644 index 0000000000..af779b1dab --- /dev/null +++ b/internal/gkr/bls12-377/gkr.go @@ -0,0 +1,930 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls12-377/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of the polynomial + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +// CombinedSum returns ∑ᵢ aⁱ yᵢ +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { + // the eq terms ( E ) + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the w(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) + manager *claimsManager + + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) + + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) +} + +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables ( E ) + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + // E := eq(x₀, -) + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() +} + +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) + nbGateIn := len(c.input) + + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.input) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { // compute method to allow parallelization across instances + var step fr.Element + + res := make([]fr.Element, degGJ) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) + for d := 1; d < degGJ; d++ { + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) + } + } + + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + } + } + mu.Lock() + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + } + mu.Unlock() + } + + const minBlockSize = 64 + + if sumSize < minBlockSize { + // no parallelization + computeAll(0, sumSize) + } else { + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return gJ +} + +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) + } + c.eq.Fold(challenge) + } else { + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + wI := c.input[inI] + if _, found := visited[in]; !found { + visited[in] = struct{}{} + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) + } + c.manager.memPool.Dump(wI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bls12-377/gkr_test.go b/internal/gkr/bls12-377/gkr_test.go new file mode 100644 index 0000000000..209b77cc0d --- /dev/null +++ b/internal/gkr/bls12-377/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls12-377/sumcheck" + "github.com/consensys/gnark/internal/gkr/bls12-377/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + const testDirPath = "../test_vectors/gkr" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bls12-377/registry.go b/internal/gkr/bls12-377/registry.go new file mode 100644 index 0000000000..251a0a5dfd --- /dev/null +++ b/internal/gkr/bls12-377/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bls12-377/sumcheck/sumcheck.go b/internal/gkr/bls12-377/sumcheck/sumcheck.go new file mode 100644 index 0000000000..3a0c516cc9 --- /dev/null +++ b/internal/gkr/bls12-377/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := range claims.VarsNum() { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bls12-377/sumcheck/sumcheck_test.go b/internal/gkr/bls12-377/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..24011c9552 --- /dev/null +++ b/internal/gkr/bls12-377/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bls12-377/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bls12-377/test_vector_utils/test_vector_utils.go b/internal/gkr/bls12-377/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..da958ce237 --- /dev/null +++ b/internal/gkr/bls12-377/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go new file mode 100644 index 0000000000..ce8281d8e3 --- /dev/null +++ b/internal/gkr/bls12-381/gkr.go @@ -0,0 +1,930 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls12-381/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of the polynomial + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +// CombinedSum returns ∑ᵢ aⁱ yᵢ +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { + // the eq terms ( E ) + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the w(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) + manager *claimsManager + + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) + + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) +} + +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables ( E ) + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + // E := eq(x₀, -) + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() +} + +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) + nbGateIn := len(c.input) + + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.input) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { // compute method to allow parallelization across instances + var step fr.Element + + res := make([]fr.Element, degGJ) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) + for d := 1; d < degGJ; d++ { + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) + } + } + + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + } + } + mu.Lock() + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + } + mu.Unlock() + } + + const minBlockSize = 64 + + if sumSize < minBlockSize { + // no parallelization + computeAll(0, sumSize) + } else { + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return gJ +} + +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) + } + c.eq.Fold(challenge) + } else { + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + wI := c.input[inI] + if _, found := visited[in]; !found { + visited[in] = struct{}{} + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) + } + c.manager.memPool.Dump(wI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bls12-381/gkr_test.go b/internal/gkr/bls12-381/gkr_test.go new file mode 100644 index 0000000000..dd2bebb645 --- /dev/null +++ b/internal/gkr/bls12-381/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls12-381/sumcheck" + "github.com/consensys/gnark/internal/gkr/bls12-381/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + const testDirPath = "../test_vectors/gkr" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bls12-381/registry.go b/internal/gkr/bls12-381/registry.go new file mode 100644 index 0000000000..484b1939f0 --- /dev/null +++ b/internal/gkr/bls12-381/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bls12-381/sumcheck/sumcheck.go b/internal/gkr/bls12-381/sumcheck/sumcheck.go new file mode 100644 index 0000000000..800a67938e --- /dev/null +++ b/internal/gkr/bls12-381/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := range claims.VarsNum() { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bls12-381/sumcheck/sumcheck_test.go b/internal/gkr/bls12-381/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..3d8c096e8c --- /dev/null +++ b/internal/gkr/bls12-381/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bls12-381/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bls12-381/test_vector_utils/test_vector_utils.go b/internal/gkr/bls12-381/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..b1a74e1bae --- /dev/null +++ b/internal/gkr/bls12-381/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go new file mode 100644 index 0000000000..5a27e68471 --- /dev/null +++ b/internal/gkr/bls24-315/gkr.go @@ -0,0 +1,930 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls24-315/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of the polynomial + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +// CombinedSum returns ∑ᵢ aⁱ yᵢ +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { + // the eq terms ( E ) + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the w(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) + manager *claimsManager + + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) + + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) +} + +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables ( E ) + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + // E := eq(x₀, -) + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() +} + +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) + nbGateIn := len(c.input) + + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.input) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { // compute method to allow parallelization across instances + var step fr.Element + + res := make([]fr.Element, degGJ) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) + for d := 1; d < degGJ; d++ { + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) + } + } + + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + } + } + mu.Lock() + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + } + mu.Unlock() + } + + const minBlockSize = 64 + + if sumSize < minBlockSize { + // no parallelization + computeAll(0, sumSize) + } else { + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return gJ +} + +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) + } + c.eq.Fold(challenge) + } else { + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + wI := c.input[inI] + if _, found := visited[in]; !found { + visited[in] = struct{}{} + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) + } + c.manager.memPool.Dump(wI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bls24-315/gkr_test.go b/internal/gkr/bls24-315/gkr_test.go new file mode 100644 index 0000000000..458ab5ed30 --- /dev/null +++ b/internal/gkr/bls24-315/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls24-315/sumcheck" + "github.com/consensys/gnark/internal/gkr/bls24-315/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + const testDirPath = "../test_vectors/gkr" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bls24-315/registry.go b/internal/gkr/bls24-315/registry.go new file mode 100644 index 0000000000..f5cae19de7 --- /dev/null +++ b/internal/gkr/bls24-315/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bls24-315/sumcheck/sumcheck.go b/internal/gkr/bls24-315/sumcheck/sumcheck.go new file mode 100644 index 0000000000..d2ca7f2d5c --- /dev/null +++ b/internal/gkr/bls24-315/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := range claims.VarsNum() { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bls24-315/sumcheck/sumcheck_test.go b/internal/gkr/bls24-315/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..d7219b08e1 --- /dev/null +++ b/internal/gkr/bls24-315/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bls24-315/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bls24-315/test_vector_utils/test_vector_utils.go b/internal/gkr/bls24-315/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..59836542d5 --- /dev/null +++ b/internal/gkr/bls24-315/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go new file mode 100644 index 0000000000..252c1c7c97 --- /dev/null +++ b/internal/gkr/bls24-317/gkr.go @@ -0,0 +1,930 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls24-317/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of the polynomial + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +// CombinedSum returns ∑ᵢ aⁱ yᵢ +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { + // the eq terms ( E ) + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the w(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) + manager *claimsManager + + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) + + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) +} + +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables ( E ) + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + // E := eq(x₀, -) + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() +} + +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) + nbGateIn := len(c.input) + + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.input) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { // compute method to allow parallelization across instances + var step fr.Element + + res := make([]fr.Element, degGJ) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) + for d := 1; d < degGJ; d++ { + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) + } + } + + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + } + } + mu.Lock() + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + } + mu.Unlock() + } + + const minBlockSize = 64 + + if sumSize < minBlockSize { + // no parallelization + computeAll(0, sumSize) + } else { + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return gJ +} + +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) + } + c.eq.Fold(challenge) + } else { + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + wI := c.input[inI] + if _, found := visited[in]; !found { + visited[in] = struct{}{} + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) + } + c.manager.memPool.Dump(wI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bls24-317/gkr_test.go b/internal/gkr/bls24-317/gkr_test.go new file mode 100644 index 0000000000..d4749815b0 --- /dev/null +++ b/internal/gkr/bls24-317/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bls24-317/sumcheck" + "github.com/consensys/gnark/internal/gkr/bls24-317/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + const testDirPath = "../test_vectors/gkr" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bls24-317/registry.go b/internal/gkr/bls24-317/registry.go new file mode 100644 index 0000000000..d0c68beea6 --- /dev/null +++ b/internal/gkr/bls24-317/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bls24-317/sumcheck/sumcheck.go b/internal/gkr/bls24-317/sumcheck/sumcheck.go new file mode 100644 index 0000000000..aeb31069b3 --- /dev/null +++ b/internal/gkr/bls24-317/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := range claims.VarsNum() { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bls24-317/sumcheck/sumcheck_test.go b/internal/gkr/bls24-317/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..58c43f491e --- /dev/null +++ b/internal/gkr/bls24-317/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bls24-317/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bls24-317/test_vector_utils/test_vector_utils.go b/internal/gkr/bls24-317/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..eef6e7dea9 --- /dev/null +++ b/internal/gkr/bls24-317/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go new file mode 100644 index 0000000000..8f78c2ee9e --- /dev/null +++ b/internal/gkr/bn254/gkr.go @@ -0,0 +1,930 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bn254/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of the polynomial + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +// CombinedSum returns ∑ᵢ aⁱ yᵢ +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { + // the eq terms ( E ) + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the w(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) + manager *claimsManager + + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) + + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) +} + +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables ( E ) + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + // E := eq(x₀, -) + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() +} + +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) + nbGateIn := len(c.input) + + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.input) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { // compute method to allow parallelization across instances + var step fr.Element + + res := make([]fr.Element, degGJ) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) + for d := 1; d < degGJ; d++ { + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) + } + } + + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + } + } + mu.Lock() + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + } + mu.Unlock() + } + + const minBlockSize = 64 + + if sumSize < minBlockSize { + // no parallelization + computeAll(0, sumSize) + } else { + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return gJ +} + +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) + } + c.eq.Fold(challenge) + } else { + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + wI := c.input[inI] + if _, found := visited[in]; !found { + visited[in] = struct{}{} + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) + } + c.manager.memPool.Dump(wI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bn254/gkr_test.go b/internal/gkr/bn254/gkr_test.go new file mode 100644 index 0000000000..095c78218d --- /dev/null +++ b/internal/gkr/bn254/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bn254/sumcheck" + "github.com/consensys/gnark/internal/gkr/bn254/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + const testDirPath = "../test_vectors/gkr" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bn254/registry.go b/internal/gkr/bn254/registry.go new file mode 100644 index 0000000000..be935ba7b5 --- /dev/null +++ b/internal/gkr/bn254/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bn254/sumcheck/sumcheck.go b/internal/gkr/bn254/sumcheck/sumcheck.go new file mode 100644 index 0000000000..fc57a1f3cd --- /dev/null +++ b/internal/gkr/bn254/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := range claims.VarsNum() { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bn254/sumcheck/sumcheck_test.go b/internal/gkr/bn254/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..9c3c6c5dc2 --- /dev/null +++ b/internal/gkr/bn254/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bn254/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bn254/test_vector_utils/test_vector_utils.go b/internal/gkr/bn254/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..93679ae858 --- /dev/null +++ b/internal/gkr/bn254/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go new file mode 100644 index 0000000000..6028a22ed4 --- /dev/null +++ b/internal/gkr/bw6-633/gkr.go @@ -0,0 +1,930 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bw6-633/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of the polynomial + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +// CombinedSum returns ∑ᵢ aⁱ yᵢ +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { + // the eq terms ( E ) + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the w(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) + manager *claimsManager + + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) + + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) +} + +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables ( E ) + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + // E := eq(x₀, -) + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() +} + +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) + nbGateIn := len(c.input) + + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.input) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { // compute method to allow parallelization across instances + var step fr.Element + + res := make([]fr.Element, degGJ) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) + for d := 1; d < degGJ; d++ { + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) + } + } + + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + } + } + mu.Lock() + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + } + mu.Unlock() + } + + const minBlockSize = 64 + + if sumSize < minBlockSize { + // no parallelization + computeAll(0, sumSize) + } else { + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return gJ +} + +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) + } + c.eq.Fold(challenge) + } else { + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + wI := c.input[inI] + if _, found := visited[in]; !found { + visited[in] = struct{}{} + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) + } + c.manager.memPool.Dump(wI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bw6-633/gkr_test.go b/internal/gkr/bw6-633/gkr_test.go new file mode 100644 index 0000000000..cead952cbd --- /dev/null +++ b/internal/gkr/bw6-633/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bw6-633/sumcheck" + "github.com/consensys/gnark/internal/gkr/bw6-633/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + const testDirPath = "../test_vectors/gkr" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bw6-633/registry.go b/internal/gkr/bw6-633/registry.go new file mode 100644 index 0000000000..d8bda624f1 --- /dev/null +++ b/internal/gkr/bw6-633/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bw6-633/sumcheck/sumcheck.go b/internal/gkr/bw6-633/sumcheck/sumcheck.go new file mode 100644 index 0000000000..71c472de96 --- /dev/null +++ b/internal/gkr/bw6-633/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := range claims.VarsNum() { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bw6-633/sumcheck/sumcheck_test.go b/internal/gkr/bw6-633/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..357e169b4f --- /dev/null +++ b/internal/gkr/bw6-633/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bw6-633/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bw6-633/test_vector_utils/test_vector_utils.go b/internal/gkr/bw6-633/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..ea84f4f255 --- /dev/null +++ b/internal/gkr/bw6-633/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go new file mode 100644 index 0000000000..dc08567116 --- /dev/null +++ b/internal/gkr/bw6-761/gkr.go @@ -0,0 +1,930 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bw6-761/sumcheck" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...fr.Element) fr.Element + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of the polynomial + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +// CombinedSum returns ∑ᵢ aⁱ yᵢ +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff, purportedValue fr.Element, inputEvaluationsNoRedundancy []fr.Element) error { + // the eq terms ( E ) + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the w(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]fr.Element // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []fr.Element // yᵢ = w(xᵢ) + manager *claimsManager + + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) + + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) +} + +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables ( E ) + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + // E := eq(x₀, -) + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() +} + +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) + nbGateIn := len(c.input) + + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.input) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + + gJ := make([]fr.Element, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { // compute method to allow parallelization across instances + var step fr.Element + + res := make([]fr.Element, degGJ) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]fr.Element, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) + for d := 1; d < degGJ; d++ { + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) + } + } + + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + } + } + mu.Lock() + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + } + mu.Unlock() + } + + const minBlockSize = 64 + + if sumSize < minBlockSize { + // no parallelization + computeAll(0, sumSize) + } else { + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return gJ +} + +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge fr.Element) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) + } + c.eq.Fold(challenge) + } else { + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) []fr.Element { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + wI := c.input[inI] + if _, found := visited[in]; !found { + visited[in] = struct{}{} + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) + } + c.manager.memPool.Dump(wI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []fr.Element) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]fr.Element, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]fr.Element, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []fr.Element) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/bw6-761/gkr_test.go b/internal/gkr/bw6-761/gkr_test.go new file mode 100644 index 0000000000..99ea1ff5d7 --- /dev/null +++ b/internal/gkr/bw6-761/gkr_test.go @@ -0,0 +1,829 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/bw6-761/sumcheck" + "github.com/consensys/gnark/internal/gkr/bw6-761/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" + "time" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleAddGateTwoInstances(t *testing.T) { + testSingleAddGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleAddGate(t *testing.T) { + testManyInstances(t, 2, testSingleAddGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + var o settings + pool := polynomial.NewPool(256, 1<<11) + workers := utils.NewWorkerPool() + o.pool = &pool + o.workers = workers + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, o) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandomSlice(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleAddGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Add2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: GetGate(Identity), + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate("mimc"), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/Consensys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: GetGate(Mul2), + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandomSlice(slice []fr.Element) { + for i := range slice { + slice[i].MustSetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(testCase.Hash)) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + const testDirPath = "../test_vectors/gkr" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof; len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof, xSeen.FinalEvalProof); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func benchmarkGkrMiMC(b *testing.B, nbInstances, mimcDepth int) { + fmt.Println("creating circuit structure") + c := mimcCircuit(mimcDepth) + + in0 := make([]fr.Element, nbInstances) + in1 := make([]fr.Element, nbInstances) + setRandomSlice(in0) + setRandomSlice(in1) + + fmt.Println("evaluating circuit") + start := time.Now().UnixMicro() + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + solved := time.Now().UnixMicro() - start + fmt.Println("solved in", solved, "μs") + + //b.ResetTimer() + fmt.Println("constructing proof") + start = time.Now().UnixMicro() + _, err := Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) + proved := time.Now().UnixMicro() - start + fmt.Println("proved in", proved, "μs") + assert.NoError(b, err) +} + +func BenchmarkGkrMimc19(b *testing.B) { + benchmarkGkrMiMC(b, 1<<19, 91) +} + +func BenchmarkGkrMimc17(b *testing.B) { + benchmarkGkrMiMC(b, 1<<17, 91) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = GetGate(c[i].Gate) + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC GateName = "mimc" + SelectInput3 GateName = "select-input-3" +) + +func init() { + if err := RegisterGate(MiMC, mimcRound, 2, WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := RegisterGate(SelectInput3, func(input ...fr.Element) fr.Element { + return input[2] + }, 3, WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash hash.Hash + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func TestRegisterGateDegreeDetection(t *testing.T) { + testGate := func(name GateName, f func(...fr.Element) fr.Element, nbIn, degree int) { + t.Run(string(name), func(t *testing.T) { + name = name + "-register-gate-test" + + assert.NoError(t, RegisterGate(name, f, nbIn, WithDegree(degree)), "given degree must be accepted") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree-1)), "lower degree must be rejected") + + assert.Error(t, RegisterGate(name, f, nbIn, WithDegree(degree+1)), "higher degree must be rejected") + + assert.NoError(t, RegisterGate(name, f, nbIn), "no degree must be accepted") + + assert.Equal(t, degree, GetGate(name).Degree(), "degree must be detected correctly") + }) + } + + testGate("select", func(x ...fr.Element) fr.Element { + return x[0] + }, 3, 1) + + testGate("add2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + res.Add(&res, &x[2]) + return res + }, 3, 1) + + testGate("mul2", func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, 2) + + testGate("mimc", mimcRound, 2, 7) + + testGate("sub2PlusOne", func(x ...fr.Element) fr.Element { + var res fr.Element + res. + SetOne(). + Add(&res, &x[0]). + Sub(&res, &x[1]) + return res + }, 2, 1) + + // zero polynomial must not be accepted + t.Run("zero", func(t *testing.T) { + const gateName GateName = "zero-register-gate-test" + expectedError := fmt.Errorf("for gate %s: %v", gateName, errZeroFunction) + zeroGate := func(x ...fr.Element) fr.Element { + var res fr.Element + return res + } + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1)) + + assert.Equal(t, expectedError, RegisterGate(gateName, zeroGate, 1, WithDegree(2))) + }) +} + +func TestIsAdditive(t *testing.T) { + + // f: x,y -> x² + xy + f := func(x ...fr.Element) fr.Element { + if len(x) != 2 { + panic("bivariate input needed") + } + var res fr.Element + res.Add(&x[0], &x[1]) + res.Mul(&res, &x[0]) + return res + } + + // g: x,y -> x² + 3y + g := func(x ...fr.Element) fr.Element { + var res, y3 fr.Element + res.Square(&x[0]) + y3.Mul(&x[1], &three) + res.Add(&res, &y3) + return res + } + + // h: x -> 2x + // but it edits it input + h := func(x ...fr.Element) fr.Element { + x[0].Double(&x[0]) + return x[0] + } + + assert.False(t, GateFunction(f).isAdditive(1, 2)) + assert.False(t, GateFunction(f).isAdditive(0, 2)) + + assert.False(t, GateFunction(g).isAdditive(0, 2)) + assert.True(t, GateFunction(g).isAdditive(1, 2)) + + assert.True(t, GateFunction(h).isAdditive(0, 1)) +} diff --git a/internal/gkr/bw6-761/registry.go b/internal/gkr/bw6-761/registry.go new file mode 100644 index 0000000000..ed7bb6819a --- /dev/null +++ b/internal/gkr/bw6-761/registry.go @@ -0,0 +1,320 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(fr.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]fr.Element, nbIn) + consts := make(fr.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + domain := fft.NewDomain(degreeBound) + // evaluate p on the unit circle (first filling p with evaluations rather than coefficients) + x := fr.One() + for i := range p { + fIn[0] = x + for j := range consts { + fIn[j+1].Mul(&x, &consts[j]) + } + p[i] = f(fIn...) + + x.Mul(&x, &domain.Generator) + } + + // obtain p's coefficients + domain.FFTInverse(p, fft.DIF) + fft.BitReverse(p) + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...fr.Element) fr.Element { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...fr.Element) fr.Element { + var res fr.Element + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/bw6-761/sumcheck/sumcheck.go b/internal/gkr/bw6-761/sumcheck/sumcheck.go new file mode 100644 index 0000000000..ddcc4d0057 --- /dev/null +++ b/internal/gkr/bw6-761/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) []fr.Element //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []fr.Element `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := range claims.VarsNum() { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/bw6-761/sumcheck/sumcheck_test.go b/internal/gkr/bw6-761/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..1ca6bbb57c --- /dev/null +++ b/internal/gkr/bw6-761/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/bw6-761/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) []fr.Element { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof []fr.Element) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/bw6-761/test_vector_utils/test_vector_utils.go b/internal/gkr/bw6-761/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..c3f063ff58 --- /dev/null +++ b/internal/gkr/bw6-761/test_vector_utils/test_vector_utils.go @@ -0,0 +1,216 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + "hash" + "reflect" + "strings" +) + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go new file mode 100644 index 0000000000..8efda8508f --- /dev/null +++ b/internal/gkr/small_rational/gkr.go @@ -0,0 +1,930 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "math/big" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// GateFunction a polynomial defining a gate. It may modify its input. The changes will be ignored. +type GateFunction func(...small_rational.SmallRational) small_rational.SmallRational + +// A Gate is a low-degree multivariate polynomial +type Gate struct { + Evaluate GateFunction // Evaluate the polynomial function defining the gate + nbIn int // number of inputs + degree int // total degree of the polynomial + solvableVar int // if there is a solvable variable, its index, -1 otherwise +} + +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 +func (g *Gate) Degree() int { + return g.degree +} + +// SolvableVar returns I such that x_I can always be determined from {xᵢ} - x_I and f(x...). If there is no such variable, it returns -1. +func (g *Gate) SolvableVar() int { + return g.solvableVar +} + +// NbIn returns the number of inputs to the gate (its fan-in) +func (g *Gate) NbIn() int { + return g.nbIn +} + +type Wire struct { + Gate *Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + +func (c Circuit) maxGateDegree() int { + res := 1 + for i := range c { + if !c[i].IsInput() { + res = max(res, c[i].Gate.Degree()) + } + } + return res +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]small_rational.SmallRational // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []small_rational.SmallRational // yᵢ = w(xᵢ), allegedly + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +// CombinedSum returns ∑ᵢ aⁱ yᵢ +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a small_rational.SmallRational) small_rational.SmallRational { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff, purportedValue small_rational.SmallRational, inputEvaluationsNoRedundancy []small_rational.SmallRational) error { + // the eq terms ( E ) + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the w(...) term + var gateEvaluation small_rational.SmallRational + if e.wire.IsInput() { // just compute w(r) + if len(inputEvaluationsNoRedundancy) != 0 { + return errors.New("final evaluation proof not needed for input wire") + } + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { // proof contains the evaluations of the inputs, but avoids repetition in case multiple inputs come from the same wire + inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return errors.New("incompatible evaluations") +} + +// eqTimesGateEvalSumcheckClaims is a claim for sumcheck (prover side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the proving of multiple evaluations of the same wire. +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire // the wire for which we are making the claim, with value w + evaluationPoints [][]small_rational.SmallRational // xᵢ: the points at which the prover has made claims about the evaluation of w + claimedEvaluations []small_rational.SmallRational // yᵢ = w(xᵢ) + manager *claimsManager + + input []polynomial.MultiLin // input[i](h₁, ..., hₘ₋ⱼ) = wᵢ(r₁, r₂, ..., rⱼ₋₁, h₁, ..., hₘ₋ⱼ) + + eq polynomial.MultiLin // E := ∑ᵢ cⁱ eq(xᵢ, -) +} + +// Combine the multiple claims into one claim using a random combination (combinationCoeff or c). +// From the original multiple claims of w(xᵢ) = yᵢ, we get a single claim +// ∑ᵢ,ₕ cⁱ eq(xᵢ, h) w(h) = ∑ᵢ cⁱ yᵢ, where h iterates over the hypercube (circuit instances) and +// i iterates over the claims. +// Equivalently, we could say ∑ᵢ cⁱ yᵢ = ∑ₕ,ᵢ cⁱ eq(xᵢ, h) w(h) = ∑ₕ w(h) ∑ᵢ cⁱ eq(xᵢ, h). +// Thus if we initially compute E := ∑ᵢ cⁱ eq(xᵢ, -), our claim will find the simpler form +// ∑ᵢ cⁱ yᵢ = ∑ₕ w(h) E(h), where the sum-checked polynomial is of degree deg(g) + 1, +// and deg(g) is the total degree of the polynomial defining the gate g of which w is the output. +// The output of Combine is the first sumcheck claim, i.e. ∑₍ₕ₁,ₕ₂,...₎ w(X, h₁, h₂, ...) E(X, h₁, h₂, ...).. +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational.SmallRational) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables ( E ) + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + // E := eq(x₀, -) + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + // E += cⁱ eq(xᵢ, -) + for k := 1; k < claimsNum; k++ { + newEq[0].Set(&aI) + + c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + return c.computeGJ() +} + +// eqAcc sets m to an eq table at q and then adds it to e. +// m <- eq(q, -). +// e <- e + m +func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []small_rational.SmallRational) { + n := len(q) + + //At the end of each iteration, m(h₁, ..., hₙ) = eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) + for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ + // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ + const threshold = 1 << 6 + k := 1 << i + if k < threshold { + for j := 0; j < k; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + } else { + c.manager.workers.Submit(k, func(start, end int) { + for j := start; j < end; j++ { + j0 := j << (n - i) // bᵢ₊₁ = 0 + j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 + + m[j1].Mul(&q[i], &m[j0]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 1) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ + m[j0].Sub(&m[j0], &m[j1]) // eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) eq(qᵢ₊₁, 0) = eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) + } + }, 1024).Wait() + } + + } + c.manager.workers.Submit(len(e), func(start, end int) { + for i := start; i < end; i++ { + e[i].Add(&e[i], &m[i]) + } + }, 512).Wait() +} + +// computeGJ: gⱼ = ∑_{0≤h<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, h...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., Xⱼ, h...) g( w₀(r₁, ..., Xⱼ, h...), ... ). +// the polynomial is represented by the evaluations gⱼ(1), gⱼ(2), ..., gⱼ(deg(gⱼ)). +// The value gⱼ(0) is inferred from the equation gⱼ(0) + gⱼ(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(gⱼ) + nbGateIn := len(c.input) + + // Both E and wᵢ (the input wires and the eq table) are multilinear, thus + // they are linear in Xⱼ. + // So for f ∈ { E(r₁, ..., Xⱼ, h...) } ∪ {wᵢ(r₁, ..., Xⱼ, h...) }, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables. + // ml are such multilinear polynomials the evaluations of which over different values of Xⱼ are computed in this stepwise manner. + ml := make([]polynomial.MultiLin, nbGateIn+1) + ml[0] = c.eq + copy(ml[1:], c.input) + + sumSize := len(c.eq) / 2 // the range of h, over which we sum + + // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called + + gJ := make([]small_rational.SmallRational, degGJ) + var mu sync.Mutex + computeAll := func(start, end int) { // compute method to allow parallelization across instances + var step small_rational.SmallRational + + res := make([]small_rational.SmallRational, degGJ) + // evaluations of ml, laid out as: + // ml[0](1, h...), ml[1](1, h...), ..., ml[len(ml)-1](1, h...), + // ml[0](2, h...), ml[1](2, h...), ..., ml[len(ml)-1](2, h...), + // ... + // ml[0](degGJ, h...), ml[2](degGJ, h...), ..., ml[len(ml)-1](degGJ, h...) + // Thus the contribution of the + mlEvals := make([]small_rational.SmallRational, degGJ*len(ml)) + + for h := start; h < end; h++ { // h counts across instances + + evalAt1Index := sumSize + h + for k := range ml { + // d = 0 + mlEvals[k].Set(&ml[k][evalAt1Index]) // evaluation at Xⱼ = 1. Can be taken directly from the table. + step.Sub(&mlEvals[k], &ml[k][h]) // step = ml[k](1) - ml[k](0) + for d := 1; d < degGJ; d++ { + mlEvals[d*len(ml)+k].Add(&mlEvals[(d-1)*len(ml)+k], &step) + } + } + + eIndex := 0 + nextEIndex := len(ml) + for d := range degGJ { + summand := c.wire.Gate.Evaluate(mlEvals[eIndex+1 : nextEIndex]...) + summand.Mul(&summand, &mlEvals[eIndex]) + res[d].Add(&res[d], &summand) // collect contributions into the sum from start to end + eIndex, nextEIndex = nextEIndex, nextEIndex+len(ml) + } + } + mu.Lock() + for i := range gJ { + gJ[i].Add(&gJ[i], &res[i]) // collect into the complete sum + } + mu.Unlock() + } + + const minBlockSize = 64 + + if sumSize < minBlockSize { + // no parallelization + computeAll(0, sumSize) + } else { + c.manager.workers.Submit(sumSize, computeAll, minBlockSize).Wait() + } + + return gJ +} + +// Next first folds the input and E polynomials at the given verifier challenge then computes the new gⱼ. +// Thus, j <- j+1 and rⱼ = challenge. +func (c *eqTimesGateEvalSumcheckClaims) Next(challenge small_rational.SmallRational) polynomial.Polynomial { + const minBlockSize = 512 + n := len(c.eq) / 2 + if n < minBlockSize { + // no parallelization + for i := 0; i < len(c.input); i++ { + c.input[i].Fold(challenge) + } + c.eq.Fold(challenge) + } else { + wgs := make([]*sync.WaitGroup, len(c.input)) + for i := 0; i < len(c.input); i++ { + wgs[i] = c.manager.workers.Submit(n, c.input[i].FoldParallel(challenge), minBlockSize) + } + c.manager.workers.Submit(n, c.eq.FoldParallel(challenge), minBlockSize).Wait() + for _, wg := range wgs { + wg.Wait() + } + } + + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +// ProveFinalEval provides the values wᵢ(r₁, ..., rₙ) +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational { + + //defer the proof, return list of claims + evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) + visited := make(map[*Wire]struct{}, len(c.input)) // we don't double report wires, in case a gate takes the same wire as multiple input variables. + visited[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + wI := c.input[inI] + if _, found := visited[in]; !found { + visited[in] = struct{}{} + wI.Fold(r[len(r)-1]) // We already have wᵢ(r₁, ..., rₙ₋₁, hₙ) in a table. Only one more fold required. + c.manager.add(in, r, wI[0]) + evaluations = append(evaluations, wI[0]) + } + c.manager.memPool.Dump(wI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool + workers *utils.WorkerPool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = o.pool + claims.workers = o.workers + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]small_rational.SmallRational, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []small_rational.SmallRational, evaluation small_rational.SmallRational) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.input = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.input = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.input[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int + workers *utils.WorkerPool +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func WithWorkers(workers *utils.WorkerPool) Option { + return func(options *settings) { + options.workers = workers + } +} + +// MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement +func (c Circuit) MemoryRequirements(nbInstances int) []int { + res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} + + if res[0] > res[1] { // make sure it's sorted + res[0], res[1] = res[1], res[0] + if res[1] > res[2] { + res[1], res[2] = res[2], res[1] + } + } + + return res +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_rational.SmallRational, error) { + res := make([]small_rational.SmallRational, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []small_rational.SmallRational) [][]byte { + baseChallenge := make([][]byte, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j].Marshal()) + } + } + return baseChallenge +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []small_rational.SmallRational{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + defer o.workers.Stop() + + claims := newClaimsManager(c, assignment, o) + + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return errors.New("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return errors.New("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { // incorporate prover claims about w's input into the transcript + baseChallenge = getBaseChallenge(wire, proof[i].FinalEvalProof) + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + idGate := GetGate("identity") + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = idGate + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// TopologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func TopologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := TopologicalSort(c) + nbInstances := a.NumInstances() + maxNbIns := 0 + + for _, w := range sortedWires { + maxNbIns = max(maxNbIns, len(w.Inputs)) + if a[w] == nil { + a[w] = make([]small_rational.SmallRational, nbInstances) + } + } + + // TODO: Parallelize, if needed + ins := make([]small_rational.SmallRational, maxNbIns) + for i := range nbInstances { + for _, w := range sortedWires { + if !w.IsInput() { + for inI, in := range w.Inputs { + ins[inI] = a[in][i] + } + a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) + } + } + } + + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} + +// SerializeToBigInts flattens a proof object into the given slice of big.Ints +// useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this +func (p Proof) SerializeToBigInts(outs []*big.Int) { + offset := 0 + for i := range p { + for _, poly := range p[i].PartialSumPolys { + frToBigInts(outs[offset:], poly) + offset += len(poly) + } + if p[i].FinalEvalProof != nil { + frToBigInts(outs[offset:], p[i].FinalEvalProof) + offset += len(p[i].FinalEvalProof) + } + } +} + +func frToBigInts(dst []*big.Int, src []small_rational.SmallRational) { + for i := range src { + src[i].BigInt(dst[i]) + } +} diff --git a/internal/gkr/small_rational/registry.go b/internal/gkr/small_rational/registry.go new file mode 100644 index 0000000000..b48f179c20 --- /dev/null +++ b/internal/gkr/small_rational/registry.go @@ -0,0 +1,374 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "errors" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "slices" + "sync" +) + +type GateName string + +var ( + gates = make(map[GateName]*Gate) + gatesLock sync.Mutex +) + +type registerGateSettings struct { + solvableVar int + noSolvableVarVerification bool + noDegreeVerification bool + degree int +} + +type RegisterGateOption func(*registerGateSettings) + +// WithSolvableVar gives the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will return an error if it cannot verify that this claim is correct. +func WithSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = solvableVar + } +} + +// WithUnverifiedSolvableVar sets the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not verify that the given index is correct. +func WithUnverifiedSolvableVar(solvableVar int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noSolvableVarVerification = true + settings.solvableVar = solvableVar + } +} + +// WithNoSolvableVar sets the gate as having no variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// RegisterGate will not check the correctness of this claim. +func WithNoSolvableVar() RegisterGateOption { + return func(settings *registerGateSettings) { + settings.solvableVar = -1 + settings.noSolvableVarVerification = true + } +} + +// WithUnverifiedDegree sets the degree of the gate. RegisterGate will not verify that the given degree is correct. +func WithUnverifiedDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.noDegreeVerification = true + settings.degree = degree + } +} + +// WithDegree sets the degree of the gate. RegisterGate will return an error if the degree is not correct. +func WithDegree(degree int) RegisterGateOption { + return func(settings *registerGateSettings) { + settings.degree = degree + } +} + +// isAdditive returns whether x_i occurs only in a monomial of total degree 1 in f +func (f GateFunction) isAdditive(i, nbIn int) bool { + // fix all variables except the i-th one at random points + // pick random value x1 for the i-th variable + // check if f(-, 0, -) + f(-, 2*x1, -) = 2*f(-, x1, -) + x := make(small_rational.Vector, nbIn) + x.MustSetRandom() + x0 := x[i] + x[i].SetZero() + in := slices.Clone(x) + y0 := f(in...) + + x[i] = x0 + copy(in, x) + y1 := f(in...) + + x[i].Double(&x[i]) + copy(in, x) + y2 := f(in...) + + y2.Sub(&y2, &y1) + y1.Sub(&y1, &y0) + + if !y2.Equal(&y1) { + return false // not linear + } + + // check if the coefficient of x_i is nonzero and independent of the other variables (so that we know it is ALWAYS nonzero) + if y1.IsZero() { // f(-, x1, -) = f(-, 0, -), so the coefficient of x_i is 0 + return false + } + + // compute the slope with another assignment for the other variables + x.MustSetRandom() + x[i].SetZero() + copy(in, x) + y0 = f(in...) + + x[i] = x0 + copy(in, x) + y1 = f(in...) + + y1.Sub(&y1, &y0) + + return y1.Equal(&y2) +} + +// fitPoly tries to fit a polynomial of degree less than degreeBound to f. +// degreeBound must be a power of 2. +// It returns the polynomial if successful, nil otherwise +func (f GateFunction) fitPoly(nbIn int, degreeBound uint64) polynomial.Polynomial { + // turn f univariate by defining p(x) as f(x, rx, ..., sx) + // where r, s, ... are random constants + fIn := make([]small_rational.SmallRational, nbIn) + consts := make(small_rational.Vector, nbIn-1) + consts.MustSetRandom() + + p := make(polynomial.Polynomial, degreeBound) + x := make(small_rational.Vector, degreeBound) + x.MustSetRandom() + for i := range x { + fIn[0] = x[i] + for j := range consts { + fIn[j+1].Mul(&x[i], &consts[j]) + } + p[i] = f(fIn...) + } + + // obtain p's coefficients + p, err := interpolate(x, p) + if err != nil { + panic(err) + } + + // check if p is equal to f. This not being the case means that f is of a degree higher than degreeBound + fIn[0].MustSetRandom() + for i := range consts { + fIn[i+1].Mul(&fIn[0], &consts[i]) + } + pAt := p.Eval(&fIn[0]) + fAt := f(fIn...) + if !pAt.Equal(&fAt) { + return nil + } + + // trim p + lastNonZero := len(p) - 1 + for lastNonZero >= 0 && p[lastNonZero].IsZero() { + lastNonZero-- + } + return p[:lastNonZero+1] +} + +type errorString string + +func (e errorString) Error() string { + return string(e) +} + +const errZeroFunction = errorString("detected a zero function") + +// FindDegree returns the degree of the gate function, or -1 if it fails. +// Failure could be due to the degree being higher than max or the function not being a polynomial at all. +func (f GateFunction) FindDegree(max, nbIn int) (int, error) { + bound := uint64(max) + 1 + for degreeBound := uint64(4); degreeBound <= bound; degreeBound *= 8 { + if p := f.fitPoly(nbIn, degreeBound); p != nil { + if len(p) == 0 { + return -1, errZeroFunction + } + return len(p) - 1, nil + } + } + return -1, fmt.Errorf("could not find a degree: tried up to %d", max) +} + +func (f GateFunction) VerifyDegree(claimedDegree, nbIn int) error { + if p := f.fitPoly(nbIn, ecc.NextPowerOfTwo(uint64(claimedDegree)+1)); p == nil { + return fmt.Errorf("detected a higher degree than %d", claimedDegree) + } else if len(p) == 0 { + return errZeroFunction + } else if len(p)-1 != claimedDegree { + return fmt.Errorf("detected degree %d, claimed %d", len(p)-1, claimedDegree) + } + return nil +} + +// FindSolvableVar returns the index of a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns -1 if it fails to find one. +// nbIn is the number of inputs to the gate +func (f GateFunction) FindSolvableVar(nbIn int) int { + for i := range nbIn { + if f.isAdditive(i, nbIn) { + return i + } + } + return -1 +} + +// IsVarSolvable returns whether claimedSolvableVar is a variable whose value can be uniquely determined from that of the other variables along with the gate's output. +// It returns false if it fails to verify this claim. +// nbIn is the number of inputs to the gate. +func (f GateFunction) IsVarSolvable(claimedSolvableVar, nbIn int) bool { + return f.isAdditive(claimedSolvableVar, nbIn) +} + +// RegisterGate creates a gate object and stores it in the gates registry. +// name is a human-readable name for the gate. +// f is the polynomial function defining the gate. +// nbIn is the number of inputs to the gate. +func RegisterGate(name GateName, f GateFunction, nbIn int, options ...RegisterGateOption) error { + s := registerGateSettings{degree: -1, solvableVar: -1} + for _, option := range options { + option(&s) + } + + if s.degree == -1 { // find a degree + if s.noDegreeVerification { + panic("invalid settings") + } + const maxAutoDegreeBound = 32 + var err error + if s.degree, err = f.FindDegree(maxAutoDegreeBound, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } else { + if !s.noDegreeVerification { // check that the given degree is correct + if err := f.VerifyDegree(s.degree, nbIn); err != nil { + return fmt.Errorf("for gate %s: %v", name, err) + } + } + } + + if s.solvableVar == -1 { + if !s.noSolvableVarVerification { // find a solvable variable + s.solvableVar = f.FindSolvableVar(nbIn) + } + } else { + // solvable variable given + if !s.noSolvableVarVerification && !f.IsVarSolvable(s.solvableVar, nbIn) { + return fmt.Errorf("cannot verify the solvability of variable %d in gate %s", s.solvableVar, name) + } + } + + gatesLock.Lock() + defer gatesLock.Unlock() + gates[name] = &Gate{Evaluate: f, nbIn: nbIn, degree: s.degree, solvableVar: s.solvableVar} + return nil +} + +func GetGate(name GateName) *Gate { + gatesLock.Lock() + defer gatesLock.Unlock() + return gates[name] +} + +// interpolate fits a polynomial of degree len(X) - 1 = len(Y) - 1 to the points (X[i], Y[i]) +// Note that the runtime is O(len(X)³) +func interpolate(X, Y []small_rational.SmallRational) (polynomial.Polynomial, error) { + if len(X) != len(Y) { + return nil, errors.New("X and Y must have the same length") + } + + // solve the system of equations by Gaussian elimination + augmentedRows := make([][]small_rational.SmallRational, len(X)) // the last column is the Y values + for i := range augmentedRows { + augmentedRows[i] = make([]small_rational.SmallRational, len(X)+1) + augmentedRows[i][0].SetOne() + augmentedRows[i][1].Set(&X[i]) + for j := 2; j < len(augmentedRows[i])-1; j++ { + augmentedRows[i][j].Mul(&augmentedRows[i][j-1], &X[i]) + } + augmentedRows[i][len(augmentedRows[i])-1].Set(&Y[i]) + } + + // make the upper triangle + for i := range len(augmentedRows) - 1 { + // use row i to eliminate the ith element in all rows below + var negInv small_rational.SmallRational + if augmentedRows[i][i].IsZero() { + return nil, errors.New("singular matrix") + } + negInv.Inverse(&augmentedRows[i][i]) + negInv.Neg(&negInv) + for j := i + 1; j < len(augmentedRows); j++ { + var c small_rational.SmallRational + c.Mul(&augmentedRows[j][i], &negInv) + // augmentedRows[j][i].SetZero() omitted + for k := i + 1; k < len(augmentedRows[i]); k++ { + var t small_rational.SmallRational + t.Mul(&augmentedRows[i][k], &c) + augmentedRows[j][k].Add(&augmentedRows[j][k], &t) + } + } + } + + // back substitution + res := make(polynomial.Polynomial, len(X)) + for i := len(augmentedRows) - 1; i >= 0; i-- { + res[i] = augmentedRows[i][len(augmentedRows[i])-1] + for j := i + 1; j < len(augmentedRows[i])-1; j++ { + var t small_rational.SmallRational + t.Mul(&res[j], &augmentedRows[i][j]) + res[i].Sub(&res[i], &t) + } + res[i].Div(&res[i], &augmentedRows[i][i]) + } + + return res, nil +} + +const ( + Identity GateName = "identity" // Identity gate: x -> x + Add2 GateName = "add2" // Add2 gate: (x, y) -> x + y + Sub2 GateName = "sub2" // Sub2 gate: (x, y) -> x - y + Neg GateName = "neg" // Neg gate: x -> -x + Mul2 GateName = "mul2" // Mul2 gate: (x, y) -> x * y +) + +func init() { + // register some basic gates + + if err := RegisterGate(Identity, func(x ...small_rational.SmallRational) small_rational.SmallRational { + return x[0] + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Add2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Add(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Sub2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Sub(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Neg, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Neg(&x[0]) + return res + }, 1, WithUnverifiedDegree(1), WithUnverifiedSolvableVar(0)); err != nil { + panic(err) + } + + if err := RegisterGate(Mul2, func(x ...small_rational.SmallRational) small_rational.SmallRational { + var res small_rational.SmallRational + res.Mul(&x[0], &x[1]) + return res + }, 2, WithUnverifiedDegree(2), WithNoSolvableVar()); err != nil { + panic(err) + } +} diff --git a/internal/gkr/small_rational/sumcheck/sumcheck.go b/internal/gkr/small_rational/sumcheck/sumcheck.go new file mode 100644 index 0000000000..b0d233b1dd --- /dev/null +++ b/internal/gkr/small_rational/sumcheck/sumcheck.go @@ -0,0 +1,170 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "errors" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a small_rational.SmallRational) small_rational.SmallRational // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof []small_rational.SmallRational `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []small_rational.SmallRational, remainingChallengeNames *[]string) (small_rational.SmallRational, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return small_rational.SmallRational{}, err + } + } + var res small_rational.SmallRational + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff small_rational.SmallRational + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]small_rational.SmallRational, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff small_rational.SmallRational + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]small_rational.SmallRational, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := range claims.VarsNum() { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return errors.New("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/gkr/small_rational/sumcheck/sumcheck_test.go b/internal/gkr/small_rational/sumcheck/sumcheck_test.go new file mode 100644 index 0000000000..43cc89393e --- /dev/null +++ b/internal/gkr/small_rational/sumcheck/sumcheck_test.go @@ -0,0 +1,149 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package sumcheck + +import ( + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/small_rational/test_vector_utils" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) []small_rational.SmallRational { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []small_rational.SmallRational{sum} +} + +func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum small_rational.SmallRational +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof []small_rational.SmallRational) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs small_rational.SmallRational) small_rational.SmallRational { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/gkr/small_rational/test_vector_utils/test_vector_utils.go b/internal/gkr/small_rational/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..3102a2133d --- /dev/null +++ b/internal/gkr/small_rational/test_vector_utils/test_vector_utils.go @@ -0,0 +1,185 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "hash" + "reflect" +) + +func ToElement(i int64) *small_rational.SmallRational { + var res small_rational.SmallRational + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res small_rational.SmallRational + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return small_rational.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return small_rational.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []small_rational.SmallRational + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return small_rational.Bytes +} + +func (h *ListHash) BlockSize() int { + return small_rational.Bytes +} + +func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *small_rational.SmallRational) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(small_rational.SmallRational) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/test_vector_utils/test_vector_utils.go b/internal/gkr/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..3102a2133d --- /dev/null +++ b/internal/gkr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,185 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "hash" + "reflect" +) + +func ToElement(i int64) *small_rational.SmallRational { + var res small_rational.SmallRational + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res small_rational.SmallRational + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return small_rational.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return small_rational.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []small_rational.SmallRational + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return small_rational.Bytes +} + +func (h *ListHash) BlockSize() int { + return small_rational.Bytes +} + +func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *small_rational.SmallRational) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(small_rational.SmallRational) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/gkr/test_vectors/gkr/circuits/mimc_five_levels.json b/internal/gkr/test_vectors/gkr/circuits/mimc_five_levels.json new file mode 100644 index 0000000000..3dd74f42b5 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/circuits/mimc_five_levels.json @@ -0,0 +1,36 @@ +[ + [ + { + "gate": "mimc", + "inputs": [[1,0], [5,5]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[2,0], [5,4]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[3,0], [5,3]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[4,0], [5,2]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[5,0], [5,1]] + } + ], + [ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []} + ] +] \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/circuits/single_identity_gate.json b/internal/gkr/test_vectors/gkr/circuits/single_identity_gate.json new file mode 100644 index 0000000000..a44066c7b4 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/circuits/single_identity_gate.json @@ -0,0 +1,10 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/circuits/single_input_two_identity_gates.json b/internal/gkr/test_vectors/gkr/circuits/single_input_two_identity_gates.json new file mode 100644 index 0000000000..6181784fa8 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/circuits/single_input_two_identity_gates.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/circuits/single_input_two_outs.json b/internal/gkr/test_vectors/gkr/circuits/single_input_two_outs.json new file mode 100644 index 0000000000..3a39e5625f --- /dev/null +++ b/internal/gkr/test_vectors/gkr/circuits/single_input_two_outs.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul2", + "inputs": [0, 0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/circuits/single_mimc_gate.json b/internal/gkr/test_vectors/gkr/circuits/single_mimc_gate.json new file mode 100644 index 0000000000..c89e7d52ae --- /dev/null +++ b/internal/gkr/test_vectors/gkr/circuits/single_mimc_gate.json @@ -0,0 +1,7 @@ +[ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + { + "gate": "mimc", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/circuits/single_mul_gate.json b/internal/gkr/test_vectors/gkr/circuits/single_mul_gate.json new file mode 100644 index 0000000000..d009ebe03d --- /dev/null +++ b/internal/gkr/test_vectors/gkr/circuits/single_mul_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul2", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/circuits/two_identity_gates_composed_single_input.json b/internal/gkr/test_vectors/gkr/circuits/two_identity_gates_composed_single_input.json new file mode 100644 index 0000000000..26681c2f89 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/circuits/two_identity_gates_composed_single_input.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [1] + } +] \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/circuits/two_inputs_select-input-3_gate.json b/internal/gkr/test_vectors/gkr/circuits/two_inputs_select-input-3_gate.json new file mode 100644 index 0000000000..cdbdb3b471 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/circuits/two_inputs_select-input-3_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "select-input-3", + "inputs": [0,0,1] + } +] \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go b/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go new file mode 100644 index 0000000000..c1dcb5b6e0 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/gkr-gen-vectors.go @@ -0,0 +1,341 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/bavard" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/small_rational" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" + "hash" + "os" + "path/filepath" + "reflect" +) + +func GenerateVectors() error { + testDirPath, err := filepath.Abs("../../gkr/test_vectors/gkr") + if err != nil { + return err + } + + fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) + + dirEntries, err := os.ReadDir(testDirPath) + if err != nil { + return err + } + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + if !bavard.ShouldGenerate(path) { + continue + } + fmt.Println("\tprocessing", dirEntry.Name()) + if err = run(path); err != nil { + return err + } + } + } + } + + return nil +} + +func run(absPath string) error { + testCase, err := newTestCase(absPath) + if err != nil { + return err + } + + transcriptSetting := fiatshamir.WithHash(testCase.Hash) + + var proof gkr.Proof + proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, transcriptSetting) + if err != nil { + return err + } + + if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { + return err + } + var outBytes []byte + if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { + if err = os.WriteFile(absPath, outBytes, 0); err != nil { + return err + } + } else { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, transcriptSetting) + if err != nil { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(2, 0))) + if err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { + res := make(PrintableProof, len(proof)) + + for i := range proof { + + partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) + for k, partialK := range proof[i].PartialSumPolys { + partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) + } + + res[i] = PrintableSumcheckProof{ + FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), + PartialSumPolys: partialSumPolys, + } + } + return res, nil +} + +type WireInfo struct { + Gate gkr.GateName `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]gkr.Circuit) + +func getCircuit(path string) (gkr.Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit gkr.Circuit) { + circuit = make(gkr.Circuit, len(c)) + for i := range c { + circuit[i].Gate = gkr.GetGate(c[i].Gate) + circuit[i].Inputs = make([]*gkr.Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +func mimcRound(input ...small_rational.SmallRational) (res small_rational.SmallRational) { + var sum small_rational.SmallRational + + sum. + Add(&input[0], &input[1]) //.Add(&sum, &m.ark) TODO: add ark + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +const ( + MiMC gkr.GateName = "mimc" + SelectInput3 gkr.GateName = "select-input-3" +) + +func init() { + if err := gkr.RegisterGate(MiMC, mimcRound, 2, gkr.WithUnverifiedDegree(7)); err != nil { + panic(err) + } + + if err := gkr.RegisterGate(SelectInput3, func(input ...small_rational.SmallRational) small_rational.SmallRational { + return input[2] + }, 3, gkr.WithUnverifiedDegree(1)); err != nil { + panic(err) + } +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (gkr.Proof, error) { + proof := make(gkr.Proof, len(printable)) + for i := range printable { + finalEvalProof := []small_rational.SmallRational(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit gkr.Circuit + Hash hash.Hash + Proof gkr.Proof + FullAssignment gkr.WireAssignment + InOutAssignment gkr.WireAssignment + Info TestCaseInfo // we are generating the test vectors, so we need to keep the circuit instance info to ADD the proof to it and resave it +} + +type TestCaseInfo struct { + Hash test_vector_utils.HashDescription `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit gkr.Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash hash.Hash + if _hash, err = test_vector_utils.HashFromDescription(info.Hash); err != nil { + return nil, err + } + var proof gkr.Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(gkr.WireAssignment) + inOutAssignment := make(gkr.WireAssignment) + + sorted := gkr.TopologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []small_rational.SmallRational + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + info.Output = make([][]interface{}, 0, outI) + + for _, w := range sorted { + if w.IsOutput() { + + info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + Info: info, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} diff --git a/internal/gkr/test_vectors/gkr/mimc_five_levels_two_instances._json b/internal/gkr/test_vectors/gkr/mimc_five_levels_two_instances._json new file mode 100644 index 0000000000..e980cfb0cb --- /dev/null +++ b/internal/gkr/test_vectors/gkr/mimc_five_levels_two_instances._json @@ -0,0 +1,7 @@ +{ + "hash": {"type": "const", "val": -1}, + "circuit": "circuits/mimc_five_levels.json", + "input": [[1, 3], [1, 3], [1, 3], [1, 3], [1, 3], [1, 3]], + "output": [[4, 3]], + "proof": [[{"partialSumPolys":[[3,4]],"finalEvalProof":[3]}],[{"partialSumPolys":null,"finalEvalProof":null}]] +} \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/single_identity_gate_two_instances.json b/internal/gkr/test_vectors/gkr/single_identity_gate_two_instances.json new file mode 100644 index 0000000000..ba28e35961 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/single_identity_gate_two_instances.json @@ -0,0 +1,36 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "circuits/single_identity_gate.json", + "input": [ + [ + 4, + 3 + ] + ], + "output": [ + [ + 4, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 5 + ], + "partialSumPolys": [ + [ + -3, + -8 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/single_input_two_identity_gates_two_instances.json b/internal/gkr/test_vectors/gkr/single_input_two_identity_gates_two_instances.json new file mode 100644 index 0000000000..1451b332c2 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/single_input_two_identity_gates_two_instances.json @@ -0,0 +1,56 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "circuits/single_input_two_identity_gates.json", + "input": [ + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ], + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + 0, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + }, + { + "finalEvalProof": [ + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/single_input_two_outs_two_instances.json b/internal/gkr/test_vectors/gkr/single_input_two_outs_two_instances.json new file mode 100644 index 0000000000..897aea7ee5 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/single_input_two_outs_two_instances.json @@ -0,0 +1,57 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "circuits/single_input_two_outs.json", + "input": [ + [ + 1, + 2 + ] + ], + "output": [ + [ + 1, + 4 + ], + [ + 1, + 2 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + 0, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 0 + ], + "partialSumPolys": [ + [ + -4, + -36, + -112 + ] + ] + }, + { + "finalEvalProof": [ + 0 + ], + "partialSumPolys": [ + [ + -2, + -12 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/single_mimc_gate_four_instances.json b/internal/gkr/test_vectors/gkr/single_mimc_gate_four_instances.json new file mode 100644 index 0000000000..a724ba5a7b --- /dev/null +++ b/internal/gkr/test_vectors/gkr/single_mimc_gate_four_instances.json @@ -0,0 +1,67 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "circuits/single_mimc_gate.json", + "input": [ + [ + 1, + 1, + 2, + 1 + ], + [ + 1, + 2, + 2, + 1 + ] + ], + "output": [ + [ + 128, + 2187, + 16384, + 128 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + -1, + -3 + ], + "partialSumPolys": [ + [ + -32640, + -2239484, + -29360128, + -200000010, + -931628672, + -3373267120, + -10200858624, + -26939400158 + ], + [ + -81920, + -41943040, + -1254113280, + -13421772800, + -83200000000, + -366917713920, + -1281828208640, + -3779571220480 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/single_mimc_gate_two_instances.json b/internal/gkr/test_vectors/gkr/single_mimc_gate_two_instances.json new file mode 100644 index 0000000000..901db48692 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/single_mimc_gate_two_instances.json @@ -0,0 +1,51 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "circuits/single_mimc_gate.json", + "input": [ + [ + 1, + 1 + ], + [ + 1, + 2 + ] + ], + "output": [ + [ + 128, + 2187 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 1, + 0 + ], + "partialSumPolys": [ + [ + -2187, + -65536, + -546875, + -2799360, + -10706059, + -33554432, + -90876411, + -220000000 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/single_mul_gate_two_instances.json b/internal/gkr/test_vectors/gkr/single_mul_gate_two_instances.json new file mode 100644 index 0000000000..b85a6df42c --- /dev/null +++ b/internal/gkr/test_vectors/gkr/single_mul_gate_two_instances.json @@ -0,0 +1,46 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "circuits/single_mul_gate.json", + "input": [ + [ + 4, + 3 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 8, + 9 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 5, + 1 + ], + "partialSumPolys": [ + [ + -9, + -32, + -35 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/two_identity_gates_composed_single_input_two_instances.json b/internal/gkr/test_vectors/gkr/two_identity_gates_composed_single_input_two_instances.json new file mode 100644 index 0000000000..69a2038a75 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/two_identity_gates_composed_single_input_two_instances.json @@ -0,0 +1,47 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "circuits/two_identity_gates_composed_single_input.json", + "input": [ + [ + 2, + 1 + ] + ], + "output": [ + [ + 2, + 1 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 3 + ], + "partialSumPolys": [ + [ + -1, + 0 + ] + ] + }, + { + "finalEvalProof": [ + 3 + ], + "partialSumPolys": [ + [ + -1, + 0 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/gkr/test_vectors/gkr/two_inputs_select-input-3_gate_two_instances.json b/internal/gkr/test_vectors/gkr/two_inputs_select-input-3_gate_two_instances.json new file mode 100644 index 0000000000..2dca0746a2 --- /dev/null +++ b/internal/gkr/test_vectors/gkr/two_inputs_select-input-3_gate_two_instances.json @@ -0,0 +1,45 @@ +{ + "hash": { + "type": "const", + "val": -1 + }, + "circuit": "circuits/two_inputs_select-input-3_gate.json", + "input": [ + [ + 0, + 1 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + -1, + 1 + ], + "partialSumPolys": [ + [ + -3, + -16 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/gkr/test_vectors/main.go b/internal/gkr/test_vectors/main.go new file mode 100644 index 0000000000..a551f0713c --- /dev/null +++ b/internal/gkr/test_vectors/main.go @@ -0,0 +1,17 @@ +package main + +import ( + "github.com/consensys/gnark/internal/gkr/test_vectors/gkr" + "github.com/consensys/gnark/internal/gkr/test_vectors/sumcheck" +) + +func main() { + assertNoError(sumcheck.GenerateVectors()) + assertNoError(gkr.GenerateVectors()) +} + +func assertNoError(err error) { + if err != nil { + panic(err) + } +} diff --git a/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go new file mode 100644 index 0000000000..14542ef228 --- /dev/null +++ b/internal/gkr/test_vectors/sumcheck/sumcheck-gen-vectors.go @@ -0,0 +1,204 @@ +package sumcheck + +import ( + "encoding/json" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark/internal/gkr/small_rational/sumcheck" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "github.com/consensys/gnark/internal/small_rational/test_vector_utils" + "hash" + "math/bits" + "os" + "path/filepath" + "runtime/pprof" +) + +func runMultilin(testCaseInfo *TestCaseInfo) error { + + var poly polynomial.MultiLin + if v, err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); err == nil { + poly = v + } else { + return err + } + + var ( + hsh hash.Hash + err error + ) + + if hsh, err = test_vector_utils.HashFromDescription(testCaseInfo.Hash); err != nil { + return err + } + + proof, err := sumcheck.Prove( + &singleMultilinClaim{poly}, fiatshamir.WithHash(hsh)) + if err != nil { + return err + } + testCaseInfo.Proof = toPrintableProof(proof) + + // Verification + if v, _err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); _err == nil { + poly = v + } else { + return _err + } + var claimedSum small_rational.SmallRational + if _, err = claimedSum.SetInterface(testCaseInfo.ClaimedSum); err != nil { + return err + } + + if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err != nil { + return fmt.Errorf("proof rejected: %v", err) + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(hsh)); err == nil { + return fmt.Errorf("bad proof accepted") + } + + pprof.StopCPUProfile() + //return f.Close() + + return nil +} + +func run(testCaseInfo *TestCaseInfo) error { + switch testCaseInfo.Type { + case "multilin": + return runMultilin(testCaseInfo) + default: + return fmt.Errorf("type \"%s\" unrecognized", testCaseInfo.Type) + } +} + +func GenerateVectors() error { + // read the test vectors file, generate the proof, make sure it verifies, + // and add the proof to the same file + const relPath = "../../gkr/test_vectors/sumcheck/vectors.json" + + var filename string + var err error + if filename, err = filepath.Abs(relPath); err != nil { + return err + } + + var bytes []byte + + if bytes, err = os.ReadFile(filename); err != nil { + return err + } + + var testCasesInfo TestCasesInfo + if err = json.Unmarshal(bytes, &testCasesInfo); err != nil { + return err + } + + failed := false + for name, testCase := range testCasesInfo { + if err = run(testCase); err != nil { + fmt.Println(name, ":", err) + failed = true + } + } + + if failed { + return fmt.Errorf("test case failed") + } + + if bytes, err = json.MarshalIndent(testCasesInfo, "", "\t"); err != nil { + return err + } + + return os.WriteFile(filename, bytes, 0) +} + +type TestCasesInfo map[string]*TestCaseInfo + +type TestCaseInfo struct { + Type string `json:"type"` + Hash test_vector_utils.HashDescription `json:"hash"` + Values []interface{} `json:"values"` + Description string `json:"description"` + Proof PrintableProof `json:"proof"` + ClaimedSum interface{} `json:"claimedSum"` +} + +type PrintableProof struct { + PartialSumPolys [][]interface{} `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` +} + +func toPrintableProof(proof sumcheck.Proof) (printable PrintableProof) { + if proof.FinalEvalProof != nil { + panic("null expected") + } + printable.FinalEvalProof = struct{}{} + printable.PartialSumPolys = test_vector_utils.ElementSliceSliceToInterfaceSliceSlice(proof.PartialSumPolys) + return +} + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval([]small_rational.SmallRational) []small_rational.SmallRational { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []small_rational.SmallRational{sum} +} + +func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum small_rational.SmallRational +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, _ small_rational.SmallRational, purportedValue small_rational.SmallRational, _ []small_rational.SmallRational) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(small_rational.SmallRational) small_rational.SmallRational { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} diff --git a/internal/gkr/test_vectors/sumcheck/vectors.json b/internal/gkr/test_vectors/sumcheck/vectors.json new file mode 100644 index 0000000000..64b8e3fb2d --- /dev/null +++ b/internal/gkr/test_vectors/sumcheck/vectors.json @@ -0,0 +1,56 @@ +{ + "linear_univariate_single_claim": { + "type": "multilin", + "hash": { + "type": "const", + "val": -1 + }, + "values": [ + 1, + 3 + ], + "description": "X ↦ 2X + 1", + "proof": { + "partialSumPolys": [ + [ + 3 + ] + ], + "finalEvalProof": {} + }, + "claimedSum": 4 + }, + "trilinear_single_claim": { + "type": "multilin", + "hash": { + "type": "const", + "val": -1 + }, + "values": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8 + ], + "description": "X₁, X₂, X₃ ↦ 1 + 4X₁ + 2X₂ + X₃", + "proof": { + "partialSumPolys": [ + [ + 26 + ], + [ + -1 + ], + [ + -4 + ] + ], + "finalEvalProof": {} + }, + "claimedSum": 36 + } +} \ No newline at end of file diff --git a/internal/small_rational/polynomial/doc.go b/internal/small_rational/polynomial/doc.go new file mode 100644 index 0000000000..95ba2f135f --- /dev/null +++ b/internal/small_rational/polynomial/doc.go @@ -0,0 +1,5 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Package polynomial provides polynomial methods and commitment schemes. +package polynomial diff --git a/internal/small_rational/polynomial/multilin.go b/internal/small_rational/polynomial/multilin.go new file mode 100644 index 0000000000..6bb2d8916e --- /dev/null +++ b/internal/small_rational/polynomial/multilin.go @@ -0,0 +1,176 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +package polynomial + +import ( + "github.com/consensys/gnark-crypto/utils" + "github.com/consensys/gnark/internal/small_rational" + "math/bits" +) + +// MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial +// The variables are X₁ through Xₙ where n = log(len(.)) +// .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ) +// It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial +type MultiLin []small_rational.SmallRational + +// Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r +func (m *MultiLin) Fold(r small_rational.SmallRational) { + mid := len(*m) / 2 + + bottom, top := (*m)[:mid], (*m)[mid:] + + var t small_rational.SmallRational // no need to update the top part + + // updating bookkeeping table + // knowing that the polynomial f ∈ (k[X₂, ..., Xₙ])[X₁] is linear, we would get f(r) = f(0) + r(f(1) - f(0)) + // the following loop computes the evaluations of f(r) accordingly: + // f(r, b₂, ..., bₙ) = f(0, b₂, ..., bₙ) + r(f(1, b₂, ..., bₙ) - f(0, b₂, ..., bₙ)) + for i := 0; i < mid; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + t.Sub(&top[i], &bottom[i]) + t.Mul(&t, &r) + bottom[i].Add(&bottom[i], &t) + } + + *m = (*m)[:mid] +} + +func (m *MultiLin) FoldParallel(r small_rational.SmallRational) utils.Task { + mid := len(*m) / 2 + bottom, top := (*m)[:mid], (*m)[mid:] + + *m = bottom + + return func(start, end int) { + var t small_rational.SmallRational // no need to update the top part + for i := start; i < end; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + t.Sub(&top[i], &bottom[i]) + t.Mul(&t, &r) + bottom[i].Add(&bottom[i], &t) + } + } +} + +func (m MultiLin) Sum() small_rational.SmallRational { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + +// Evaluate extrapolate the value of the multilinear polynomial corresponding to m +// on the given coordinates +func (m MultiLin) Evaluate(coordinates []small_rational.SmallRational, p *Pool) small_rational.SmallRational { + // Folding is a mutating operation + bkCopy := _clone(m, p) + + // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) + for _, r := range coordinates { + bkCopy.Fold(r) + } + + result := bkCopy[0] + + _dump(bkCopy, p) + return result +} + +// Clone creates a deep copy of a bookkeeping table. +// Both multilinear interpolation and sumcheck require folding an underlying +// array, but folding changes the array. To do both one requires a deep copy +// of the bookkeeping table. +func (m MultiLin) Clone() MultiLin { + res := make(MultiLin, len(m)) + copy(res, m) + return res +} + +// Add two bookKeepingTables +func (m *MultiLin) Add(left, right MultiLin) { + size := len(left) + // Check that left and right have the same size + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") + } + + // Add elementwise + for i := 0; i < size; i++ { + (*m)[i].Add(&left[i], &right[i]) + } +} + +// EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) +// where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates +// +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// +// In other words the polynomial evaluated here is the multilinear extrapolation of +// one that evaluates to q' == h' for vectors q', h' of binary values +func EvalEq(q, h []small_rational.SmallRational) small_rational.SmallRational { + var res, nxt, one, sum small_rational.SmallRational + one.SetOne() + for i := 0; i < len(q); i++ { + nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ + nxt.Double(&nxt) // nxt <- 2 * qᵢ * hᵢ + nxt.Add(&nxt, &one) // nxt <- 1 + 2 * qᵢ * hᵢ + sum.Add(&q[i], &h[i]) // sum <- qᵢ + hᵢ TODO: Why not subtract one by one from nxt? More parallel? + + if i == 0 { + res.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + } else { + nxt.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + res.Mul(&res, &nxt) // res <- res * nxt + } + } + return res +} + +// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] +func (m *MultiLin) Eq(q []small_rational.SmallRational) { + n := len(q) + + if len(*m) != 1<= 0; i-- { + res.Mul(&res, v) + res.Add(&res, &(*p)[i]) + } + + return res +} + +// Clone returns a copy of the polynomial +func (p *Polynomial) Clone() Polynomial { + _p := make(Polynomial, len(*p)) + copy(_p, *p) + return _p +} + +// Set to another polynomial +func (p *Polynomial) Set(p1 Polynomial) { + if len(*p) != len(p1) { + *p = p1.Clone() + return + } + + for i := 0; i < len(p1); i++ { + (*p)[i].Set(&p1[i]) + } +} + +// AddConstantInPlace adds a constant to the polynomial, modifying p +func (p *Polynomial) AddConstantInPlace(c *small_rational.SmallRational) { + for i := 0; i < len(*p); i++ { + (*p)[i].Add(&(*p)[i], c) + } +} + +// SubConstantInPlace subs a constant to the polynomial, modifying p +func (p *Polynomial) SubConstantInPlace(c *small_rational.SmallRational) { + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&(*p)[i], c) + } +} + +// ScaleInPlace multiplies p by v, modifying p +func (p *Polynomial) ScaleInPlace(c *small_rational.SmallRational) { + for i := 0; i < len(*p); i++ { + (*p)[i].Mul(&(*p)[i], c) + } +} + +// Scale multiplies p0 by v, storing the result in p +func (p *Polynomial) Scale(c *small_rational.SmallRational, p0 Polynomial) { + if len(*p) != len(p0) { + *p = make(Polynomial, len(p0)) + } + for i := 0; i < len(p0); i++ { + (*p)[i].Mul(c, &p0[i]) + } +} + +// Add adds p1 to p2 +// This function allocates a new slice unless p == p1 or p == p2 +func (p *Polynomial) Add(p1, p2 Polynomial) *Polynomial { + + bigger := p1 + smaller := p2 + if len(bigger) < len(smaller) { + bigger, smaller = smaller, bigger + } + + if len(*p) == len(bigger) && (&(*p)[0] == &bigger[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &smaller[i]) + } + return p + } + + if len(*p) == len(smaller) && (&(*p)[0] == &smaller[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &bigger[i]) + } + *p = append(*p, bigger[len(smaller):]...) + return p + } + + res := make(Polynomial, len(bigger)) + copy(res, bigger) + for i := 0; i < len(smaller); i++ { + res[i].Add(&res[i], &smaller[i]) + } + *p = res + return p +} + +// Sub subtracts p2 from p1 +// TODO make interface more consistent with Add +func (p *Polynomial) Sub(p1, p2 Polynomial) *Polynomial { + if len(p1) != len(p2) || len(p2) != len(*p) { + return nil + } + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&p1[i], &p2[i]) + } + return p +} + +// Equal checks equality between two polynomials +func (p *Polynomial) Equal(p1 Polynomial) bool { + if (*p == nil) != (p1 == nil) { + return false + } + + if len(*p) != len(p1) { + return false + } + + for i := range p1 { + if !(*p)[i].Equal(&p1[i]) { + return false + } + } + + return true +} + +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() + } +} + +func (p Polynomial) Text(base int) string { + + var builder strings.Builder + + first := true + for d := len(p) - 1; d >= 0; d-- { + if p[d].IsZero() { + continue + } + + pD := p[d] + pDText := pD.Text(base) + + initialLen := builder.Len() + + if pDText[0] == '-' { + pDText = pDText[1:] + if first { + builder.WriteString("-") + } else { + builder.WriteString(" - ") + } + } else if !first { + builder.WriteString(" + ") + } + + first = false + + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) + } + + if builder.Len()-initialLen > 10 { + builder.WriteString("×") + } + + if d != 0 { + builder.WriteString("X") + } + if d > 1 { + builder.WriteString( + utils.ToSuperscript(strconv.Itoa(d)), + ) + } + + } + + if first { + return "0" + } + + return builder.String() +} + +// InterpolateOnRange maps vector v to polynomial f +// such that f(i) = v[i] for 0 ≤ i < len(v). +// len(f) = len(v) and deg(f) ≤ len(v) - 1 +func InterpolateOnRange(v []small_rational.SmallRational) Polynomial { + nEvals := uint8(len(v)) + if int(nEvals) != len(v) { + panic("interpolation method too inefficient for nEvals > 255") + } + lagrange := getLagrangeBasis(nEvals) + + var res Polynomial + res.Scale(&v[0], lagrange[0]) + + temp := make(Polynomial, nEvals) + + for i := uint8(1); i < nEvals; i++ { + temp.Scale(&v[i], lagrange[i]) + res.Add(res, temp) + } + + return res +} + +// lagrange bases used by InterpolateOnRange +var lagrangeBasis sync.Map + +func getLagrangeBasis(domainSize uint8) []Polynomial { + if res, ok := lagrangeBasis.Load(domainSize); ok { + return res.([]Polynomial) + } + + // not found. compute + var res []Polynomial + if domainSize >= 2 { + res = computeLagrangeBasis(domainSize) + } else if domainSize == 1 { + res = []Polynomial{make(Polynomial, 1)} + res[0][0].SetOne() + } + lagrangeBasis.Store(domainSize, res) + + return res +} + +// computeLagrangeBasis precomputes in explicit coefficient form for each 0 ≤ l < domainSize the polynomial +// pₗ := X (X-1) ... (X-l-1) (X-l+1) ... (X - domainSize + 1) / ( l (l-1) ... 2 (-1) ... (l - domainSize +1) ) +// Note that pₗ(l) = 1 and pₗ(n) = 0 if 0 ≤ l < domainSize, n ≠ l +func computeLagrangeBasis(domainSize uint8) []Polynomial { + + constTerms := make([]small_rational.SmallRational, domainSize) + for i := uint8(0); i < domainSize; i++ { + constTerms[i].SetInt64(-int64(i)) + } + + res := make([]Polynomial, domainSize) + multScratch := make(Polynomial, domainSize-1) + + // compute pₗ + for l := uint8(0); l < domainSize; l++ { + + // TODO @Tabaie Optimize this with some trees? O(log(domainSize)) polynomial mults instead of O(domainSize)? Then again it would be fewer big poly mults vs many small poly mults + d := uint8(0) //d is the current degree of res + for i := uint8(0); i < domainSize; i++ { + if i == l { + continue + } + if d == 0 { + res[l] = make(Polynomial, domainSize) + res[l][domainSize-2] = constTerms[i] + res[l][domainSize-1].SetOne() + } else { + current := res[l][domainSize-d-2:] + timesConst := multScratch[domainSize-d-2:] + + timesConst.Scale(&constTerms[i], current[1:]) //TODO: Directly double and add since constTerms are tiny? (even less than 4 bits) + nonLeading := current[0 : d+1] + + nonLeading.Add(nonLeading, timesConst) + + } + d++ + } + + } + + // We have pₗ(i≠l)=0. Now scale so that pₗ(l)=1 + // Replace the constTerms with norms + for l := uint8(0); l < domainSize; l++ { + constTerms[l].Neg(&constTerms[l]) + constTerms[l] = res[l].Eval(&constTerms[l]) + } + constTerms = small_rational.BatchInvert(constTerms) + for l := uint8(0); l < domainSize; l++ { + res[l].ScaleInPlace(&constTerms[l]) + } + + return res +} diff --git a/internal/small_rational/polynomial/pool.go b/internal/small_rational/polynomial/pool.go new file mode 100644 index 0000000000..bc855ef5d4 --- /dev/null +++ b/internal/small_rational/polynomial/pool.go @@ -0,0 +1,29 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +package polynomial + +import ( + "github.com/consensys/gnark/internal/small_rational" +) + +// Do as little as possible to instantiate the interface +type Pool struct { +} + +func NewPool(...int) (pool Pool) { + return Pool{} +} + +func (p *Pool) Make(n int) []small_rational.SmallRational { + return make([]small_rational.SmallRational, n) +} + +func (p *Pool) Dump(...[]small_rational.SmallRational) { +} + +func (p *Pool) Clone(slice []small_rational.SmallRational) []small_rational.SmallRational { + res := p.Make(len(slice)) + copy(res, slice) + return res +} diff --git a/internal/small_rational/small-rational.go b/internal/small_rational/small-rational.go new file mode 100644 index 0000000000..6dbd87f1df --- /dev/null +++ b/internal/small_rational/small-rational.go @@ -0,0 +1,459 @@ +package small_rational + +import ( + "crypto/rand" + "fmt" + "math/big" + "strconv" + "strings" +) + +const Bytes = 64 + +// SmallRational implements the rational field, used to generate field agnostic test vectors. +// It is not optimized for performance, so it is best used sparingly. +type SmallRational struct { + text string //For debugging purposes + numerator big.Int + denominator big.Int // By convention, denominator == 0 also indicates zero +} + +var smallPrimes = []*big.Int{ + big.NewInt(2), big.NewInt(3), big.NewInt(5), + big.NewInt(7), big.NewInt(11), big.NewInt(13), +} + +func bigDivides(p, a *big.Int) bool { + var remainder big.Int + remainder.Mod(a, p) + return remainder.BitLen() == 0 +} + +func (z *SmallRational) UpdateText() { + z.text = z.Text(10) +} + +func (z *SmallRational) simplify() { + + if z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 { + return + } + + var num, den big.Int + + num.Set(&z.numerator) + den.Set(&z.denominator) + + for _, p := range smallPrimes { + for bigDivides(p, &num) && bigDivides(p, &den) { + num.Div(&num, p) + den.Div(&den, p) + } + } + + if bigDivides(&den, &num) { + num.Div(&num, &den) + den.SetInt64(1) + } + + z.numerator = num + z.denominator = den + +} +func (z *SmallRational) Square(x *SmallRational) *SmallRational { + var num, den big.Int + num.Mul(&x.numerator, &x.numerator) + den.Mul(&x.denominator, &x.denominator) + + z.numerator = num + z.denominator = den + + z.UpdateText() + + return z +} + +func (z *SmallRational) String() string { + z.text = z.Text(10) + return z.text +} + +func (z *SmallRational) Add(x, y *SmallRational) *SmallRational { + if x.denominator.BitLen() == 0 { + *z = *y + } else if y.denominator.BitLen() == 0 { + *z = *x + } else { + //TODO: Exploit cases where one denom divides the other + var numDen, denNum big.Int + numDen.Mul(&x.numerator, &y.denominator) + denNum.Mul(&x.denominator, &y.numerator) + + numDen.Add(&denNum, &numDen) + z.numerator = numDen //to avoid shallow copy problems + + denNum.Mul(&x.denominator, &y.denominator) + z.denominator = denNum + z.simplify() + } + + z.UpdateText() + + return z +} + +func (z *SmallRational) IsZero() bool { + return z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 +} + +func (z *SmallRational) Inverse(x *SmallRational) *SmallRational { + if x.IsZero() { + *z = *x + } else { + *z = SmallRational{numerator: x.denominator, denominator: x.numerator} + z.UpdateText() + } + + return z +} + +func (z *SmallRational) Neg(x *SmallRational) *SmallRational { + z.numerator.Neg(&x.numerator) + z.denominator = x.denominator + + if x.text == "" { + x.UpdateText() + } + + if x.text[0] == '-' { + z.text = x.text[1:] + } else { + z.text = "-" + x.text + } + + return z +} + +func (z *SmallRational) Double(x *SmallRational) *SmallRational { + + var y big.Int + + if x.denominator.Bit(0) == 0 { + z.numerator = x.numerator + y.Rsh(&x.denominator, 1) + z.denominator = y + } else { + y.Lsh(&x.numerator, 1) + z.numerator = y + z.denominator = x.denominator + } + + z.UpdateText() + + return z +} + +func (z *SmallRational) Sign() int { + return z.numerator.Sign() * z.denominator.Sign() +} + +func (z *SmallRational) MarshalJSON() ([]byte, error) { + return []byte(z.String()), nil +} + +func (z *SmallRational) UnmarshalJson(data []byte) error { + _, err := z.SetInterface(string(data)) + return err +} + +func (z *SmallRational) Equal(x *SmallRational) bool { + return z.Cmp(x) == 0 +} + +func (z *SmallRational) Sub(x, y *SmallRational) *SmallRational { + var yNeg SmallRational + yNeg.Neg(y) + z.Add(x, &yNeg) + + z.UpdateText() + return z +} + +func (z *SmallRational) Cmp(x *SmallRational) int { + zSign, xSign := z.Sign(), x.Sign() + + if zSign > xSign { + return 1 + } + if zSign < xSign { + return -1 + } + + var Z, X big.Int + Z.Mul(&z.numerator, &x.denominator) + X.Mul(&x.numerator, &z.denominator) + + Z.Abs(&Z) + X.Abs(&X) + + return Z.Cmp(&X) * zSign + +} + +func BatchInvert(a []SmallRational) []SmallRational { + res := make([]SmallRational, len(a)) + for i := range a { + res[i].Inverse(&a[i]) + } + return res +} + +func (z *SmallRational) Mul(x, y *SmallRational) *SmallRational { + var num, den big.Int + + num.Mul(&x.numerator, &y.numerator) + den.Mul(&x.denominator, &y.denominator) + + z.numerator = num + z.denominator = den + + z.simplify() + z.UpdateText() + return z +} + +func (z *SmallRational) Div(x, y *SmallRational) *SmallRational { + var num, den big.Int + + num.Mul(&x.numerator, &y.denominator) + den.Mul(&x.denominator, &y.numerator) + + z.numerator = num + z.denominator = den + + z.simplify() + z.UpdateText() + return z +} + +func (z *SmallRational) Halve() *SmallRational { + if z.numerator.Bit(0) == 0 { + z.numerator.Rsh(&z.numerator, 1) + } else { + z.denominator.Lsh(&z.denominator, 1) + } + + z.simplify() + z.UpdateText() + return z +} + +func (z *SmallRational) SetOne() *SmallRational { + return z.SetInt64(1) +} + +func (z *SmallRational) SetZero() *SmallRational { + return z.SetInt64(0) +} + +func (z *SmallRational) SetInt64(i int64) *SmallRational { + z.numerator = *big.NewInt(i) + z.denominator = *big.NewInt(1) + z.text = strconv.FormatInt(i, 10) + return z +} + +func (z *SmallRational) SetRandom() (*SmallRational, error) { + + bytes := make([]byte, 1) + n, err := rand.Read(bytes) + if err != nil { + return nil, err + } + if n != len(bytes) { + return nil, fmt.Errorf("%d bytes read instead of %d", n, len(bytes)) + } + + z.numerator = *big.NewInt(int64(bytes[0]%16) - 8) + z.denominator = *big.NewInt(int64((bytes[0]) / 16)) + + z.simplify() + z.UpdateText() + + return z, nil +} + +func (z *SmallRational) MustSetRandom() *SmallRational { + if _, err := z.SetRandom(); err != nil { + panic(err) + } + return z +} + +func (z *SmallRational) SetUint64(i uint64) { + var num big.Int + num.SetUint64(i) + z.numerator = num + z.denominator = *big.NewInt(1) + z.text = strconv.FormatUint(i, 10) +} + +func (z *SmallRational) IsOne() bool { + return z.numerator.Cmp(&z.denominator) == 0 && z.denominator.BitLen() != 0 +} + +func (z *SmallRational) Text(base int) string { + + if z.denominator.BitLen() == 0 { + return "0" + } + + if z.denominator.Sign() < 0 { + var num, den big.Int + num.Neg(&z.numerator) + den.Neg(&z.denominator) + z.numerator = num + z.denominator = den + } + + if bigDivides(&z.denominator, &z.numerator) { + var num big.Int + num.Div(&z.numerator, &z.denominator) + z.numerator = num + z.denominator = *big.NewInt(1) + } + + numerator := z.numerator.Text(base) + + if z.denominator.IsInt64() && z.denominator.Int64() == 1 { + return numerator + } + + return numerator + "/" + z.denominator.Text(base) +} + +func (z *SmallRational) Set(x *SmallRational) *SmallRational { + *z = *x // shallow copy is safe because ops are never in place + return z +} + +func (z *SmallRational) SetInterface(x interface{}) (*SmallRational, error) { + + switch v := x.(type) { + case *SmallRational: + *z = *v + case SmallRational: + *z = v + case int64: + z.SetInt64(v) + case int: + z.SetInt64(int64(v)) + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + case string: + z.text = v + sep := strings.Split(v, "/") + switch len(sep) { + case 1: + if asInt, err := strconv.Atoi(sep[0]); err == nil { + z.SetInt64(int64(asInt)) + } else { + return nil, err + } + case 2: + var err error + var num, denom int + num, err = strconv.Atoi(sep[0]) + if err != nil { + return nil, err + } + denom, err = strconv.Atoi(sep[1]) + if err != nil { + return nil, err + } + z.numerator = *big.NewInt(int64(num)) + z.denominator = *big.NewInt(int64(denom)) + default: + return nil, fmt.Errorf("cannot parse \"%s\"", v) + } + default: + return nil, fmt.Errorf("cannot parse %T", x) + } + + return z, nil +} + +func bigIntToBytesSigned(dst []byte, src big.Int) { + src.FillBytes(dst[1:]) + dst[0] = 0 + if src.Sign() < 0 { + dst[0] = 255 + } +} + +func (z *SmallRational) Bytes() [Bytes]byte { + var res [Bytes]byte + bigIntToBytesSigned(res[:Bytes/2], z.numerator) + bigIntToBytesSigned(res[Bytes/2:], z.denominator) + return res +} + +func (z *SmallRational) Marshal() []byte { + res := z.Bytes() + return res[:] +} + +func bytesToBigIntSigned(src []byte) big.Int { + var res big.Int + res.SetBytes(src[1:]) + if src[0] != 0 { + res.Neg(&res) + } + return res +} + +// BigInt returns sets dst to the value of z if it is an integer. +// if z is not an integer, nil is returned. +// if the given dst is nil, the address of the numerator is returned. +// if the given dst is non-nil, it is returned. +func (z *SmallRational) BigInt(dst *big.Int) *big.Int { + if z.denominator.Cmp(big.NewInt(1)) != 0 { + return nil + } + if dst == nil { + return &z.numerator + } + dst.Set(&z.numerator) + return dst +} + +func (z *SmallRational) SetBytes(b []byte) { + if len(b) > Bytes/2 { + z.numerator = bytesToBigIntSigned(b[:Bytes/2]) + z.denominator = bytesToBigIntSigned(b[Bytes/2:]) + } else { + z.numerator.SetBytes(b) + z.denominator.SetInt64(1) + } + z.simplify() + z.UpdateText() +} + +func One() SmallRational { + res := SmallRational{ + text: "1", + } + res.numerator.SetInt64(1) + res.denominator.SetInt64(1) + return res +} + +func Modulus() *big.Int { + res := big.NewInt(1) + res.Lsh(res, 64) + return res +} diff --git a/internal/small_rational/small_rational_test.go b/internal/small_rational/small_rational_test.go new file mode 100644 index 0000000000..6d7733ea76 --- /dev/null +++ b/internal/small_rational/small_rational_test.go @@ -0,0 +1,115 @@ +package small_rational + +import ( + "github.com/stretchr/testify/assert" + "math/big" + "testing" +) + +func TestBigDivides(t *testing.T) { + assert.True(t, bigDivides(big.NewInt(-1), big.NewInt(4))) + assert.False(t, bigDivides(big.NewInt(-3), big.NewInt(4))) +} + +func TestCmp(t *testing.T) { + + cases := make([]SmallRational, 36) + + for i := int64(0); i < 9; i++ { + if i%2 == 0 { + cases[4*i].numerator.SetInt64((i - 4) / 2) + cases[4*i].denominator.SetInt64(1) + } else { + cases[4*i].numerator.SetInt64(i - 4) + cases[4*i].denominator.SetInt64(2) + } + + cases[4*i+1].numerator.Neg(&cases[4*i].numerator) + cases[4*i+1].denominator.Neg(&cases[4*i].denominator) + + cases[4*i+2].numerator.Lsh(&cases[4*i].numerator, 1) + cases[4*i+2].denominator.Lsh(&cases[4*i].denominator, 1) + + cases[4*i+3].numerator.Neg(&cases[4*i+2].numerator) + cases[4*i+3].denominator.Neg(&cases[4*i+2].denominator) + } + + for i := range cases { + for j := range cases { + I, J := i/4, j/4 + var expectedCmp int + cmp := cases[i].Cmp(&cases[j]) + if I < J { + expectedCmp = -1 + } else if I == J { + expectedCmp = 0 + } else { + expectedCmp = 1 + } + assert.Equal(t, expectedCmp, cmp, "comparing index %d, index %d", i, j) + } + } + + zeroIndex := len(cases) / 8 + var weirdZero SmallRational + for i := range cases { + I := i / 4 + var expectedCmp int + cmp := cases[i].Cmp(&weirdZero) + cmpNeg := weirdZero.Cmp(&cases[i]) + if I < zeroIndex { + expectedCmp = -1 + } else if I == zeroIndex { + expectedCmp = 0 + } else { + expectedCmp = 1 + } + + assert.Equal(t, expectedCmp, cmp, "comparing index %d, 0/0", i) + assert.Equal(t, -expectedCmp, cmpNeg, "comparing 0/0, index %d", i) + } +} + +func TestDouble(t *testing.T) { + values := []interface{}{1, 2, 3, 4, 5, "2/3", "3/2", "-3/-2"} + valsDoubled := []interface{}{2, 4, 6, 8, 10, "-4/-3", 3, 3} + + for i := range values { + var v, vDoubled, vDoubledExpected SmallRational + _, err := v.SetInterface(values[i]) + assert.NoError(t, err) + _, err = vDoubledExpected.SetInterface(valsDoubled[i]) + assert.NoError(t, err) + vDoubled.Double(&v) + assert.True(t, vDoubled.Equal(&vDoubledExpected), + "mismatch at %d: expected 2×%s = %s, saw %s", i, v.text, vDoubledExpected.text, vDoubled.text) + + } +} + +func TestOperandConstancy(t *testing.T) { + var p0, p, pPure SmallRational + p0.SetInt64(1) + p.SetInt64(-3) + pPure.SetInt64(-3) + + res := p + res.Add(&res, &p0) + assert.True(t, p.Equal(&pPure)) +} + +func TestSquare(t *testing.T) { + var two, four, x SmallRational + two.SetInt64(2) + four.SetInt64(4) + + x.Square(&two) + + assert.True(t, x.Equal(&four), "expected 4, saw %s", x.Text(10)) +} + +func TestSetBytes(t *testing.T) { + var c SmallRational + c.SetBytes([]byte("firstChallenge.0")) + +} diff --git a/internal/small_rational/test_vector_utils/test_vector_utils.go b/internal/small_rational/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..9459281e09 --- /dev/null +++ b/internal/small_rational/test_vector_utils/test_vector_utils.go @@ -0,0 +1,186 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + + "hash" + "reflect" +) + +func ToElement(i int64) *small_rational.SmallRational { + var res small_rational.SmallRational + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res small_rational.SmallRational + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return small_rational.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return small_rational.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []small_rational.SmallRational + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return small_rational.Bytes +} + +func (h *ListHash) BlockSize() int { + return small_rational.Bytes +} + +func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *small_rational.SmallRational) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(small_rational.SmallRational) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/small_rational/vector.go b/internal/small_rational/vector.go new file mode 100644 index 0000000000..07fcc3afff --- /dev/null +++ b/internal/small_rational/vector.go @@ -0,0 +1,9 @@ +package small_rational + +type Vector []SmallRational + +func (v Vector) MustSetRandom() { + for i := range v { + v[i].MustSetRandom() + } +} diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go index b8c4d9951c..be2f48492b 100644 --- a/std/gkr/api_test.go +++ b/std/gkr/api_test.go @@ -8,6 +8,8 @@ import ( "testing" "time" + gcHash "github.com/consensys/gnark-crypto/hash" + bls12377 "github.com/consensys/gnark/constraint/bls12-377" bls12381 "github.com/consensys/gnark/constraint/bls12-381" bls24315 "github.com/consensys/gnark/constraint/bls24-315" @@ -21,12 +23,11 @@ import ( "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" - bn254MiMC "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" + gkr "github.com/consensys/gnark/internal/gkr/bn254" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" test_vector_utils "github.com/consensys/gnark/std/internal/test_vectors_utils" @@ -386,9 +387,7 @@ func (c *benchMiMCMerkleTreeCircuit) Define(api frontend.API) error { } func registerMiMC() { - bn254.RegisterHashBuilder("mimc", func() hash.Hash { - return bn254MiMC.NewMiMC() - }) + bn254.RegisterHashBuilder("mimc", gcHash.MIMC_BN254.New) stdHash.Register("mimc", func(api frontend.API) (stdHash.FieldHasher, error) { m, err := mimc.NewMiMC(api) return &m, err diff --git a/std/gkr/bn254_wrapper_api.go b/std/gkr/bn254_wrapper_api.go index 0538b1d3e2..a45109006f 100644 --- a/std/gkr/bn254_wrapper_api.go +++ b/std/gkr/bn254_wrapper_api.go @@ -4,8 +4,8 @@ import ( "errors" "fmt" "github.com/consensys/gnark-crypto/ecc/bn254/fr" - "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/gkr/bn254" "github.com/consensys/gnark/internal/utils" ) diff --git a/std/gkr/compile.go b/std/gkr/compile.go index 3d683fc9e3..4a86ac3784 100644 --- a/std/gkr/compile.go +++ b/std/gkr/compile.go @@ -2,6 +2,7 @@ package gkr import ( "errors" + "fmt" "math/bits" "github.com/consensys/gnark/constraint" @@ -108,9 +109,14 @@ func (api *API) Solve(parentApi frontend.API) (Solution, error) { for i := range circuit { v := &circuit[i] - if v.IsInput() { + in, out := v.IsInput(), v.IsOutput() + if in && out { + return Solution{}, fmt.Errorf("unused input (variable #%d)", i) + } + + if in { solveHintNIn += nbInstances - len(v.Dependencies) - } else if v.IsOutput() { + } else if out { solveHintNOut += nbInstances } } @@ -152,8 +158,8 @@ func (api *API) Solve(parentApi frontend.API) (Solution, error) { } // Export returns the values of an output variable across all instances -func (s Solution) Export(v frontend.Variable) []frontend.Variable { - return utils.Map(s.permutations.SortedInstances, utils.SliceAt(s.assignments[v.(constraint.GkrVariable)])) +func (s Solution) Export(v constraint.GkrVariable) []frontend.Variable { + return utils.Map(s.permutations.SortedInstances, utils.SliceAt(s.assignments[v])) } // Verify encodes the verification circuitry for the GKR circuit diff --git a/std/gkr/example_test.go b/std/gkr/example_test.go new file mode 100644 index 0000000000..d61adae39e --- /dev/null +++ b/std/gkr/example_test.go @@ -0,0 +1,276 @@ +package gkr_test + +import ( + "encoding/binary" + "errors" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + gcHash "github.com/consensys/gnark-crypto/hash" + bw6761 "github.com/consensys/gnark/constraint/bw6-761" + "github.com/consensys/gnark/frontend" + gkrBw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" + "github.com/consensys/gnark/std/gkr" + stdHash "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/test" +) + +func Example() { + // This example computes the double of multiple BLS12-377 G1 points, which can be computed natively over BW6-761. + // This means that the imported fr and fp packages are the same, being from BW6-761 and BLS12-377 respectively. TODO @Tabaie delete if no longer have fp imported + // It is based on the function DoubleAssign() of type G1Jac in gnark-crypto v0.17.0. + // github.com/consensys/gnark-crypto/ecc/bls12-377 + const ( + gateNamePrefix = "bls12-377-jac-double-" + fsHashName = "mimc" + ) + + // Every gate needs to be defined over a concrete field, used by the GKR prover, + // and over a frontend.API, used by the in-SNARK GKR verifier. + // + // Note that the SNARK prover will need both of these: + // The GKR prover will provide a proof as private input to the SNARK prover, + // wherein the embedded GKR verifier will verify it, establishing the correctness + // of our claimed values. + + // This function will contain the concrete implementations of the gates. + // The SNARK implementations will be defined in the Define() method of the circuit. + + // combine the operations that define the first value assigned to variable S + // input = [X, YY, XX, YYYY] + // S = 2 * [(X + YY)² - XX - YYYY] + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"s", func(input ...fr.Element) (S fr.Element) { + S. + Add(&input[0], &input[1]). // 409: S.Add(&p.X, &YY) + Square(&S). // 410: S.Square(&S). + Sub(&S, &input[2]). // 411: Sub(&S, &XX). + Sub(&S, &input[3]). // 412: Sub(&S, &YYYY). + Double(&S) // 413: Double(&S) + + return + }, 4)) + + // combine the operations that define the assignment to p.Z + // input = [p.Z, p.Y, YY, ZZ] + // p.Z = (p.Z + p.Y)² - YY - ZZ + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"z", func(input ...fr.Element) (Z fr.Element) { + Z.Add(&input[0], &input[1]) // 415: p.Z.Add(&p.Z, &p.Y). + Z.Square(&Z) // 416: p.Z.Square(&p.Z). + Z.Sub(&Z, &input[2]) // 417: Sub(&p.Z, &YY). + Z.Sub(&Z, &input[3]) // 418: Sub(&p.Z, &ZZ) + return + }, 4)) + + // combine the operations that define the assignment to p.X + // input = [XX, S] + // p.X = 9XX² - 2S + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"x", func(input ...fr.Element) (X fr.Element) { + var M, T fr.Element + M.Double(&input[0]).Add(&M, &input[0]) // 414: M.Double(&XX).Add(&M, &XX) + T.Square(&M) // 419: T.Square(&M) + X = T // 420: p.X = T + T.Double(&input[1]) // 421: T.Double(&S) + X.Sub(&X, &T) // 422: p.X.Sub(&p.X, &T) + return + }, 2)) + + // combine the operations that define the assignment to p.Y + // input = [S, p.X, XX, YYYY] + // p.Y = (S - p.X) * 3 * XX - 8 * YYYY + assertNoError(gkrBw6761.RegisterGate(gateNamePrefix+"y", func(input ...fr.Element) (Y fr.Element) { + Y.Double(&input[2]).Add(&Y, &input[2]) // 414: M.Double(&XX).Add(&M, &XX) + input[2] = Y + + Y.Sub(&input[0], &input[1]). // 423: p.Y.Sub(&S, &p.X). + Mul(&Y, &input[2]) // 424: Mul(&p.Y, &M). + input[3].Double(&input[3]).Double(&input[3]).Double(&input[3]) // 425: YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) + Y.Sub(&Y, &input[3]) // 426: p.Y.Sub(&p.Y, &YYYY) + + return + }, 4)) + + // we have a lot of squaring operations, which we'd rather look at as single-input + assertNoError(gkrBw6761.RegisterGate("square", func(input ...fr.Element) (res fr.Element) { + res.Square(&input[0]) + return + }, 1)) + + const nbInstances = 2 + // create instances + assignment := exampleCircuit{ + X: make([]frontend.Variable, nbInstances), + Y: make([]frontend.Variable, nbInstances), + Z: make([]frontend.Variable, nbInstances), + XOut: make([]frontend.Variable, nbInstances), + YOut: make([]frontend.Variable, nbInstances), + ZOut: make([]frontend.Variable, nbInstances), + } + + for i := range nbInstances { + // create a "random" point + var b [8]byte + binary.BigEndian.PutUint64(b[:], uint64(i)) + a, err := bls12377.HashToG1(b[:], nil) + assertNoError(err) + var p bls12377.G1Jac + p.FromAffine(&a) + + assignment.X[i] = p.X + assignment.Y[i] = p.Y + assignment.Z[i] = p.Z + + p.DoubleAssign() + assignment.XOut[i] = p.X + assignment.YOut[i] = p.Y + assignment.ZOut[i] = p.Z + } + + circuit := exampleCircuit{ + X: make([]frontend.Variable, nbInstances), + Y: make([]frontend.Variable, nbInstances), + Z: make([]frontend.Variable, nbInstances), + XOut: make([]frontend.Variable, nbInstances), + YOut: make([]frontend.Variable, nbInstances), + ZOut: make([]frontend.Variable, nbInstances), + gateNamePrefix: gateNamePrefix, + fsHashName: fsHashName, + } + + // register the hash function used for verifying the GKR proof (prover side) + bw6761.RegisterHashBuilder(fsHashName, gcHash.MIMC_BW6_761.New) + + assertNoError(test.IsSolved(&circuit, &assignment, ecc.BW6_761.ScalarField())) + + // Output: +} + +type exampleCircuit struct { + X, Y, Z []frontend.Variable // Jacobian coordinates for each point (input) + XOut, YOut, ZOut []frontend.Variable // Jacobian coordinates for the double of each point (expected output) + gateNamePrefix gkr.GateName + fsHashName string // name of the hash function used for Fiat-Shamir in the GKR verifier +} + +func (c *exampleCircuit) Define(api frontend.API) error { + if len(c.X) != len(c.Y) || len(c.X) != len(c.Z) || len(c.X) != len(c.XOut) || len(c.X) != len(c.YOut) || len(c.X) != len(c.ZOut) { + return errors.New("all inputs/outputs must have the same length (i.e. the number of instances)") + } + + gkrApi := gkr.NewApi() + + assertNoError(gkr.RegisterGate("square", func(api gkr.GateAPI, input ...frontend.Variable) (res frontend.Variable) { + return api.Mul(input[0], input[0]) + }, 1)) + + // define the GKR circuit + + // create GKR circuit variables based on the given assignments + X, err := gkrApi.Import(c.X) + if err != nil { + return err + } + + Y, err := gkrApi.Import(c.Y) + if err != nil { + return err + } + + Z, err := gkrApi.Import(c.Z) + if err != nil { + return err + } + + XX := gkrApi.NamedGate("square", X) // 405: XX.Square(&p.X) TODO See if anything changes (perf-wise) if we use gkrApi.Mul(X, X) instead + YY := gkrApi.NamedGate("square", Y) // 406: YY.Square(&p.Y) + YYYY := gkrApi.NamedGate("square", YY) // 407: YYYY.Square(&YY) + ZZ := gkrApi.NamedGate("square", Z) // 408: ZZ.Square(&p.Z) + + // define the SNARK version of the custom gates, similarly to the ones in Example + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"s", func(api gkr.GateAPI, input ...frontend.Variable) (S frontend.Variable) { + S = api.Add(input[0], input[1]) // 409: S.Add(&p.X, &YY) + S = api.Mul(S, S) // 410: S.Square(&S). + S = api.Sub(S, input[2], input[3]) // 411: Sub(&S, &XX). + // 412: Sub(&S, &YYYY). + return api.Add(S, S) // 413: Double(&S) + }, 4)) + S := gkrApi.NamedGate(c.gateNamePrefix+"s", X, YY, XX, YYYY) // 409 - 413 + + // 414: M.Double(&XX).Add(&M, &XX) + // Note (but don't explicitly compute) that M = 3XX + + // combine the operations that define the assignment to p.Z + // input = [p.Z, p.Y, YY, ZZ] + // Z = (p.Z + p.Y)² - YY - ZZ + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"z", func(api gkr.GateAPI, input ...frontend.Variable) (Z frontend.Variable) { + Z = api.Add(input[0], input[1]) // 415: p.Z.Add(&p.Z, &p.Y). + Z = api.Mul(Z, Z) // 416: p.Z.Square(&p.Z). + Z = api.Sub(Z, input[2], input[3]) // 417: Sub(&p.Z, &YY). + // 418: Sub(&p.Z, &ZZ). + return + }, 4)) + Z = gkrApi.NamedGate(c.gateNamePrefix+"z", Z, Y, YY, ZZ) // 415 - 418 + + // combine the operations that define the assignment to p.X + // input = [XX, S] + // p.X = 9XX² - 2S + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"x", func(api gkr.GateAPI, input ...frontend.Variable) (X frontend.Variable) { + M := api.Mul(input[0], 3) // 414: M.Double(&XX).Add(&M, &XX) + T := api.Mul(M, M) // 419: T.Square(&M) + X = api.Sub(T, api.Mul(input[1], 2)) // 420: p.X = T + // 421: T.Double(&S) + // 422: p.X.Sub(&p.X, &T) + return + }, 2)) + X = gkrApi.NamedGate(c.gateNamePrefix+"x", XX, S) // 419-422 + + // combine the operations that define the assignment to p.Y + // input = [S, p.X, XX, YYYY] + // p.Y = (S - p.X) * 3 * XX - 8 * YYYY + assertNoError(gkr.RegisterGate(c.gateNamePrefix+"y", func(api gkr.GateAPI, input ...frontend.Variable) (Y frontend.Variable) { + Y = api.Sub(input[0], input[1]) // 423: p.Y.Sub(&S, &p.X). + Y = api.Mul(Y, input[2], 3) // 414: M.Double(&XX).Add(&M, &XX) + // 424:Mul(&p.Y, &M) + Y = api.Sub(Y, api.Mul(input[3], 8)) // 425: YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) + // 426: p.Y.Sub(&p.Y, &YYYY) + + return + }, 4)) + Y = gkrApi.NamedGate(c.gateNamePrefix+"y", S, X, XX, YYYY) // 423 - 426 + + // have to duplicate X for it to be considered an output variable + // TODO remove once https://github.com/Consensys/gnark/issues/1452 is addressed + X = gkrApi.NamedGate("identity", X) + + // register the hash function used for verification (fiat shamir) + stdHash.Register(c.fsHashName, func(api frontend.API) (stdHash.FieldHasher, error) { + m, err := mimc.NewMiMC(api) + return &m, err + }) + + // solve and prove the circuit + solution, err := gkrApi.Solve(api) + if err != nil { + return err + } + + // check the output + + XOut := solution.Export(X) + YOut := solution.Export(Y) + ZOut := solution.Export(Z) + for i := range XOut { + api.AssertIsEqual(XOut[i], c.XOut[i]) + api.AssertIsEqual(YOut[i], c.YOut[i]) + api.AssertIsEqual(ZOut[i], c.ZOut[i]) + } + + // verify the proof + return solution.Verify(c.fsHashName) +} + +func assertNoError(err error) { + if err != nil { + panic(err) + } +} diff --git a/std/gkr/gkr.go b/std/gkr/gkr.go index 5a47a2d9c2..fa6c6020c6 100644 --- a/std/gkr/gkr.go +++ b/std/gkr/gkr.go @@ -3,11 +3,12 @@ package gkr import ( "errors" "fmt" + "strconv" + "github.com/consensys/gnark/frontend" fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/polynomial" "github.com/consensys/gnark/std/sumcheck" - "strconv" ) // @tabaie TODO: Contains many things copy-pasted from gnark-crypto. Generify somehow? @@ -54,11 +55,11 @@ type GateFunction func(GateAPI, ...frontend.Variable) frontend.Variable type Gate struct { Evaluate GateFunction // Evaluate the polynomial function defining the gate nbIn int // number of inputs - degree int // total degree of f + degree int // total degree of the polynomial solvableVar int // if there is a variable whose value can be uniquely determined from the value of the gate and the other inputs, its index, -1 otherwise } -// Degree returns the total degree of the gate's polynomial i.e. Degree(xy²) = 3 +// Degree returns the total degree of the gate's polynomial e.g. Degree(xy²) = 3 func (g *Gate) Degree() int { return g.degree } @@ -108,11 +109,45 @@ func (w Wire) noProof() bool { return w.IsInput() && w.NbClaims() == 1 } +// unhashedFinalEvalProofElemIndex returns the index of a +// value in the final evaluation proof whose hashing can +// safely be skipped, due to its solvability. +// If no such value exists, it returns -1. +func (w Wire) unhashedFinalEvalProofElemIndex() int { + if w.Gate.SolvableVar() == -1 { + return -1 + } + indexInProof := 0 + visited := make(map[*Wire]struct{}, len(w.Inputs)) + for i := range w.Inputs { // it is possible in case of repeated values that this optimization + // goes to waste: for example if g := x^2 + y + z, given the input (w', w', w"). + // only y is recorded as a solvable variable, but it is already excluded from hashing because + // it is getting a repeated input. + // If we had recorded ALL solvable vars, we could have also skipped the hashing of z. + // But it is rather strange for a user to define a circuit that way. + + if _, ok := visited[w.Inputs[i]]; ok { + continue + } + + if i == w.Gate.SolvableVar() { + return indexInProof + } + + visited[w.Inputs[i]] = struct{}{} + indexInProof++ + } + return -1 +} + // WireAssignment is assignment of values to the same wire across many instances of the circuit type WireAssignment map[*Wire]polynomial.MultiLin type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) +// eqTimesGateEvalSumcheckLazyClaims is a lazy claim for sumcheck (verifier side). +// eqTimesGateEval is a polynomial consisting of ∑ᵢ cⁱ eq(-, xᵢ) w(-). +// Its purpose is to batch the checking of multiple evaluations of the same wire. type eqTimesGateEvalSumcheckLazyClaims struct { wire *Wire evaluationPoints [][]frontend.Variable @@ -120,10 +155,20 @@ type eqTimesGateEvalSumcheckLazyClaims struct { manager *claimsManager // WARNING: Circular references } -func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(api frontend.API, r []frontend.Variable, combinationCoeff, purportedValue frontend.Variable, proof interface{}) error { - inputEvaluationsNoRedundancy := proof.([]frontend.Variable) - - // the eq terms +// VerifyFinalEval finalizes the verification of w. +// The prover's claims w(xᵢ) = yᵢ have already been reduced to verifying +// ∑ cⁱ eq(xᵢ, r) w(r) = purportedValue. ( c is combinationCoeff ) +// Both purportedValue and the vector r have been randomized during the sumcheck protocol. +// By taking the w term out of the sum we get the equivalent claim that +// for E := ∑ eq(xᵢ, r), it must be that E w(r) = purportedValue. +// If w is an input wire, the verifier can directly check its evaluation at r. +// Otherwise, the prover makes claims about the evaluation of w's input wires, +// wᵢ, at r, to be verified later. +// The claims are communicated through the proof parameter. +// The verifier checks here if the claimed evaluations of wᵢ(r) are consistent with +// the main claim, by checking E w(wᵢ(r)...) = purportedValue. +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(api frontend.API, r []frontend.Variable, combinationCoeff, purportedValue frontend.Variable, inputEvaluationsNoRedundancy []frontend.Variable) error { + // the eq terms ( E ) numClaims := len(e.evaluationPoints) evaluation := polynomial.EvalEq(api, e.evaluationPoints[numClaims-1], r) for i := numClaims - 2; i >= 0; i-- { @@ -343,6 +388,19 @@ func getChallenges(transcript *fiatshamir.Transcript, names []string) (challenge return } +// getBaseChallenge returns parts of the prover's final evaluation claims +// that need to be incorporated in the Fiat-Shamir transcript. +func getBaseChallenge(wire *Wire, finalEvalProof []frontend.Variable) []frontend.Variable { + baseChallenge := make([]frontend.Variable, 0, len(finalEvalProof)) + skipHashingOf := wire.unhashedFinalEvalProofElemIndex() + for j := range finalEvalProof { + if j != skipHashingOf { + baseChallenge = append(baseChallenge, finalEvalProof[j]) + } + } + return baseChallenge +} + // Verify the consistency of the claimed output with the claimed input // Unlike in Prove, the assignment argument need not be complete func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { @@ -369,11 +427,10 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, } proofW := proof[i] - finalEvalProof := proofW.FinalEvalProof.([]frontend.Variable) claim := claims.getLazyClaim(wire) if wire.noProof() { // input wires with one claim only // make sure the proof is empty - if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + if len(proofW.FinalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { return errors.New("no proof allowed for input wire with a single claim") } @@ -385,7 +442,7 @@ func Verify(api frontend.API, c Circuit, assignment WireAssignment, proof Proof, } else if err = sumcheck.Verify( api, claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), ); err == nil { - baseChallenge = finalEvalProof + baseChallenge = getBaseChallenge(wire, proofW.FinalEvalProof) } else { return err } @@ -509,7 +566,7 @@ func (p Proof) Serialize() []frontend.Variable { for j := range p[i].PartialSumPolys { size += len(p[i].PartialSumPolys[j]) } - size += len(p[i].FinalEvalProof.([]frontend.Variable)) + size += len(p[i].FinalEvalProof) } res := make([]frontend.Variable, 0, size) @@ -517,7 +574,7 @@ func (p Proof) Serialize() []frontend.Variable { for j := range p[i].PartialSumPolys { res = append(res, p[i].PartialSumPolys[j]...) } - res = append(res, p[i].FinalEvalProof.([]frontend.Variable)...) + res = append(res, p[i].FinalEvalProof...) } if len(res) != size { panic("bug") // TODO: Remove diff --git a/std/gkr/hints.go b/std/gkr/hints.go index 8cfdaa6421..63b57b3402 100644 --- a/std/gkr/hints.go +++ b/std/gkr/hints.go @@ -2,6 +2,7 @@ package gkr import ( "errors" + "fmt" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/constraint" bls12377 "github.com/consensys/gnark/constraint/bls12-377" @@ -12,6 +13,7 @@ import ( bw6633 "github.com/consensys/gnark/constraint/bw6-633" bw6761 "github.com/consensys/gnark/constraint/bw6-761" "github.com/consensys/gnark/constraint/solver" + "hash" "math/big" ) @@ -100,3 +102,52 @@ func ProveHintPlaceholder(hashName string) solver.Hint { return errors.New("unsupported modulus") } } + +func CheckHashHint(hashName string) solver.Hint { + return func(mod *big.Int, ins, outs []*big.Int) error { + if len(ins) != 2 || len(outs) != 1 { + return errors.New("invalid number of inputs/outputs") + } + + var ( + builder func() hash.Hash + err error + ) + if mod.Cmp(ecc.BLS12_377.ScalarField()) == 0 { + builder, err = bls12377.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BLS12_381.ScalarField()) == 0 { + builder, err = bls12381.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BLS24_315.ScalarField()) == 0 { + builder, err = bls24315.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BLS24_317.ScalarField()) == 0 { + builder, err = bls24317.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BN254.ScalarField()) == 0 { + builder, err = bn254.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BW6_633.ScalarField()) == 0 { + builder, err = bw6633.GetHashBuilder(hashName) + } else if mod.Cmp(ecc.BW6_761.ScalarField()) == 0 { + builder, err = bw6761.GetHashBuilder(hashName) + } else { + return errors.New("unsupported modulus") + } + + if err != nil { + return err + } + + toHash := ins[0].Bytes() + expectedHash := ins[1] + + hsh := builder() + hsh.Write(toHash) + hashed := hsh.Sum(nil) + + if hashed := new(big.Int).SetBytes(hashed); hashed.Cmp(expectedHash) != 0 { + return fmt.Errorf("hash mismatch: expected %s, got %s", expectedHash.String(), hashed.String()) + } + + outs[0].SetBytes(hashed) + + return nil + } +} diff --git a/std/gkr/registry.go b/std/gkr/registry.go index 55dfd82f90..0ec8f4c2c0 100644 --- a/std/gkr/registry.go +++ b/std/gkr/registry.go @@ -2,8 +2,9 @@ package gkr import ( "fmt" - "github.com/consensys/gnark/frontend" "sync" + + "github.com/consensys/gnark/frontend" ) type GateName string diff --git a/std/gkr/testing.go b/std/gkr/testing.go index 1c43a65e40..dd99608b1f 100644 --- a/std/gkr/testing.go +++ b/std/gkr/testing.go @@ -3,32 +3,72 @@ package gkr import ( "errors" "fmt" + stdHash "github.com/consensys/gnark/std/hash" "math/big" "sync" "github.com/consensys/gnark-crypto/ecc" frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - gkrBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/gkr" frBls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" - gkrBls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/gkr" frBls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" - gkrBls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/gkr" frBls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" - gkrBls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/gkr" frBn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr" - gkrBn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" frBw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" - gkrBw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/gkr" frBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" - gkrBw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/gkr" hint "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" + gkrBls12377 "github.com/consensys/gnark/internal/gkr/bls12-377" + gkrBls12381 "github.com/consensys/gnark/internal/gkr/bls12-381" + gkrBls24315 "github.com/consensys/gnark/internal/gkr/bls24-315" + gkrBls24317 "github.com/consensys/gnark/internal/gkr/bls24-317" + gkrBn254 "github.com/consensys/gnark/internal/gkr/bn254" + gkrBw6633 "github.com/consensys/gnark/internal/gkr/bw6-633" + gkrBw6761 "github.com/consensys/gnark/internal/gkr/bw6-761" ) +type solveInTestEngineSettings struct { + hashName string +} + +type SolveInTestEngineOption func(*solveInTestEngineSettings) + +func WithHashName(name string) SolveInTestEngineOption { + return func(s *solveInTestEngineSettings) { + s.hashName = name + } +} + // SolveInTestEngine solves the defined circuit directly inside the SNARK circuit. This means that the method does not compute the GKR proof of the circuit and does not embed the GKR proof verifier inside a SNARK. // The output is the values of all variables, across all instances; i.e. indexed variable-first, instance-second. // This method only works under the test engine and should only be called to debug a GKR circuit, as the GKR prover's errors can be obscure. -func (api *API) SolveInTestEngine(parentApi frontend.API) [][]frontend.Variable { +func (api *API) SolveInTestEngine(parentApi frontend.API, options ...SolveInTestEngineOption) [][]frontend.Variable { + var s solveInTestEngineSettings + for _, o := range options { + o(&s) + } + if s.hashName != "" { + // hash something and make sure it gives the same answer both on prover and verifier sides + // TODO @Tabaie If indeed cheap, move this feature to Verify so that it is always run + h, err := stdHash.GetFieldHasher(s.hashName, parentApi) + if err != nil { + panic(err) + } + nbBytes := (parentApi.Compiler().FieldBitLen() + 7) / 8 + toHash := frontend.Variable(0) + for i := range nbBytes { + toHash = parentApi.Add(parentApi.Mul(toHash, 256), i%256) + } + h.Reset() + h.Write(toHash) + hashed := h.Sum() + + hintOut, err := parentApi.Compiler().NewHint(CheckHashHint(s.hashName), 1, toHash, hashed) + if err != nil { + panic(err) + } + parentApi.AssertIsEqual(hintOut[0], hashed) // the hint already checks this + } + res := make([][]frontend.Variable, len(api.toStore.Circuit)) var degreeTestedGates sync.Map for i, w := range api.toStore.Circuit { diff --git a/std/permutation/poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr.go similarity index 95% rename from std/permutation/poseidon2/gkr.go rename to std/permutation/poseidon2/gkr-poseidon2/gkr.go index 98c07b6f8e..ea135fc2c7 100644 --- a/std/permutation/poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr.go @@ -1,24 +1,26 @@ -package poseidon2 +package gkr_poseidon2 import ( "errors" "fmt" - "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2/internal" "hash" "math/big" "sync" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark-crypto/ecc" frBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" mimcBls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" poseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" - gkrPoseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2/gkrgates" "github.com/consensys/gnark/constraint" csBls12377 "github.com/consensys/gnark/constraint/bls12-377" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/gkr" stdHash "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/hash/mimc" + gkrPoseidon2Bls12377 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2/internal/bls12-377" ) // extKeyGate applies the external matrix mul, then adds the round key @@ -156,7 +158,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // poseidon2 parameters roundKeysFr := poseidon2Bls12377.GetDefaultParameters().RoundKeys - gateNamer := gkrPoseidon2Bls12377.RoundGateNamer(poseidon2Bls12377.GetDefaultParameters()) + gateNamer := internal.RoundGateNamer[gkr.GateName](poseidon2Bls12377.GetDefaultParameters()) rF := poseidon2Bls12377.GetDefaultParameters().NbFullRounds rP := poseidon2Bls12377.GetDefaultParameters().NbPartialRounds halfRf := rF / 2 @@ -198,7 +200,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // register and apply external matrix multiplication and round key addition // round dependent due to the round key extKeySBox := func(round, varI int, a, b constraint.GkrVariable) constraint.GkrVariable { - gate := gkr.GateName(gateNamer.Linear(varI, round)) + gate := gateNamer.Linear(varI, round) if err = gkr.RegisterGate(gate, extKeyGate(frToInt(&roundKeysFr[round][varI])), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return -1 } @@ -210,7 +212,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // for the second variable // round independent due to the round key intKeySBox2 := func(round int, a, b constraint.GkrVariable) constraint.GkrVariable { - gate := gkr.GateName(gateNamer.Linear(yI, round)) + gate := gateNamer.Linear(yI, round) if err = gkr.RegisterGate(gate, intKeyGate2(frToInt(&roundKeysFr[round][1])), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return -1 } @@ -234,7 +236,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. // still using the external matrix, since the linear operation still belongs to a full (canonical) round x1 := extKeySBox(halfRf, xI, x, y) - gate := gkr.GateName(gateNamer.Linear(yI, halfRf)) + gate := gateNamer.Linear(yI, halfRf) if err = gkr.RegisterGate(gate, extGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return nil, -1, err } @@ -245,7 +247,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. for i := halfRf + 1; i < halfRf+rP; i++ { x1 := extKeySBox(i, xI, x, y) // the first row of the internal matrix is the same as that of the external matrix - gate := gkr.GateName(gateNamer.Linear(yI, i)) + gate := gateNamer.Linear(yI, i) if err = gkr.RegisterGate(gate, intKeyGate2(zero), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return nil, -1, err } @@ -265,7 +267,7 @@ func defineCircuit(insLeft, insRight []frontend.Variable) (*gkr.API, constraint. } // apply the external matrix one last time to obtain the final value of y - gate := gkr.GateName(gateNamer.Linear(yI, rP+rF)) + gate := gateNamer.Linear(yI, rP+rF) if err = gkr.RegisterGate(gate, extAddGate, 3, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { return nil, -1, err } diff --git a/std/permutation/poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go similarity index 98% rename from std/permutation/poseidon2/gkr_test.go rename to std/permutation/poseidon2/gkr-poseidon2/gkr_test.go index 96a9c2494c..3fd5c42a33 100644 --- a/std/permutation/poseidon2/gkr_test.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go @@ -1,4 +1,4 @@ -package poseidon2 +package gkr_poseidon2 import ( "fmt" diff --git a/std/permutation/poseidon2/gkr-poseidon2/internal/bls12-377/gates.go b/std/permutation/poseidon2/gkr-poseidon2/internal/bls12-377/gates.go new file mode 100644 index 0000000000..67d45d704f --- /dev/null +++ b/std/permutation/poseidon2/gkr-poseidon2/internal/bls12-377/gates.go @@ -0,0 +1,218 @@ +package bls12_377 + +import ( + "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2/internal" + "sync" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + gkr "github.com/consensys/gnark/internal/gkr/bls12-377" +) + +// The GKR gates needed for proving Poseidon2 permutations + +// extKeySBoxGate applies the external matrix mul, then adds the round key, then applies the sBox +// because of its symmetry, we don't need to define distinct x1 and x2 versions of it +func extKeySBoxGate(roundKey *fr.Element) gkr.GateFunction { + return func(x ...fr.Element) fr.Element { + x[0]. + Double(&x[0]). + Add(&x[0], &x[1]). + Add(&x[0], roundKey) + return sBox2(x[0]) + } +} + +// intKeySBoxGate2 applies the second row of internal matrix mul, then adds the round key, then applies the sBox, returning the second element +func intKeySBoxGate2(roundKey *fr.Element) gkr.GateFunction { + return func(x ...fr.Element) fr.Element { + x[0].Add(&x[0], &x[1]) + x[1]. + Double(&x[1]). + Add(&x[1], &x[0]). + Add(&x[1], roundKey) + + return sBox2(x[1]) + } +} + +// extAddGate (x,y,z) -> Ext . (x,y) + z +func extAddGate(x ...fr.Element) fr.Element { + x[0]. + Double(&x[0]). + Add(&x[0], &x[1]). + Add(&x[0], &x[2]) + return x[0] +} + +// sBox2 is Permutation.sBox for t=2 +func sBox2(x fr.Element) fr.Element { + var y fr.Element + y.Square(&x).Square(&y).Square(&y).Square(&y).Mul(&x, &y) + return y +} + +// extKeyGate applies the external matrix mul, then adds the round key, then applies the sBox +// because of its symmetry, we don't need to define distinct x1 and x2 versions of it +func extKeyGate(roundKey *fr.Element) func(...fr.Element) fr.Element { + return func(x ...fr.Element) fr.Element { + x[0]. + Double(&x[0]). + Add(&x[0], &x[1]). + Add(&x[0], roundKey) + return x[0] + } +} + +// for x1, the partial round gates are identical to full round gates +// for x2, the partial round gates are just a linear combination + +// extGate2 applies the external matrix mul, outputting the second element of the result +func extGate2(x ...fr.Element) fr.Element { + x[1]. + Double(&x[1]). + Add(&x[1], &x[0]) + return x[1] +} + +// intGate2 applies the internal matrix mul, returning the second element +func intGate2(x ...fr.Element) fr.Element { + x[0].Add(&x[0], &x[1]) + x[1]. + Double(&x[1]). + Add(&x[1], &x[0]) + return x[1] +} + +// intKeyGate2 applies the second row of internal matrix mul, then adds the round key +func intKeyGate2(roundKey *fr.Element) gkr.GateFunction { + return func(x ...fr.Element) fr.Element { + x[0].Add(&x[0], &x[1]) + x[1]. + Double(&x[1]). + Add(&x[1], &x[0]). + Add(&x[1], roundKey) + + return x[1] + } +} + +// powGate4 x -> x⁴ +func pow4Gate(x ...fr.Element) fr.Element { + x[0].Square(&x[0]).Square(&x[0]) + return x[0] +} + +// pow4TimesGate x,y -> x⁴ * y +func pow4TimesGate(x ...fr.Element) fr.Element { + x[0].Square(&x[0]).Square(&x[0]).Mul(&x[0], &x[1]) + return x[0] +} + +// pow2Gate x -> x² +func pow2Gate(x ...fr.Element) fr.Element { + x[0].Square(&x[0]) + return x[0] +} + +// pow2TimesGate x,y -> x² * y +func pow2TimesGate(x ...fr.Element) fr.Element { + x[0].Square(&x[0]).Mul(&x[0], &x[1]) + return x[0] +} + +var initOnce sync.Once + +// RegisterGkrGates registers the Poseidon2 compression gates for GKR +func RegisterGkrGates() error { + const ( + x = iota + y + ) + var err error + initOnce.Do( + func() { + p := poseidon2.GetDefaultParameters() + halfRf := p.NbFullRounds / 2 + gateNames := internal.RoundGateNamer[gkr.GateName](p) + + if err = gkr.RegisterGate(internal.Pow2GateName, pow2Gate, 1, gkr.WithUnverifiedDegree(2), gkr.WithNoSolvableVar()); err != nil { + return + } + if err = gkr.RegisterGate(internal.Pow4GateName, pow4Gate, 1, gkr.WithUnverifiedDegree(4), gkr.WithNoSolvableVar()); err != nil { + return + } + if err = gkr.RegisterGate(internal.Pow2TimesGateName, pow2TimesGate, 2, gkr.WithUnverifiedDegree(3), gkr.WithNoSolvableVar()); err != nil { + return + } + if err = gkr.RegisterGate(internal.Pow4TimesGateName, pow4TimesGate, 2, gkr.WithUnverifiedDegree(5), gkr.WithNoSolvableVar()); err != nil { + return + } + + extKeySBox := func(round int, varIndex int) error { + if err := gkr.RegisterGate(gateNames.Integrated(varIndex, round), extKeySBoxGate(&p.RoundKeys[round][varIndex]), 2, gkr.WithUnverifiedDegree(poseidon2.DegreeSBox()), gkr.WithNoSolvableVar()); err != nil { + return err + } + + return gkr.RegisterGate(gateNames.Linear(varIndex, round), extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)) + } + + intKeySBox2 := func(round int) error { + if err := gkr.RegisterGate(gateNames.Linear(y, round), intKeyGate2(&p.RoundKeys[round][1]), 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return err + } + return gkr.RegisterGate(gateNames.Integrated(y, round), intKeySBoxGate2(&p.RoundKeys[round][1]), 2, gkr.WithUnverifiedDegree(poseidon2.DegreeSBox()), gkr.WithNoSolvableVar()) + } + + fullRound := func(i int) error { + if err := extKeySBox(i, x); err != nil { + return err + } + return extKeySBox(i, y) + } + + for i := range halfRf { + if err = fullRound(i); err != nil { + return + } + } + + { // i = halfRf: first partial round + if err = extKeySBox(halfRf, x); err != nil { + return + } + if err = gkr.RegisterGate(gateNames.Linear(y, halfRf), extGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return + } + } + + for i := halfRf + 1; i < halfRf+p.NbPartialRounds; i++ { + if err = extKeySBox(i, x); err != nil { // for x1, intKeySBox is identical to extKeySBox + return + } + if err = gkr.RegisterGate(gateNames.Linear(y, i), intGate2, 2, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)); err != nil { + return + } + } + + { + i := halfRf + p.NbPartialRounds + if err = extKeySBox(i, x); err != nil { + return + } + if err = intKeySBox2(i); err != nil { + return + } + } + + for i := halfRf + p.NbPartialRounds + 1; i < p.NbPartialRounds+p.NbFullRounds; i++ { + if err = fullRound(i); err != nil { + return + } + } + + err = gkr.RegisterGate(gateNames.Linear(y, p.NbPartialRounds+p.NbFullRounds), extAddGate, 3, gkr.WithUnverifiedDegree(1), gkr.WithUnverifiedSolvableVar(0)) + }, + ) + return err +} diff --git a/std/permutation/poseidon2/gkr-poseidon2/internal/commons.go b/std/permutation/poseidon2/gkr-poseidon2/internal/commons.go new file mode 100644 index 0000000000..34ad6316e6 --- /dev/null +++ b/std/permutation/poseidon2/gkr-poseidon2/internal/commons.go @@ -0,0 +1,29 @@ +package internal + +import ( + "fmt" +) + +const ( + Pow2GateName = "pow2" + Pow4GateName = "pow4" + Pow2TimesGateName = "pow2Times" + Pow4TimesGateName = "pow4Times" +) + +type roundGateNamer[T ~string] string + +// RoundGateNamer returns an object that returns standardized names for gates in the GKR circuit +func RoundGateNamer[T ~string](p fmt.Stringer) roundGateNamer[T] { + return roundGateNamer[T](p.String()) +} + +// Linear is the name of a gate where a polynomial of total degree 1 is applied to the input +func (n roundGateNamer[T]) Linear(varIndex, round int) T { + return T(fmt.Sprintf("x%d-l-op-round=%d;%s", varIndex, round, n)) +} + +// Integrated is the name of a gate where a polynomial of total degree 1 is applied to the input, followed by an S-box +func (n roundGateNamer[T]) Integrated(varIndex, round int) T { + return T(fmt.Sprintf("x%d-i-op-round=%d;%s", varIndex, round, n)) +} diff --git a/std/sumcheck/sumcheck.go b/std/sumcheck/sumcheck.go index ad96621c96..ec3f724130 100644 --- a/std/sumcheck/sumcheck.go +++ b/std/sumcheck/sumcheck.go @@ -15,13 +15,13 @@ type LazyClaims interface { VarsNum() int // VarsNum = n CombinedSum(api frontend.API, a frontend.Variable) frontend.Variable // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ Degree(i int) int //Degree of the total claim in the i'th variable - VerifyFinalEval(api frontend.API, r []frontend.Variable, combinationCoeff, purportedValue frontend.Variable, proof interface{}) error + VerifyFinalEval(api frontend.API, r []frontend.Variable, combinationCoeff, purportedValue frontend.Variable, proof []frontend.Variable) error } // Proof of a multi-sumcheck statement. type Proof struct { PartialSumPolys []polynomial.Polynomial - FinalEvalProof interface{} + FinalEvalProof []frontend.Variable } func setupTranscript(api frontend.API, claimsNum int, varsNum int, settings *fiatshamir.Settings) ([]string, error) { diff --git a/test_vector_utils/test_vector_utils.go b/test_vector_utils/test_vector_utils.go new file mode 100644 index 0000000000..3102a2133d --- /dev/null +++ b/test_vector_utils/test_vector_utils.go @@ -0,0 +1,185 @@ +// Copyright 2020-2025 Consensys Software Inc. +// Licensed under the Apache License, Version 2.0. See the LICENSE file for details. + +// Code generated by gnark DO NOT EDIT + +package test_vector_utils + +import ( + "fmt" + "github.com/consensys/gnark/internal/small_rational" + "github.com/consensys/gnark/internal/small_rational/polynomial" + "hash" + "reflect" +) + +func ToElement(i int64) *small_rational.SmallRational { + var res small_rational.SmallRational + res.SetInt64(i) + return &res +} + +type HashDescription map[string]interface{} + +func HashFromDescription(d HashDescription) (hash.Hash, error) { + if _type, ok := d["type"]; ok { + switch _type { + case "const": + startState := int64(d["val"].(float64)) + return &MessageCounter{startState: startState, step: 0, state: startState}, nil + default: + return nil, fmt.Errorf("unknown fake hash type \"%s\"", _type) + } + } + return nil, fmt.Errorf("hash description missing type") +} + +type MessageCounter struct { + startState int64 + state int64 + step int64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 + m.state += int64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 + resI := m.state + int64(inputBlockSize)*m.step + var res small_rational.SmallRational + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return small_rational.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return small_rational.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: int64(startState), state: int64(startState), step: int64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []small_rational.SmallRational + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return small_rational.Bytes +} + +func (h *ListHash) BlockSize() int { + return small_rational.Bytes +} + +func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *small_rational.SmallRational) interface{} { + if i := x.BigInt(nil); i != nil { + return i + } + return x.Text(10) +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(small_rational.SmallRational) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +}