Skip to content

Commit

Permalink
vochain fork: refactor ZkCircuits handling
Browse files Browse the repository at this point in the history
* circuit package now has a Global() circuit, and all packages (including TransactionHandler) use that
* circuit/config.go: rename `dev` -> `v0.0.1` and add voceremony as `v1.0.0`
* circuit/config.go: rename `tag` concept into `version`
* circuit/config.go: now ZkCircuitConfig has Version field, drop app.circuitConfigTag
* circuit/config.go: now PublicSignals is a property of each circuit (previously hardcoded in prover)
* prover: use PubSignals from circuit.Global() instead of hardcoded indexes
* api: /chain/info now returns circuitVersion (instead of misspelt cicuitConfigurationTag)
* apiclient: small fix, LoadZkCircuit once on NewHTTPclient instead of every Vote
* testsuite: mount zkCircuits cache dir in test container as well
* vochain/app.go: SetZkCircuit during beginBlock
* add config/forks.go
* circuit: add DownloadArtifacts funcs
  * DownloadArtifactsForChainID
  * DownloadDefaultArtifacts
* NewBaseApplication now calls circuit.DownloadDefaultArtifacts instead of transactionHandler.LoadZkCircuit
* newTendermint now calls circuit.DownloadArtifactsForChainID
  • Loading branch information
altergui committed Dec 5, 2023
1 parent e90beb4 commit 7e686dd
Show file tree
Hide file tree
Showing 20 changed files with 290 additions and 191 deletions.
28 changes: 14 additions & 14 deletions api/api_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,20 +200,20 @@ type GenericTransactionWithInfo struct {
}

type ChainInfo struct {
ID string `json:"chainId" example:"azeno"`
BlockTime [5]uint64 `json:"blockTime" example:"12000,11580,11000,11100,11100"`
ElectionCount uint64 `json:"electionCount" example:"120"`
OrganizationCount uint64 `json:"organizationCount" example:"20"`
GenesisTime time.Time `json:"genesisTime" format:"date-time" example:"2022-11-17T18:00:57.379551614Z"`
Height uint32 `json:"height" example:"5467"`
Syncing bool `json:"syncing" example:"true"`
Timestamp int64 `json:"blockTimestamp" swaggertype:"string" format:"date-time" example:"2022-11-17T18:00:57.379551614Z"`
TransactionCount uint64 `json:"transactionCount" example:"554"`
ValidatorCount uint32 `json:"validatorCount" example:"5"`
VoteCount uint64 `json:"voteCount" example:"432"`
CircuitConfigurationTag string `json:"cicuitConfigurationTag" example:"dev"`
MaxCensusSize uint64 `json:"maxCensusSize" example:"50000"`
NetworkCapacity uint64 `json:"networkCapacity" example:"2000"`
ID string `json:"chainId" example:"azeno"`
BlockTime [5]uint64 `json:"blockTime" example:"12000,11580,11000,11100,11100"`
ElectionCount uint64 `json:"electionCount" example:"120"`
OrganizationCount uint64 `json:"organizationCount" example:"20"`
GenesisTime time.Time `json:"genesisTime" format:"date-time" example:"2022-11-17T18:00:57.379551614Z"`
Height uint32 `json:"height" example:"5467"`
Syncing bool `json:"syncing" example:"true"`
Timestamp int64 `json:"blockTimestamp" swaggertype:"string" format:"date-time" example:"2022-11-17T18:00:57.379551614Z"`
TransactionCount uint64 `json:"transactionCount" example:"554"`
ValidatorCount uint32 `json:"validatorCount" example:"5"`
VoteCount uint64 `json:"voteCount" example:"432"`
CircuitVersion string `json:"circuitVersion" example:"v1.0.0"`
MaxCensusSize uint64 `json:"maxCensusSize" example:"50000"`
NetworkCapacity uint64 `json:"networkCapacity" example:"2000"`
}

type Account struct {
Expand Down
9 changes: 2 additions & 7 deletions api/censuses.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,9 @@ func (a *API) censusCreateHandler(msg *apirest.APIdata, ctx *httprouter.HTTPCont
return ErrCensusTypeUnknown
}

// get census max levels from vochain app if available
maxLevels := circuit.CircuitsConfigurations[circuit.DefaultCircuitConfigurationTag].Levels
if a.vocapp != nil {
maxLevels = a.vocapp.TransactionHandler.ZkCircuit.Config.Levels
}

// census max levels is limited by global ZkCircuit Levels
censusID := util.RandomBytes(32)
_, err = a.censusdb.New(censusID, censusType, "", &token, maxLevels)
_, err = a.censusdb.New(censusID, censusType, "", &token, circuit.Global().Config.Levels)
if err != nil {
return err
}
Expand Down
36 changes: 17 additions & 19 deletions api/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,20 +301,20 @@ func (a *API) chainInfoHandler(_ *apirest.APIdata, ctx *httprouter.HTTPContext)
}

data, err := json.Marshal(&ChainInfo{
ID: a.vocapp.ChainID(),
BlockTime: a.vocinfo.BlockTimes(),
ElectionCount: a.indexer.CountTotalProcesses(),
OrganizationCount: a.indexer.CountTotalEntities(),
Height: a.vocapp.Height(),
Syncing: a.vocapp.IsSynchronizing(),
TransactionCount: transactionCount,
ValidatorCount: uint32(len(validators)),
Timestamp: a.vocapp.Timestamp(),
VoteCount: voteCount,
GenesisTime: a.vocapp.Genesis().GenesisTime,
CircuitConfigurationTag: a.vocapp.CircuitConfigurationTag(),
MaxCensusSize: maxCensusSize,
NetworkCapacity: networkCapacity,
ID: a.vocapp.ChainID(),
BlockTime: a.vocinfo.BlockTimes(),
ElectionCount: a.indexer.CountTotalProcesses(),
OrganizationCount: a.indexer.CountTotalEntities(),
Height: a.vocapp.Height(),
Syncing: a.vocapp.IsSynchronizing(),
TransactionCount: transactionCount,
ValidatorCount: uint32(len(validators)),
Timestamp: a.vocapp.Timestamp(),
VoteCount: voteCount,
GenesisTime: a.vocapp.Genesis().GenesisTime,
CircuitVersion: circuit.Version(),
MaxCensusSize: maxCensusSize,
NetworkCapacity: networkCapacity,
})
if err != nil {
return err
Expand All @@ -329,13 +329,11 @@ func (a *API) chainInfoHandler(_ *apirest.APIdata, ctx *httprouter.HTTPContext)
// @Tags Chain
// @Accept json
// @Produce json
// @Success 200 {object} circuit.ZkCircuitConfig
// @Success 200 {object} circuit.Config
// @Router /chain/info/circuit [get]
func (a *API) chainCircuitInfoHandler(_ *apirest.APIdata, ctx *httprouter.HTTPContext) error {
// Get current circuit tag
circuitConfig := circuit.GetCircuitConfiguration(a.vocapp.CircuitConfigurationTag())
// Encode the circuit configuration to JSON
data, err := json.Marshal(circuitConfig)
// Encode the current circuit configuration to JSON
data, err := json.Marshal(circuit.Global().Config)
if err != nil {
return err
}
Expand Down
9 changes: 6 additions & 3 deletions apiclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type HTTPclient struct {
addr *url.URL
account *ethereum.SignKeys
chainID string
circuit *circuit.ZkCircuitConfig
circuit *circuit.ZkCircuit
retries int
}

Expand Down Expand Up @@ -72,8 +72,11 @@ func NewHTTPclient(addr *url.URL, bearerToken *uuid.UUID) (*HTTPclient, error) {
return nil, fmt.Errorf("cannot get chain ID from API server")
}
c.chainID = info.ID
// Get the default circuit config
c.circuit = circuit.GetCircuitConfiguration(info.CircuitConfigurationTag)

c.circuit, err = circuit.LoadVersion(info.CircuitVersion)
if err != nil {
return nil, fmt.Errorf("error loading circuit: %w", err)
}
return c, nil
}

Expand Down
10 changes: 2 additions & 8 deletions apiclient/vote.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package apiclient

import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -102,14 +101,9 @@ func (cl *HTTPclient) Vote(v *VoteData) (types.HexBytes, error) {
if err != nil {
return nil, fmt.Errorf("error encoding inputs: %w", err)
}
// load the correct circuit from the ApiClient configuration
currentCircuit, err := circuit.LoadZkCircuit(context.Background(), c.circuit)
if err != nil {
return nil, fmt.Errorf("error loading circuit: %w", err)
}
// instance the prover with the circuit config loaded and generate the
// proof for the calculated inputs
proof, err := prover.Prove(currentCircuit.ProvingKey, currentCircuit.Wasm, inputs)
proof, err := prover.Prove(c.circuit.ProvingKey, c.circuit.Wasm, inputs)
if err != nil {
return nil, fmt.Errorf("could not generate anonymous proof: %w", err)
}
Expand All @@ -119,7 +113,7 @@ func (cl *HTTPclient) Vote(v *VoteData) (types.HexBytes, error) {
return nil, err
}
// include vote nullifier and the encoded proof in a VoteEnvelope
nullifier, err := proof.Nullifier()
nullifier, err := proof.ExtractPubSignal("nullifier")
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion benchmark/zk_census_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func genProofZk(b *testing.B, electionID []byte, acc *ethereum.SignKeys, censusD
"nullifier", nullifier.String())

// Get artifacts of the current circuit
currentCircuit, err := circuit.LoadZkCircuit(context.Background(), zkCircuitTest)
currentCircuit, err := circuit.LoadConfig(context.Background(), zkCircuitTest)
qt.Assert(b, err, qt.IsNil)
// Calculate the proof for the current apiclient circuit config and the
// inputs encoded.
Expand Down
27 changes: 27 additions & 0 deletions config/forks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package config

// ForksCfg allows applying softforks at specified heights
type ForksCfg struct {
VoceremonyForkBlock uint32
}

// Forks is a map of chainIDs
var Forks = map[string]*ForksCfg{
"vocdoni/DEV/29": {
VoceremonyForkBlock: 216600, // estimated 2023-12-05T09:49:02.224626473Z
},
"vocdoni/STAGE/9": {
VoceremonyForkBlock: 247000, // estimated 2023-12-11T08:47:56.552083308Z
},
"vocdoni/LTS/1.2": {
VoceremonyForkBlock: 393000, // estimated 2023-12-11T11:51:47.046130989Z
},
}

// ForksForChainID returns the ForksCfg of chainID, if found, or an empty ForksCfg otherwise
func ForksForChainID(chainID string) *ForksCfg {
if cfg, found := Forks[chainID]; found {
return cfg
}
return &ForksCfg{}
}
112 changes: 100 additions & 12 deletions crypto/zk/circuit/circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ import (
"net/url"
"os"
"path/filepath"
"sync"
"time"

"go.vocdoni.io/dvote/config"
"go.vocdoni.io/dvote/log"
)

var downloadCircuitsTimeout = time.Minute * 5
const downloadCircuitsTimeout = time.Minute * 5

// BaseDir is where the artifact cache is expected to be found.
// If the artifacts are not found there, they will be downloaded and stored.
Expand All @@ -31,29 +33,102 @@ var BaseDir = func() string {
return filepath.Join(home, ".cache", "vocdoni", "zkCircuits")
}()

// Global circuit
var (
mtx sync.Mutex

globalCircuit = &ZkCircuit{
Config: CircuitsConfigurations[DefaultZkCircuitVersion],
}
)

// ZkCircuit struct wraps the circuit configuration and contains the file
// content of the circuit artifacts (provingKey, verificationKey and wasm)
type ZkCircuit struct {
ProvingKey []byte
VerificationKey []byte
Wasm []byte
Config *ZkCircuitConfig
Config *Config
}

// Global returns the global ZkCircuit
func Global() *ZkCircuit {
mtx.Lock()
defer mtx.Unlock()
return globalCircuit
}

// SetGlobal will LoadVersion into the global ZkCircuit
//
// If current version is already equal to the passed version, and the artifacts are loaded into memory,
// it returns immediately
func SetGlobal(version string) error {
mtx.Lock()
defer mtx.Unlock()
if globalCircuit.Version() == version && globalCircuit.IsLoaded() {
return nil
}
circuit, err := LoadVersion(version)
if err != nil {
return fmt.Errorf("could not load zk verification keys: %w", err)
}
globalCircuit = circuit
return nil
}

// Version returns the version of the global ZkCircuit
func Version() string {
return Global().Version()
}

// IsLoaded returns true if all needed keys (Proving, Verification and Wasm) are loaded into memory
func IsLoaded() bool {
return Global().IsLoaded()
}

// Init will load (or download) the default circuit artifacts into memory, ready to be used globally.
func Init() error {
return SetGlobal(DefaultZkCircuitVersion)
}

// LoadZkCircuitByTag gets the circuit configuration associated to the provided
// tag or gets the default one and load its artifacts to prepare the circuit to
// be used.
func LoadZkCircuitByTag(configTag string) (*ZkCircuit, error) {
circuitConf := GetCircuitConfiguration(configTag)
// DownloadDefaultArtifacts ensures the default circuit is cached locally
func DownloadDefaultArtifacts() error {
_, err := LoadVersion(DefaultZkCircuitVersion)
if err != nil {
return fmt.Errorf("could not load zk verification keys: %w", err)
}
return nil
}

// DownloadArtifactsForChainID ensures all circuits needed for chainID are cached locally
func DownloadArtifactsForChainID(chainID string) error {
if config.ForksForChainID(chainID).VoceremonyForkBlock > 0 {
_, err := LoadVersion(PreVoceremonyForkZkCircuitVersion)
if err != nil {
return fmt.Errorf("could not load zk verification keys: %w", err)
}
}
return DownloadDefaultArtifacts()
}

// LoadVersion loads the circuit artifacts based on the version provided.
// First, tries to load the artifacts from local storage, if they are not
// available, tries to download from their remote location.
//
// Stores the loaded circuit in the global variable, and returns it as well
func LoadVersion(version string) (*ZkCircuit, error) {
circuitConf := GetCircuitConfiguration(version)
ctx, cancel := context.WithTimeout(context.Background(), downloadCircuitsTimeout)
defer cancel()
return LoadZkCircuit(ctx, circuitConf)
return LoadConfig(ctx, circuitConf)
}

// LoadZkCircuit load the circuit artifacts based on the configuration provided.
// LoadConfig loads the circuit artifacts based on the configuration provided.
// First, tries to load the artifacts from local storage, if they are not
// available, tries to download from their remote location.
func LoadZkCircuit(ctx context.Context, config *ZkCircuitConfig) (*ZkCircuit, error) {
//
// Stores the loaded circuit in the global variable, and returns it as well
func LoadConfig(ctx context.Context, config *Config) (*ZkCircuit, error) {
circuit := &ZkCircuit{Config: config}
// load the artifacts of the provided circuit from the local storage
if err := circuit.LoadLocal(); err == nil {
Expand All @@ -77,15 +152,28 @@ func LoadZkCircuit(ctx context.Context, config *ZkCircuitConfig) (*ZkCircuit, er
if !correct {
return nil, fmt.Errorf("hashes from downloaded artifacts don't match the expected ones")
}
globalCircuit = circuit
return circuit, nil
}

// Version returns the version of the ZkCircuit
func (circuit *ZkCircuit) Version() string {
return circuit.Config.Version
}

// IsLoaded returns true if all needed keys (Proving, Verification and Wasm) are loaded into memory
func (circuit *ZkCircuit) IsLoaded() bool {
return (circuit.ProvingKey != nil &&
circuit.VerificationKey != nil &&
circuit.Wasm != nil)
}

// LoadLocal tries to read the content of current circuit artifacts from its
// local path (provingKey, verificationKey and wasm). If any of the read
// operations fails, returns an error.
func (circuit *ZkCircuit) LoadLocal() error {
var err error
log.Debugw("loading circuit locally...", "BaseDir", BaseDir)
log.Debugw("loading circuit locally...", "BaseDir", BaseDir, "version", circuit.Config.Version)
files := map[string][]byte{
circuit.Config.ProvingKeyFilename: nil,
circuit.Config.VerificationKeyFilename: nil,
Expand All @@ -112,7 +200,7 @@ func (circuit *ZkCircuit) LoadLocal() error {
// remote location. If any of the downloads fails, returns an error.
func (circuit *ZkCircuit) LoadRemote(ctx context.Context) error {
log.Debugw("circuit not downloaded yet, downloading...",
"BaseDir", BaseDir)
"BaseDir", BaseDir, "version", circuit.Config.Version)
baseUri, err := url.Parse(circuit.Config.URI)
if err != nil {
return err
Expand Down
Loading

0 comments on commit 7e686dd

Please sign in to comment.