From 7a97f7a94c2ebf678c9a556b5c050f5e90efb5ce Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 31 May 2024 16:36:41 +0400 Subject: [PATCH 01/62] sql: improve database schema handling Use consistent package/folder structure for local.sql and state.sql databases. When a new database is created, use schema script instead of running all the migrations. Also, check that there's no schema drift by diffing database schema against the schema script after running migrations. --- activation/activation_test.go | 3 +- activation/certifier_test.go | 7 +- activation/e2e/activation_test.go | 4 +- activation/e2e/certifier_client_test.go | 4 +- activation/e2e/nipost_test.go | 6 +- activation/e2e/validation_test.go | 6 +- activation/handler_test.go | 3 +- activation/handler_v1_test.go | 3 +- activation/handler_v2_test.go | 3 +- activation/poetdb.go | 5 +- activation/poetdb_test.go | 10 +- activation/post_supervisor_test.go | 4 +- activation/post_test.go | 4 +- activation/validation_test.go | 6 +- activation/verify_state_test.go | 4 +- api/grpcserver/activation_service_test.go | 3 +- api/grpcserver/admin_service.go | 6 +- api/grpcserver/admin_service_test.go | 11 +- api/grpcserver/debug_service.go | 6 +- api/grpcserver/grpcserver_test.go | 34 +-- api/grpcserver/http_server_test.go | 4 +- api/grpcserver/mesh_service_test.go | 7 +- api/grpcserver/post_service_test.go | 6 +- api/grpcserver/transaction_service.go | 6 +- api/grpcserver/transaction_service_test.go | 7 +- api/grpcserver/v2alpha1/account_test.go | 4 +- api/grpcserver/v2alpha1/activation_test.go | 8 +- api/grpcserver/v2alpha1/layer_test.go | 8 +- api/grpcserver/v2alpha1/reward_test.go | 6 +- api/grpcserver/v2alpha1/transaction_test.go | 13 +- atxsdata/warmup.go | 3 +- atxsdata/warmup_test.go | 9 +- beacon/beacon_test.go | 15 +- blocks/certifier.go | 5 +- blocks/certifier_test.go | 5 +- blocks/generator.go | 6 +- blocks/generator_test.go | 5 +- blocks/handler.go | 12 +- blocks/handler_test.go | 4 +- blocks/utils.go | 6 +- blocks/utils_test.go | 4 +- checkpoint/recovery.go | 17 +- checkpoint/recovery_test.go | 36 ++-- checkpoint/runner.go | 6 +- checkpoint/runner_test.go | 13 +- cmd/activeset/activeset.go | 4 +- cmd/bootstrapper/generator_test.go | 5 +- cmd/bootstrapper/server_test.go | 4 +- cmd/merge-nodes/internal/errors.go | 5 +- cmd/merge-nodes/internal/merge_action.go | 22 +- cmd/merge-nodes/internal/merge_action_test.go | 23 +- datastore/store_test.go | 23 +- fetch/fetch_test.go | 8 +- fetch/handler_test.go | 5 +- fetch/mesh_data_test.go | 4 +- fetch/p2p_test.go | 9 +- genvm/core/context_test.go | 20 +- genvm/core/staged_cache_test.go | 6 +- genvm/templates/vault/vault_test.go | 6 +- genvm/vm.go | 5 +- genvm/vm_test.go | 8 +- go.mod | 1 + go.sum | 2 + hare3/eligibility/oracle_test.go | 5 +- hare3/hare.go | 5 +- hare3/hare_test.go | 8 +- malfeasance/handler_test.go | 21 +- mesh/executor_test.go | 5 +- mesh/mesh.go | 5 +- mesh/mesh_test.go | 5 +- miner/active_set_generator_test.go | 5 +- miner/proposal_builder_test.go | 6 +- node/node.go | 24 +-- node/node_version_check_test.go | 9 +- proposals/handler.go | 5 +- proposals/handler_test.go | 5 +- prune/prune.go | 6 +- prune/prune_test.go | 3 +- sql/accounts/accounts_test.go | 13 +- sql/activesets/activesets_test.go | 5 +- sql/atxs/atxs_test.go | 53 ++--- sql/ballots/ballots_test.go | 17 +- sql/beacons/beacons_test.go | 7 +- sql/blocks/blocks_test.go | 27 +-- sql/certificates/certs_test.go | 11 +- sql/database.go | 198 ++++++++---------- sql/database_test.go | 59 +++--- sql/identities/identities_test.go | 7 +- sql/layers/layers_test.go | 17 +- sql/localsql/local.go | 38 ---- sql/localsql/local_test.go | 41 ---- sql/localsql/localsql.go | 59 ++++++ sql/localsql/localsql_test.go | 74 +++++++ .../schema/migrations}/0001_initial.sql | 0 .../migrations}/0002_extend_initial_post.sql | 0 .../0003_add_nipost_builder_state.sql | 0 .../schema/migrations}/0004_atx_sync.sql | 0 .../schema/migrations}/0005_fast_startup.sql | 0 .../migrations}/0006_prepared_activeset.sql | 0 .../migrations}/0007_malfeasance_sync.sql | 0 .../schema/migrations}/0008_next.sql | 0 sql/localsql/schema/schema.sql | 82 ++++++++ sql/metrics/prometheus.go | 5 +- sql/migrations.go | 48 +++-- sql/migrations_test.go | 26 --- sql/poets/poets_test.go | 9 +- sql/recovery/recovery_test.go | 4 +- sql/rewards/rewards_test.go | 31 ++- sql/schema.go | 188 +++++++++++++++++ .../schema/migrations}/0001_initial.sql | 0 .../schema/migrations}/0002_v1.0.3.sql | 0 .../schema/migrations}/0003_v1.1.5.sql | 0 .../schema/migrations}/0004_v1.1.7.sql | 0 .../schema/migrations}/0005_v1.2.0.sql | 0 .../schema/migrations}/0006_v1.2.2.sql | 0 .../schema/migrations}/0007_v1.3.0.sql | 0 .../schema/migrations}/0008_rewards.sql | 0 .../migrations}/0009_prune_activesets.sql | 0 .../schema/migrations}/0010_rowid.sql | 0 .../migrations}/0011_atxs_extra_index.sql | 0 .../schema/migrations}/0012_atx_validity.sql | 0 .../migrations}/0013_atx_coinbase_index.sql | 0 .../migrations}/0014_remove_proposals.sql | 0 .../schema/migrations}/0015_nonce_index.sql | 0 .../schema/migrations}/0016_atx_blob.sql | 0 .../0017_atxs_prev_id_nonce_placeholder.sql | 0 .../migrations}/0018_atx_blob_version.sql | 0 sql/statesql/schema/schema.sql | 151 +++++++++++++ sql/statesql/statesql.go | 58 +++++ sql/statesql/statesql_test.go | 72 +++++++ sql/transactions/iterator_test.go | 5 +- sql/transactions/transactions_test.go | 31 +-- syncer/atxsync/atxsync.go | 6 +- syncer/atxsync/atxsync_test.go | 4 +- syncer/atxsync/syncer_test.go | 6 +- syncer/find_fork_test.go | 8 +- syncer/malsync/syncer_test.go | 6 +- syncer/syncer_test.go | 4 +- .../distributed_post_verification_test.go | 6 +- tortoise/model/core.go | 3 +- tortoise/replay/replay_test.go | 4 +- tortoise/sim/utils.go | 8 +- tortoise/threshold_test.go | 4 +- tortoise/tortoise_test.go | 3 +- txs/cache.go | 27 +-- txs/cache_test.go | 15 +- txs/conservative_state.go | 6 +- txs/conservative_state_test.go | 6 +- 148 files changed, 1353 insertions(+), 723 deletions(-) delete mode 100644 sql/localsql/local.go delete mode 100644 sql/localsql/local_test.go create mode 100644 sql/localsql/localsql.go create mode 100644 sql/localsql/localsql_test.go rename sql/{migrations/local => localsql/schema/migrations}/0001_initial.sql (100%) rename sql/{migrations/local => localsql/schema/migrations}/0002_extend_initial_post.sql (100%) rename sql/{migrations/local => localsql/schema/migrations}/0003_add_nipost_builder_state.sql (100%) rename sql/{migrations/local => localsql/schema/migrations}/0004_atx_sync.sql (100%) rename sql/{migrations/local => localsql/schema/migrations}/0005_fast_startup.sql (100%) rename sql/{migrations/local => localsql/schema/migrations}/0006_prepared_activeset.sql (100%) rename sql/{migrations/local => localsql/schema/migrations}/0007_malfeasance_sync.sql (100%) rename sql/{migrations/local => localsql/schema/migrations}/0008_next.sql (100%) create mode 100755 sql/localsql/schema/schema.sql delete mode 100644 sql/migrations_test.go create mode 100644 sql/schema.go rename sql/{migrations/state => statesql/schema/migrations}/0001_initial.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0002_v1.0.3.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0003_v1.1.5.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0004_v1.1.7.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0005_v1.2.0.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0006_v1.2.2.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0007_v1.3.0.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0008_rewards.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0009_prune_activesets.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0010_rowid.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0011_atxs_extra_index.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0012_atx_validity.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0013_atx_coinbase_index.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0014_remove_proposals.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0015_nonce_index.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0016_atx_blob.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0017_atxs_prev_id_nonce_placeholder.sql (100%) rename sql/{migrations/state => statesql/schema/migrations}/0018_atx_blob_version.sql (100%) create mode 100755 sql/statesql/schema/schema.sql create mode 100644 sql/statesql/statesql.go create mode 100644 sql/statesql/statesql_test.go diff --git a/activation/activation_test.go b/activation/activation_test.go index 72001808cd..f43bdaa6eb 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -34,6 +34,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" sqlmocks "github.com/spacemeshos/go-spacemesh/sql/mocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) // ========== Vars / Consts ========== @@ -78,7 +79,7 @@ func newTestBuilder(tb testing.TB, numSigners int, opts ...BuilderOption) *testA ctrl := gomock.NewController(tb) tab := &testAtxBuilder{ - db: sql.InMemory(), + db: statesql.InMemory(), localDb: localsql.InMemory(sql.WithConnections(numSigners)), goldenATXID: types.ATXID(types.HexToHash32("77777")), diff --git a/activation/certifier_test.go b/activation/certifier_test.go index cc329737da..78c4ca405b 100644 --- a/activation/certifier_test.go +++ b/activation/certifier_test.go @@ -16,6 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql" certdb "github.com/spacemeshos/go-spacemesh/sql/localsql/certifier" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestPersistsCerts(t *testing.T) { @@ -113,7 +114,7 @@ func TestObtainingPost(t *testing.T) { id := types.RandomNodeID() t.Run("no POST or ATX", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() certifier := NewCertifierClient(db, localDb, zaptest.NewLogger(t)) @@ -121,7 +122,7 @@ func TestObtainingPost(t *testing.T) { require.ErrorContains(t, err, "PoST not found") }) t.Run("initial POST available", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() post := nipost.Post{ @@ -142,7 +143,7 @@ func TestObtainingPost(t *testing.T) { require.Equal(t, post, *got) }) t.Run("initial POST unavailable but ATX exists", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() atx := newInitialATXv1(t, types.RandomATXID()) diff --git a/activation/e2e/activation_test.go b/activation/e2e/activation_test.go index 3333014e9e..d3329e2204 100644 --- a/activation/e2e/activation_test.go +++ b/activation/e2e/activation_test.go @@ -28,9 +28,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -52,7 +52,7 @@ func Test_BuilderWithMultipleClients(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := activation.DefaultPostConfig() - db := sql.InMemory() + db := statesql.InMemory() localDB := localsql.InMemory() syncer := activation.NewMocksyncer(ctrl) diff --git a/activation/e2e/certifier_client_test.go b/activation/e2e/certifier_client_test.go index 26da2c7df4..6ccc006476 100644 --- a/activation/e2e/certifier_client_test.go +++ b/activation/e2e/certifier_client_test.go @@ -25,9 +25,9 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestCertification(t *testing.T) { @@ -36,7 +36,7 @@ func TestCertification(t *testing.T) { logger := zaptest.NewLogger(t) cfg := activation.DefaultPostConfig() - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() syncer := activation.NewMocksyncer(gomock.NewController(t)) diff --git a/activation/e2e/nipost_test.go b/activation/e2e/nipost_test.go index fb195bc308..c470f75dd3 100644 --- a/activation/e2e/nipost_test.go +++ b/activation/e2e/nipost_test.go @@ -26,9 +26,9 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -134,7 +134,7 @@ func TestNIPostBuilderWithClients(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := activation.DefaultPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) localDb := localsql.InMemory() @@ -263,7 +263,7 @@ func Test_NIPostBuilderWithMultipleClients(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := activation.DefaultPostConfig() - db := sql.InMemory() + db := statesql.InMemory() syncer := activation.NewMocksyncer(ctrl) syncer.EXPECT().RegisterForATXSynced().AnyTimes().DoAndReturn(func() <-chan struct{} { diff --git a/activation/e2e/validation_test.go b/activation/e2e/validation_test.go index ab17b39a82..8ea6d620a9 100644 --- a/activation/e2e/validation_test.go +++ b/activation/e2e/validation_test.go @@ -17,8 +17,8 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestValidator_Validate(t *testing.T) { @@ -30,7 +30,7 @@ func TestValidator_Validate(t *testing.T) { logger := zaptest.NewLogger(t) goldenATX := types.ATXID{2, 3, 4} cfg := activation.DefaultPostConfig() - db := sql.InMemory() + db := statesql.InMemory() validator := activation.NewMocknipostValidator(gomock.NewController(t)) syncer := activation.NewMocksyncer(gomock.NewController(t)) @@ -67,7 +67,7 @@ func TestValidator_Validate(t *testing.T) { WithPhaseShift(poetCfg.PhaseShift), WithCycleGap(poetCfg.CycleGap), ) - poetDb := activation.NewPoetDb(sql.InMemory(), logger.Named("poetDb")) + poetDb := activation.NewPoetDb(statesql.InMemory(), logger.Named("poetDb")) client, err := activation.NewPoetClient( poetDb, types.PoetServer{Address: poetProver.RestURL().String()}, diff --git a/activation/handler_test.go b/activation/handler_test.go index 6176a6a087..4b51384de9 100644 --- a/activation/handler_test.go +++ b/activation/handler_test.go @@ -33,6 +33,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -194,7 +195,7 @@ func newTestHandlerMocks(tb testing.TB, golden types.ATXID) handlerMocks { func newTestHandler(tb testing.TB, goldenATXID types.ATXID, opts ...HandlerOption) *testHandler { lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), lg) + cdb := datastore.NewCachedDB(statesql.InMemory(), lg) edVerifier := signing.NewEdVerifier() mocks := newTestHandlerMocks(tb, goldenATXID) diff --git a/activation/handler_v1_test.go b/activation/handler_v1_test.go index c96f239907..91acacc4f6 100644 --- a/activation/handler_v1_test.go +++ b/activation/handler_v1_test.go @@ -24,6 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type v1TestHandler struct { @@ -34,7 +35,7 @@ type v1TestHandler struct { func newV1TestHandler(tb testing.TB, goldenATXID types.ATXID) *v1TestHandler { lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), lg) + cdb := datastore.NewCachedDB(statesql.InMemory(), lg) mocks := newTestHandlerMocks(tb, goldenATXID) return &v1TestHandler{ HandlerV1: &HandlerV1{ diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 708e271910..3297b3676f 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -21,6 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type v2TestHandler struct { @@ -31,7 +32,7 @@ type v2TestHandler struct { func newV2TestHandler(tb testing.TB, golden types.ATXID) *v2TestHandler { lg := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), lg) + cdb := datastore.NewCachedDB(statesql.InMemory(), lg) mocks := newTestHandlerMocks(tb, golden) return &v2TestHandler{ HandlerV2: &HandlerV2{ diff --git a/activation/poetdb.go b/activation/poetdb.go index a82cda5e28..bc5671854d 100644 --- a/activation/poetdb.go +++ b/activation/poetdb.go @@ -17,18 +17,19 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/poets" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) var ErrObjectExists = sql.ErrObjectExists // PoetDb is a database for PoET proofs. type PoetDb struct { - sqlDB *sql.Database + sqlDB *statesql.Database logger *zap.Logger } // NewPoetDb returns a new PoET handler. -func NewPoetDb(db *sql.Database, log *zap.Logger) *PoetDb { +func NewPoetDb(db *statesql.Database, log *zap.Logger) *PoetDb { return &PoetDb{sqlDB: db, logger: log} } diff --git a/activation/poetdb_test.go b/activation/poetdb_test.go index 1b028c1da8..26e89aac0a 100644 --- a/activation/poetdb_test.go +++ b/activation/poetdb_test.go @@ -16,7 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) var ( @@ -63,7 +63,7 @@ func getPoetProof(t *testing.T) types.PoetProofMessage { func TestPoetDbHappyFlow(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) r.NoError(poetDb.Validate(msg.Statement[:], msg.PoetProof, msg.PoetServiceID, msg.RoundID, types.EmptyEdSignature)) ref, err := msg.Ref() @@ -83,7 +83,7 @@ func TestPoetDbHappyFlow(t *testing.T) { func TestPoetDbInvalidPoetProof(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) msg.PoetProof.Root = []byte("some other root") err := poetDb.Validate(msg.Statement[:], msg.PoetProof, msg.PoetServiceID, msg.RoundID, types.EmptyEdSignature) @@ -99,7 +99,7 @@ func TestPoetDbInvalidPoetProof(t *testing.T) { func TestPoetDbInvalidPoetStatement(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) msg.Statement = types.CalcHash32([]byte("some other statement")) err := poetDb.Validate(msg.Statement[:], msg.PoetProof, msg.PoetServiceID, msg.RoundID, types.EmptyEdSignature) @@ -115,7 +115,7 @@ func TestPoetDbInvalidPoetStatement(t *testing.T) { func TestPoetDbNonExistingKeys(t *testing.T) { r := require.New(t) msg := getPoetProof(t) - poetDb := NewPoetDb(sql.InMemory(), zaptest.NewLogger(t)) + poetDb := NewPoetDb(statesql.InMemory(), zaptest.NewLogger(t)) _, err := poetDb.GetProofRef(msg.PoetServiceID, "0") r.EqualError( diff --git a/activation/post_supervisor_test.go b/activation/post_supervisor_test.go index f26021a364..7c21cbf4c3 100644 --- a/activation/post_supervisor_test.go +++ b/activation/post_supervisor_test.go @@ -24,7 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func closedChan() <-chan struct{} { @@ -56,7 +56,7 @@ func newPostManager(t *testing.T, cfg PostConfig, opts PostSetupOpts) *PostSetup close(ch) return ch }) - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() mgr, err := NewPostSetupManager(cfg, zaptest.NewLogger(t), db, atxsdata, types.RandomATXID(), syncer, validator) require.NoError(t, err) diff --git a/activation/post_test.go b/activation/post_test.go index c0273a6369..6b40665a3e 100644 --- a/activation/post_test.go +++ b/activation/post_test.go @@ -17,8 +17,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestPostSetupManager(t *testing.T) { @@ -365,7 +365,7 @@ func newTestPostManager(tb testing.TB) *testPostManager { syncer.EXPECT().RegisterForATXSynced().AnyTimes().Return(synced) logger := zaptest.NewLogger(tb) - cdb := datastore.NewCachedDB(sql.InMemory(), logger) + cdb := datastore.NewCachedDB(statesql.InMemory(), logger) mgr, err := NewPostSetupManager(DefaultPostConfig(), logger, cdb, atxsdata.New(), goldenATXID, syncer, validator) require.NoError(tb, err) diff --git a/activation/validation_test.go b/activation/validation_test.go index 13af35c306..730c4308fe 100644 --- a/activation/validation_test.go +++ b/activation/validation_test.go @@ -16,8 +16,8 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_Validation_VRFNonce(t *testing.T) { @@ -476,7 +476,7 @@ func TestValidateMerkleProof(t *testing.T) { } func TestVerifyChainDeps(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() goldenATXID := types.ATXID{2, 3, 4} signer, err := signing.NewEdSigner() @@ -579,7 +579,7 @@ func TestVerifyChainDeps(t *testing.T) { } func TestVerifyChainDepsAfterCheckpoint(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() goldenATXID := types.ATXID{2, 3, 4} signer, err := signing.NewEdSigner() diff --git a/activation/verify_state_test.go b/activation/verify_state_test.go index b0e46a78c8..a2c0229783 100644 --- a/activation/verify_state_test.go +++ b/activation/verify_state_test.go @@ -11,13 +11,13 @@ import ( "github.com/spacemeshos/go-spacemesh/activation/wire" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_CheckPrevATXs(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() logger := zaptest.NewLogger(t) // Arrange diff --git a/api/grpcserver/activation_service_test.go b/api/grpcserver/activation_service_test.go index 2cf5ad0a3c..f79d12576f 100644 --- a/api/grpcserver/activation_service_test.go +++ b/api/grpcserver/activation_service_test.go @@ -17,6 +17,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test_Highest_ReturnsGoldenAtxOnError(t *testing.T) { @@ -137,7 +138,7 @@ func TestGet_IdentityCanceled(t *testing.T) { atxProvider := grpcserver.NewMockatxProvider(ctrl) activationService := grpcserver.NewActivationService(atxProvider, types.ATXID{1}) - smesher, proof := grpcserver.BallotMalfeasance(t, sql.InMemory()) + smesher, proof := grpcserver.BallotMalfeasance(t, statesql.InMemory()) id := types.RandomATXID() atx := types.ActivationTx{ Sequence: rand.Uint64(), diff --git a/api/grpcserver/admin_service.go b/api/grpcserver/admin_service.go index ba2109a587..010dbf8487 100644 --- a/api/grpcserver/admin_service.go +++ b/api/grpcserver/admin_service.go @@ -22,7 +22,7 @@ import ( "github.com/spacemeshos/go-spacemesh/checkpoint" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -32,14 +32,14 @@ const ( // AdminService exposes endpoints for node administration. type AdminService struct { - db *sql.Database + db *statesql.Database dataDir string recover func() p peers } // NewAdminService creates a new admin grpc service. -func NewAdminService(db *sql.Database, dataDir string, p peers) *AdminService { +func NewAdminService(db *statesql.Database, dataDir string, p peers) *AdminService { return &AdminService{ db: db, dataDir: dataDir, diff --git a/api/grpcserver/admin_service_test.go b/api/grpcserver/admin_service_test.go index 46591774b2..f82c7f7f29 100644 --- a/api/grpcserver/admin_service_test.go +++ b/api/grpcserver/admin_service_test.go @@ -15,11 +15,12 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const snapshot uint32 = 15 -func newAtx(tb testing.TB, db *sql.Database) { +func newAtx(tb testing.TB, db *statesql.Database) { atx := &types.ActivationTx{ PublishEpoch: types.EpochID(2), Sequence: 0, @@ -36,7 +37,7 @@ func newAtx(tb testing.TB, db *sql.Database) { require.NoError(tb, atxs.Add(db, atx)) } -func createMesh(tb testing.TB, db *sql.Database) { +func createMesh(tb testing.TB, db *statesql.Database) { for range 10 { newAtx(tb, db) } @@ -52,7 +53,7 @@ func createMesh(tb testing.TB, db *sql.Database) { } func TestAdminService_Checkpoint(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() createMesh(t, db) svc := NewAdminService(db, t.TempDir(), nil) cfg, cleanup := launchServer(t, svc) @@ -89,7 +90,7 @@ func TestAdminService_Checkpoint(t *testing.T) { } func TestAdminService_CheckpointError(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() svc := NewAdminService(db, t.TempDir(), nil) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -106,7 +107,7 @@ func TestAdminService_CheckpointError(t *testing.T) { } func TestAdminService_Recovery(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() recoveryCalled := atomic.Bool{} svc := NewAdminService(db, t.TempDir(), nil) svc.recover = func() { recoveryCalled.Store(true) } diff --git a/api/grpcserver/debug_service.go b/api/grpcserver/debug_service.go index d8731c4920..c9b2637631 100644 --- a/api/grpcserver/debug_service.go +++ b/api/grpcserver/debug_service.go @@ -18,13 +18,13 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) // DebugService exposes global state data, output from the STF. type DebugService struct { - db *sql.Database + db *statesql.Database conState conservativeState netInfo networkInfo oracle oracle @@ -46,7 +46,7 @@ func (d DebugService) String() string { } // NewDebugService creates a new grpc service using config data. -func NewDebugService(db *sql.Database, conState conservativeState, host networkInfo, oracle oracle, +func NewDebugService(db *statesql.Database, conState conservativeState, host networkInfo, oracle oracle, loggers map[string]*zap.AtomicLevel, ) *DebugService { return &DebugService{ diff --git a/api/grpcserver/grpcserver_test.go b/api/grpcserver/grpcserver_test.go index dfdd6c4977..c1dd3d2b55 100644 --- a/api/grpcserver/grpcserver_test.go +++ b/api/grpcserver/grpcserver_test.go @@ -45,11 +45,11 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" pubsubmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" "github.com/spacemeshos/go-spacemesh/txs" ) @@ -722,7 +722,7 @@ func TestMeshService(t *testing.T) { genesis := time.Unix(genTimeUnix, 0) genTime.EXPECT().GenesisTime().Return(genesis) genTime.EXPECT().CurrentLayer().Return(layerCurrent).AnyTimes() - db := datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)) + db := datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)) svc := NewMeshService( db, meshAPIMock, @@ -1266,7 +1266,7 @@ func TestTransactionServiceSubmitUnsync(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil) - svc := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + svc := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -1305,7 +1305,7 @@ func TestTransactionServiceSubmitInvalidTx(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(errors.New("failed validation")) - grpcService := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + grpcService := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, grpcService) t.Cleanup(cleanup) @@ -1338,7 +1338,7 @@ func TestTransactionService_SubmitNoConcurrency(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).Times(numTxs) - grpcService := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + grpcService := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, grpcService) t.Cleanup(cleanup) @@ -1366,7 +1366,7 @@ func TestTransactionService(t *testing.T) { txHandler := NewMocktxValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - grpcService := NewTransactionService(sql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) + grpcService := NewTransactionService(statesql.InMemory(), publisher, meshAPIMock, conStateAPI, syncer, txHandler) cfg, cleanup := launchServer(t, grpcService) t.Cleanup(cleanup) @@ -1680,7 +1680,7 @@ func TestAccountMeshDataStream_comprehensive(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) grpcService := NewMeshService( - datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)), + datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)), meshAPIMock, conStateAPI, genTime, @@ -1862,7 +1862,7 @@ func TestLayerStream_comprehensive(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)) + db := datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)) grpcService := NewMeshService( db, @@ -2008,7 +2008,7 @@ func TestMultiService(t *testing.T) { genTime.EXPECT().GenesisTime().Return(genesis) svc1 := NewNodeService(peerCounter, meshAPIMock, genTime, syncer, "v0.0.0", "cafebabe") svc2 := NewMeshService( - datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)), + datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)), meshAPIMock, conStateAPI, genTime, @@ -2055,7 +2055,7 @@ func TestDebugService(t *testing.T) { ctrl := gomock.NewController(t) netInfo := NewMocknetworkInfo(ctrl) mOracle := NewMockoracle(ctrl) - db := sql.InMemory() + db := statesql.InMemory() testLog := zap.NewAtomicLevel() loggers := map[string]*zap.AtomicLevel{ @@ -2229,7 +2229,7 @@ func TestEventsReceived(t *testing.T) { events.InitializeReporter() t.Cleanup(events.CloseEventReporter) - txService := NewTransactionService(sql.InMemory(), nil, meshAPIMock, conStateAPI, nil, nil) + txService := NewTransactionService(statesql.InMemory(), nil, meshAPIMock, conStateAPI, nil, nil) gsService := NewGlobalStateService(meshAPIMock, conStateAPI) cfg, cleanup := launchServer(t, txService, gsService) t.Cleanup(cleanup) @@ -2279,8 +2279,8 @@ func TestEventsReceived(t *testing.T) { time.Sleep(time.Millisecond * 50) lg := logtest.New(t) - svm := vm.New(sql.InMemory(), vm.WithLogger(lg)) - conState := txs.NewConservativeState(svm, sql.InMemory(), txs.WithLogger(lg.Zap().Named("conState"))) + svm := vm.New(statesql.InMemory(), vm.WithLogger(lg)) + conState := txs.NewConservativeState(svm, statesql.InMemory(), txs.WithLogger(lg.Zap().Named("conState"))) conState.AddToCache(context.Background(), globalTx, time.Now()) weight := new(big.Rat).SetFloat64(18.7) @@ -2343,7 +2343,7 @@ func TestTransactionsRewards(t *testing.T) { req.NoError(err, "stream request returned unexpected error") time.Sleep(50 * time.Millisecond) - svm := vm.New(sql.InMemory(), vm.WithLogger(logtest.New(t))) + svm := vm.New(statesql.InMemory(), vm.WithLogger(logtest.New(t))) _, _, err = svm.Apply(vm.ApplyContext{Layer: types.LayerID(17)}, []types.Transaction{*globalTx}, rewards) req.NoError(err) @@ -2364,7 +2364,7 @@ func TestTransactionsRewards(t *testing.T) { req.NoError(err, "stream request returned unexpected error") time.Sleep(50 * time.Millisecond) - svm := vm.New(sql.InMemory(), vm.WithLogger(logtest.New(t))) + svm := vm.New(statesql.InMemory(), vm.WithLogger(logtest.New(t))) _, _, err = svm.Apply(vm.ApplyContext{Layer: types.LayerID(17)}, []types.Transaction{*globalTx}, rewards) req.NoError(err) @@ -2383,7 +2383,7 @@ func TestVMAccountUpdates(t *testing.T) { events.InitializeReporter() // in memory database doesn't allow reads while writer locked db - db, err := sql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) + db, err := statesql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) require.NoError(t, err) t.Cleanup(func() { db.Close() }) svm := vm.New(db, vm.WithLogger(logtest.New(t))) @@ -2479,7 +2479,7 @@ func createAtxs(tb testing.TB, epoch types.EpochID, atxids []types.ATXID) []*typ func TestMeshService_EpochStream(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), meshAPIMock, diff --git a/api/grpcserver/http_server_test.go b/api/grpcserver/http_server_test.go index 06cb018b53..b31d502b0d 100644 --- a/api/grpcserver/http_server_test.go +++ b/api/grpcserver/http_server_test.go @@ -18,7 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func launchJsonServer(tb testing.TB, services ...ServiceAPI) (Config, func()) { @@ -65,7 +65,7 @@ func TestJsonApi(t *testing.T) { conStateAPI := NewMockconservativeState(ctrl) svc1 := NewNodeService(peerCounter, meshAPIMock, genTime, syncer, version, build) svc2 := NewMeshService( - datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)), + datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)), meshAPIMock, conStateAPI, genTime, diff --git a/api/grpcserver/mesh_service_test.go b/api/grpcserver/mesh_service_test.go index 6acb9465aa..9ce543173f 100644 --- a/api/grpcserver/mesh_service_test.go +++ b/api/grpcserver/mesh_service_test.go @@ -22,6 +22,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -142,7 +143,7 @@ func HareMalfeasance(tb testing.TB, db sql.Executor) (types.NodeID, *wire.Malfea func TestMeshService_MalfeasanceQuery(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), meshAPIMock, @@ -195,7 +196,7 @@ func TestMeshService_MalfeasanceStream(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), meshAPIMock, @@ -301,7 +302,7 @@ func (t *ConStateAPIMockInstrumented) GetLayerStateRoot(types.LayerID) (types.Ha func TestReadLayer(t *testing.T) { ctrl := gomock.NewController(t) genTime := NewMockgenesisTimeAPI(ctrl) - db := sql.InMemory() + db := statesql.InMemory() srv := NewMeshService( datastore.NewCachedDB(db, zaptest.NewLogger(t)), &MeshAPIMockInstrumented{}, diff --git a/api/grpcserver/post_service_test.go b/api/grpcserver/post_service_test.go index f3fddd506b..464ef984a1 100644 --- a/api/grpcserver/post_service_test.go +++ b/api/grpcserver/post_service_test.go @@ -23,7 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func launchPostSupervisor( @@ -58,7 +58,7 @@ func launchPostSupervisor( close(ch) return ch }) - db := sql.InMemory() + db := statesql.InMemory() logger := log.Named("post manager") mgr, err := activation.NewPostSetupManager(postCfg, logger, db, atxsdata.New(), goldenATXID, syncer, validator) require.NoError(tb, err) @@ -102,7 +102,7 @@ func launchPostSupervisorTLS( close(ch) return ch }) - db := sql.InMemory() + db := statesql.InMemory() logger := log.Named("post supervisor") mgr, err := activation.NewPostSetupManager(postCfg, logger, db, atxsdata.New(), goldenATXID, syncer, validator) require.NoError(tb, err) diff --git a/api/grpcserver/transaction_service.go b/api/grpcserver/transaction_service.go index a0713b3b8c..f02b2b7983 100644 --- a/api/grpcserver/transaction_service.go +++ b/api/grpcserver/transaction_service.go @@ -22,13 +22,13 @@ import ( "github.com/spacemeshos/go-spacemesh/events" "github.com/spacemeshos/go-spacemesh/genvm/core" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) // TransactionService exposes transaction data, and a submit tx endpoint. type TransactionService struct { - db *sql.Database + db *statesql.Database publisher pubsub.Publisher // P2P Swarm mesh meshAPI // Mesh conState conservativeState @@ -52,7 +52,7 @@ func (s TransactionService) String() string { // NewTransactionService creates a new grpc service using config data. func NewTransactionService( - db *sql.Database, + db *statesql.Database, publisher pubsub.Publisher, msh meshAPI, conState conservativeState, diff --git a/api/grpcserver/transaction_service_test.go b/api/grpcserver/transaction_service_test.go index 343e30ed58..a2274a5bb1 100644 --- a/api/grpcserver/transaction_service_test.go +++ b/api/grpcserver/transaction_service_test.go @@ -23,12 +23,13 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/sdk/wallet" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/txs" ) func TestTransactionService_StreamResults(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -134,7 +135,7 @@ func TestTransactionService_StreamResults(t *testing.T) { } func BenchmarkStreamResults(b *testing.B) { - db := sql.InMemory() + db := statesql.InMemory() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -216,7 +217,7 @@ func parseOk() parseExpectation { } func TestParseTransactions(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) t.Cleanup(cancel) vminst := vm.New(db) diff --git a/api/grpcserver/v2alpha1/account_test.go b/api/grpcserver/v2alpha1/account_test.go index 3d8684d5bc..fb876851e2 100644 --- a/api/grpcserver/v2alpha1/account_test.go +++ b/api/grpcserver/v2alpha1/account_test.go @@ -14,8 +14,8 @@ import ( "google.golang.org/grpc/status" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testAccount struct { @@ -27,7 +27,7 @@ type testAccount struct { } func TestAccountService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctrl, ctx := gomock.WithContext(context.Background(), t) conState := NewMockaccountConState(ctrl) diff --git a/api/grpcserver/v2alpha1/activation_test.go b/api/grpcserver/v2alpha1/activation_test.go index 851c0bb432..9f118268de 100644 --- a/api/grpcserver/v2alpha1/activation_test.go +++ b/api/grpcserver/v2alpha1/activation_test.go @@ -16,12 +16,12 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestActivationService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewAtxsGenerator() @@ -104,7 +104,7 @@ func TestActivationService_List(t *testing.T) { } func TestActivationStreamService_Stream(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewAtxsGenerator() @@ -213,7 +213,7 @@ func TestActivationStreamService_Stream(t *testing.T) { } func TestActivationService_ActivationsCount(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() genEpoch3 := fixture.NewAtxsGenerator().WithEpochs(3, 1) diff --git a/api/grpcserver/v2alpha1/layer_test.go b/api/grpcserver/v2alpha1/layer_test.go index bfc082b068..9ae7f277af 100644 --- a/api/grpcserver/v2alpha1/layer_test.go +++ b/api/grpcserver/v2alpha1/layer_test.go @@ -16,13 +16,13 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestLayerService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() lrs := make([]layers.Layer, 100) @@ -98,7 +98,7 @@ func TestLayerConvertEventStatus(t *testing.T) { } func TestLayerStreamService_Stream(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() lrs := make([]layers.Layer, 100) @@ -225,7 +225,7 @@ func layerGenWithBlock(withBlock bool) layerGenOpt { } } -func generateLayer(db *sql.Database, id types.LayerID, opts ...layerGenOpt) (*layers.Layer, error) { +func generateLayer(db *statesql.Database, id types.LayerID, opts ...layerGenOpt) (*layers.Layer, error) { g := &layerGenOpts{} for _, opt := range opts { opt(g) diff --git a/api/grpcserver/v2alpha1/reward_test.go b/api/grpcserver/v2alpha1/reward_test.go index 50952097d1..2e16cab291 100644 --- a/api/grpcserver/v2alpha1/reward_test.go +++ b/api/grpcserver/v2alpha1/reward_test.go @@ -15,12 +15,12 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/rewards" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestRewardService_List(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewRewardsGenerator().WithAddresses(100).WithUniqueCoinbase() @@ -103,7 +103,7 @@ func TestRewardService_List(t *testing.T) { } func TestRewardStreamService_Stream(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewRewardsGenerator() diff --git a/api/grpcserver/v2alpha1/transaction_test.go b/api/grpcserver/v2alpha1/transaction_test.go index 62223acd76..2ab968e7ef 100644 --- a/api/grpcserver/v2alpha1/transaction_test.go +++ b/api/grpcserver/v2alpha1/transaction_test.go @@ -25,13 +25,14 @@ import ( pubsubmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/txs" ) func TestTransactionService_List(t *testing.T) { types.SetLayersPerEpoch(5) - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() gen := fixture.NewTransactionResultGenerator().WithAddresses(2) @@ -136,7 +137,7 @@ func TestTransactionService_List(t *testing.T) { func TestTransactionService_EstimateGas(t *testing.T) { types.SetLayersPerEpoch(5) - db := sql.InMemory() + db := statesql.InMemory() vminst := vm.New(db) ctx := context.Background() @@ -199,7 +200,7 @@ func TestTransactionService_EstimateGas(t *testing.T) { func TestTransactionService_ParseTransaction(t *testing.T) { types.SetLayersPerEpoch(5) - db := sql.InMemory() + db := statesql.InMemory() vminst := vm.New(db) ctx := context.Background() @@ -292,7 +293,7 @@ func TestTransactionServiceSubmitUnsync(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil) - svc := NewTransactionService(sql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -335,7 +336,7 @@ func TestTransactionServiceSubmitInvalidTx(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(errors.New("failed validation")) - svc := NewTransactionService(sql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) @@ -372,7 +373,7 @@ func TestTransactionService_SubmitNoConcurrency(t *testing.T) { txHandler := NewMocktransactionValidator(ctrl) txHandler.EXPECT().VerifyAndCacheTx(gomock.Any(), gomock.Any()).Return(nil).Times(numTxs) - svc := NewTransactionService(sql.InMemory(), nil, syncer, txHandler, publisher) + svc := NewTransactionService(statesql.InMemory(), nil, syncer, txHandler, publisher) cfg, cleanup := launchServer(t, svc) t.Cleanup(cleanup) diff --git a/atxsdata/warmup.go b/atxsdata/warmup.go index 84c6850736..a069a2717a 100644 --- a/atxsdata/warmup.go +++ b/atxsdata/warmup.go @@ -8,9 +8,10 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) -func Warm(db *sql.Database, keep types.EpochID) (*Data, error) { +func Warm(db *statesql.Database, keep types.EpochID) (*Data, error) { cache := New() tx, err := db.Tx(context.Background()) if err != nil { diff --git a/atxsdata/warmup_test.go b/atxsdata/warmup_test.go index 2b82935000..3359efe46d 100644 --- a/atxsdata/warmup_test.go +++ b/atxsdata/warmup_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/mocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func gatx( @@ -37,7 +38,7 @@ func gatx( func TestWarmup(t *testing.T) { types.SetLayersPerEpoch(3) t.Run("sanity", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() applied := types.LayerID(10) nonce := types.VRFPostIndex(1) data := []types.ActivationTx{ @@ -60,19 +61,19 @@ func TestWarmup(t *testing.T) { } }) t.Run("no data", func(t *testing.T) { - c, err := Warm(sql.InMemory(), 1) + c, err := Warm(statesql.InMemory(), 1) require.NoError(t, err) require.NotNil(t, c) }) t.Run("closed db", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, db.Close()) c, err := Warm(db, 1) require.Error(t, err) require.Nil(t, c) }) t.Run("db failures", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nonce := types.VRFPostIndex(1) data := gatx(types.ATXID{1, 1}, 1, types.NodeID{1}, nonce) require.NoError(t, atxs.Add(db, &data)) diff --git a/beacon/beacon_test.go b/beacon/beacon_test.go index c72776f17a..ff96b513ed 100644 --- a/beacon/beacon_test.go +++ b/beacon/beacon_test.go @@ -30,6 +30,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -91,7 +92,7 @@ func newTestDriver(tb testing.TB, cfg Config, p pubsub.Publisher, miners int, id tpd.mVerifier.EXPECT().Verify(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(true) - tpd.cdb = datastore.NewCachedDB(sql.InMemory(), lg) + tpd.cdb = datastore.NewCachedDB(statesql.InMemory(), lg) tpd.ProtocolDriver = New(p, signing.NewEdVerifier(), tpd.mVerifier, tpd.cdb, tpd.mClock, WithConfig(cfg), WithLogger(lg), @@ -494,7 +495,7 @@ func TestBeacon_NoRaceOnClose(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), beacons: make(map[types.EpochID]types.Beacon), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), clock: mclock, closed: make(chan struct{}), results: make(chan result.Beacon, 100), @@ -529,7 +530,7 @@ func TestBeacon_BeaconsWithDatabase(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), beacons: make(map[types.EpochID]types.Beacon), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), clock: mclock, } epoch3 := types.EpochID(3) @@ -582,7 +583,7 @@ func TestBeacon_BeaconsWithDatabaseFailure(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), beacons: make(map[types.EpochID]types.Beacon), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), clock: mclock, } epoch := types.EpochID(3) @@ -600,7 +601,7 @@ func TestBeacon_BeaconsCleanupOldEpoch(t *testing.T) { mclock := NewMocklayerClock(gomock.NewController(t)) pd := &ProtocolDriver{ logger: lg.Named("Beacon"), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), beacons: make(map[types.EpochID]types.Beacon), ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight), clock: mclock, @@ -705,7 +706,7 @@ func TestBeacon_ReportBeaconFromBallot(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), config: UnitTestConfig(), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), beacons: make(map[types.EpochID]types.Beacon), ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight), clock: mclock, @@ -741,7 +742,7 @@ func TestBeacon_ReportBeaconFromBallot_SameBallot(t *testing.T) { pd := &ProtocolDriver{ logger: lg.Named("Beacon"), config: UnitTestConfig(), - cdb: datastore.NewCachedDB(sql.InMemory(), lg), + cdb: datastore.NewCachedDB(statesql.InMemory(), lg), beacons: make(map[types.EpochID]types.Beacon), ballotsBeacons: make(map[types.EpochID]map[types.Beacon]*beaconWeight), clock: mclock, diff --git a/blocks/certifier.go b/blocks/certifier.go index d5045149e9..db7b4627f6 100644 --- a/blocks/certifier.go +++ b/blocks/certifier.go @@ -20,6 +20,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -81,7 +82,7 @@ type Certifier struct { stop func() stopped atomic.Bool - db *sql.Database + db *statesql.Database oracle eligibility.Rolacle signers map[types.NodeID]*signing.EdSigner edVerifier *signing.EdVerifier @@ -99,7 +100,7 @@ type Certifier struct { // NewCertifier creates new block certifier. func NewCertifier( - db *sql.Database, + db *statesql.Database, o eligibility.Rolacle, v *signing.EdVerifier, diff --git a/blocks/certifier_test.go b/blocks/certifier_test.go index 8a3cc72120..a20365bbf0 100644 --- a/blocks/certifier_test.go +++ b/blocks/certifier_test.go @@ -20,6 +20,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -27,7 +28,7 @@ const defaultCnt = uint16(2) type testCertifier struct { *Certifier - db *sql.Database + db *statesql.Database mOracle *eligibility.MockRolacle mPub *pubsubmock.MockPublisher mClk *mocks.MocklayerClock @@ -38,7 +39,7 @@ type testCertifier struct { func newTestCertifier(t *testing.T, signers int) *testCertifier { t.Helper() types.SetLayersPerEpoch(3) - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(t) mo := eligibility.NewMockRolacle(ctrl) mp := pubsubmock.NewMockPublisher(ctrl) diff --git a/blocks/generator.go b/blocks/generator.go index 0d30d7f40f..6651923e75 100644 --- a/blocks/generator.go +++ b/blocks/generator.go @@ -17,8 +17,8 @@ import ( "github.com/spacemeshos/go-spacemesh/hare3/eligibility" "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/proposals/store" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -30,7 +30,7 @@ type Generator struct { eg errgroup.Group stop func() - db *sql.Database + db *statesql.Database atxs *atxsdata.Data proposals *store.Store msh meshProvider @@ -84,7 +84,7 @@ func WithHareOutputChan(ch <-chan hare3.ConsensusOutput) GeneratorOpt { // NewGenerator creates new block generator. func NewGenerator( - db *sql.Database, + db *statesql.Database, atxs *atxsdata.Data, proposals *store.Store, exec executor, diff --git a/blocks/generator_test.go b/blocks/generator_test.go index 0d4d206464..0f967c08f8 100644 --- a/blocks/generator_test.go +++ b/blocks/generator_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -73,7 +74,7 @@ func createTestGenerator(t *testing.T) *testGenerator { } tg.mockMesh.EXPECT().ProcessedLayer().Return(types.LayerID(1)).AnyTimes() lg := zaptest.NewLogger(t) - db := sql.InMemory() + db := statesql.InMemory() data := atxsdata.New() proposals := store.New() tg.Generator = NewGenerator( @@ -265,7 +266,7 @@ func Test_StopBeforeStart(t *testing.T) { func genData( t *testing.T, - db *sql.Database, + db *statesql.Database, data *atxsdata.Data, store *store.Store, lid types.LayerID, diff --git a/blocks/handler.go b/blocks/handler.go index 1d6d7b1201..2e749d389b 100644 --- a/blocks/handler.go +++ b/blocks/handler.go @@ -12,8 +12,8 @@ import ( "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -28,7 +28,7 @@ type Handler struct { logger *zap.Logger fetcher system.Fetcher - db *sql.Database + db *statesql.Database tortoise tortoiseProvider mesh meshProvider } @@ -44,7 +44,13 @@ func WithLogger(logger *zap.Logger) Opt { } // NewHandler creates new Handler. -func NewHandler(f system.Fetcher, db *sql.Database, tortoise tortoiseProvider, m meshProvider, opts ...Opt) *Handler { +func NewHandler( + f system.Fetcher, + db *statesql.Database, + tortoise tortoiseProvider, + m meshProvider, + opts ...Opt, +) *Handler { h := &Handler{ logger: zap.NewNop(), fetcher: f, diff --git a/blocks/handler_test.go b/blocks/handler_test.go index 8c44db0272..7ccb207840 100644 --- a/blocks/handler_test.go +++ b/blocks/handler_test.go @@ -13,8 +13,8 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -34,7 +34,7 @@ func createTestHandler(t *testing.T) *testHandler { } th.Handler = NewHandler( th.mockFetcher, - sql.InMemory(), + statesql.InMemory(), th.mockTortoise, th.mockMesh, WithLogger(zaptest.NewLogger(t)), diff --git a/blocks/utils.go b/blocks/utils.go index e6ec1faf50..88de439ab3 100644 --- a/blocks/utils.go +++ b/blocks/utils.go @@ -18,9 +18,9 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/txs" ) @@ -50,7 +50,7 @@ type proposalMetadata struct { func getProposalMetadata( ctx context.Context, logger *zap.Logger, - db *sql.Database, + db *statesql.Database, atxs *atxsdata.Data, cfg Config, lid types.LayerID, @@ -232,7 +232,7 @@ func toUint64Slice(b []byte) []uint64 { func rewardInfoAndHeight( cfg Config, - db *sql.Database, + db *statesql.Database, atxs *atxsdata.Data, props []*types.Proposal, ) (uint64, []types.AnyReward, error) { diff --git a/blocks/utils_test.go b/blocks/utils_test.go index 11776145b5..d7f9b1f3ca 100644 --- a/blocks/utils_test.go +++ b/blocks/utils_test.go @@ -14,9 +14,9 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -159,7 +159,7 @@ func Test_getBlockTXs_expected_order(t *testing.T) { func Test_getProposalMetadata(t *testing.T) { lg := zaptest.NewLogger(t) - db := sql.InMemory() + db := statesql.InMemory() data := atxsdata.New() cfg := Config{OptFilterThreshold: 70} lid := types.LayerID(111) diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 423f09ae08..4a5d5923b7 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -26,6 +26,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/malsync" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/recovery" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const recoveryDir = "recovery" @@ -113,7 +114,7 @@ func Recover( fs afero.Fs, cfg *RecoverConfig, ) (*PreservedData, error) { - db, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + db, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) if err != nil { return nil, fmt.Errorf("open old database: %w", err) } @@ -148,7 +149,7 @@ func Recover( func RecoverWithDb( ctx context.Context, logger log.Log, - db *sql.Database, + db *statesql.Database, localDB *localsql.Database, fs afero.Fs, cfg *RecoverConfig, @@ -180,7 +181,7 @@ type recoveryData struct { func recoverFromLocalFile( ctx context.Context, logger log.Log, - db *sql.Database, + db *statesql.Database, localDB *localsql.Database, fs afero.Fs, cfg *RecoverConfig, @@ -257,7 +258,7 @@ func recoverFromLocalFile( log.String("backup dir", backupDir), ) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) if err != nil { return nil, fmt.Errorf("open sqlite db %w", err) } @@ -366,7 +367,7 @@ func checkpointData(fs afero.Fs, file string, newGenesis types.LayerID) (*recove func collectOwnAtxDeps( logger log.Log, - db *sql.Database, + db *statesql.Database, localDB *localsql.Database, nodeID types.NodeID, goldenATX types.ATXID, @@ -434,7 +435,7 @@ func collectOwnAtxDeps( } func collectDeps( - db *sql.Database, + db *statesql.Database, ref types.ATXID, all map[types.ATXID]struct{}, ) (map[types.ATXID]*AtxDep, map[types.PoetProofRef]*types.PoetProofMessage, error) { @@ -450,7 +451,7 @@ func collectDeps( } func collect( - db *sql.Database, + db *statesql.Database, ref types.ATXID, all map[types.ATXID]struct{}, deps map[types.ATXID]*AtxDep, @@ -505,7 +506,7 @@ func collect( } func poetProofs( - db *sql.Database, + db *statesql.Database, atxIds map[types.ATXID]*AtxDep, ) (map[types.PoetProofRef]*types.PoetProofMessage, error) { proofs := make(map[types.PoetProofRef]*types.PoetProofMessage, len(atxIds)) diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index 6ff15c9cb1..02cf6d2624 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -29,13 +29,13 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/recovery" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -76,7 +76,7 @@ func accountEqual(tb testing.TB, cacct types.AccountSnapshot, acct *types.Accoun } } -func verifyDbContent(tb testing.TB, db *sql.Database) { +func verifyDbContent(tb testing.TB, db *statesql.Database) { var expected types.Checkpoint require.NoError(tb, json.Unmarshal([]byte(checkpointData), &expected)) expAtx := map[types.ATXID]types.AtxSnapshot{} @@ -165,7 +165,7 @@ func TestRecover(t *testing.T) { } bsdir := filepath.Join(cfg.DataDir, bootstrap.DirName) require.NoError(t, fs.MkdirAll(bsdir, 0o700)) - db := sql.InMemory() + db := statesql.InMemory() localDB := localsql.InMemory() preserve, err := checkpoint.RecoverWithDb(ctx, logtest.New(t), db, localDB, fs, cfg) if tc.expErr != nil { @@ -174,7 +174,7 @@ func TestRecover(t *testing.T) { } require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) defer newDB.Close() @@ -212,7 +212,7 @@ func TestRecover_SameRecoveryInfo(t *testing.T) { } bsdir := filepath.Join(cfg.DataDir, bootstrap.DirName) require.NoError(t, fs.MkdirAll(bsdir, 0o700)) - db := sql.InMemory() + db := statesql.InMemory() localDB := localsql.InMemory() types.SetEffectiveGenesis(0) require.NoError(t, recovery.SetCheckpoint(db, types.LayerID(recoverLayer))) @@ -227,7 +227,7 @@ func TestRecover_SameRecoveryInfo(t *testing.T) { func validateAndPreserveData( tb testing.TB, - db *sql.Database, + db *statesql.Database, deps []*checkpoint.AtxDep, ) { lg := zaptest.NewLogger(tb) @@ -489,7 +489,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) @@ -533,7 +533,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve(t *testing.T) { require.ElementsMatch(t, atxRef, atxIDs(preserve.Deps)) require.ElementsMatch(t, proofRef, proofRefs(preserve.Proofs)) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -579,7 +579,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) @@ -654,7 +654,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_IncludePending(t *testing.T) { require.ElementsMatch(t, atxRef, atxIDs(preserve.Deps)) require.ElementsMatch(t, proofRef, proofRefs(preserve.Proofs)) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -698,7 +698,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) @@ -755,7 +755,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_Still_Initializing(t *testing.T) require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -794,7 +794,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) vAtxs, proofs := createAtxChain(t, sig) @@ -837,7 +837,7 @@ func TestRecover_OwnAtxNotInCheckpoint_Preserve_DepIsGolden(t *testing.T) { require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -876,7 +876,7 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) vAtxs, proofs := createAtxChain(t, sig) @@ -902,7 +902,7 @@ func TestRecover_OwnAtxNotInCheckpoint_DontPreserve(t *testing.T) { require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) @@ -946,7 +946,7 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { Restore: types.LayerID(recoverLayer), } - oldDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + oldDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, oldDB) require.NoError(t, atxs.Add(oldDB, atx)) @@ -956,7 +956,7 @@ func TestRecover_OwnAtxInCheckpoint(t *testing.T) { require.NoError(t, err) require.Nil(t, preserve) - newDB, err := sql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) + newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) require.NoError(t, err) require.NotNil(t, newDB) t.Cleanup(func() { assert.NoError(t, newDB.Close()) }) diff --git a/checkpoint/runner.go b/checkpoint/runner.go index 89e7797cfe..f40cfe391a 100644 --- a/checkpoint/runner.go +++ b/checkpoint/runner.go @@ -10,10 +10,10 @@ import ( "github.com/spf13/afero" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -28,7 +28,7 @@ const ( func checkpointDB( ctx context.Context, - db *sql.Database, + db *statesql.Database, snapshot types.LayerID, numAtxs int, ) (*types.Checkpoint, error) { @@ -115,7 +115,7 @@ func checkpointDB( func Generate( ctx context.Context, fs afero.Fs, - db *sql.Database, + db *statesql.Database, dataDir string, snapshot types.LayerID, numAtxs int, diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index f394ecfb04..7661f79f4b 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -18,10 +18,10 @@ import ( "github.com/spacemeshos/go-spacemesh/checkpoint" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -231,7 +231,12 @@ func asAtxSnapshot(v *types.ActivationTx, cmt *types.ATXID) types.AtxSnapshot { } } -func createMesh(t *testing.T, db *sql.Database, miners map[types.NodeID][]*types.ActivationTx, accts []*types.Account) { +func createMesh( + t *testing.T, + db *statesql.Database, + miners map[types.NodeID][]*types.ActivationTx, + accts []*types.Account, +) { for _, vatxs := range miners { vatxs = slices.Clone(vatxs) // ATXs are expected to be in reverse epoch order and we want older ATXs @@ -293,7 +298,7 @@ func TestRunner_Generate(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) createMesh(t, db, tc.atxes, tc.accts) @@ -343,7 +348,7 @@ func TestRunner_Generate_Error(t *testing.T) { } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() snapshot := types.LayerID(5) var atx *types.ActivationTx if tc.missingCommitment { diff --git a/cmd/activeset/activeset.go b/cmd/activeset/activeset.go index 6c3acd6d0c..868264bc3e 100644 --- a/cmd/activeset/activeset.go +++ b/cmd/activeset/activeset.go @@ -9,8 +9,8 @@ import ( "strconv" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func main() { @@ -30,7 +30,7 @@ Example: if len(dbpath) == 0 { must(errors.New("dbpath is empty"), "dbpath is empty\n") } - db, err := sql.Open("file:" + dbpath) + db, err := statesql.Open("file:" + dbpath) must(err, "can't open db at dbpath=%v. err=%s\n", dbpath, err) ids, err := atxs.GetIDsByEpoch(context.Background(), db, types.EpochID(publish)) diff --git a/cmd/bootstrapper/generator_test.go b/cmd/bootstrapper/generator_test.go index 7b1d4153e7..2eedfa95a4 100644 --- a/cmd/bootstrapper/generator_test.go +++ b/cmd/bootstrapper/generator_test.go @@ -25,6 +25,7 @@ import ( "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -101,7 +102,7 @@ func verifyUpdate(tb testing.TB, data []byte, epoch types.EpochID, expBeacon str func TestGenerator_Generate(t *testing.T) { targetEpoch := types.EpochID(3) - db := sql.InMemory() + db := statesql.InMemory() createAtxs(t, db, targetEpoch-1, types.RandomActiveSet(activeSetSize)) cfg, cleanup := launchServer(t, datastore.NewCachedDB(db, zaptest.NewLogger(t))) t.Cleanup(cleanup) @@ -168,7 +169,7 @@ func TestGenerator_Generate(t *testing.T) { func TestGenerator_CheckAPI(t *testing.T) { targetEpoch := types.EpochID(3) - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) createAtxs(t, db, targetEpoch-1, types.RandomActiveSet(activeSetSize)) cfg, cleanup := launchServer(t, datastore.NewCachedDB(db, lg.Zap())) diff --git a/cmd/bootstrapper/server_test.go b/cmd/bootstrapper/server_test.go index 910d448c79..1423048602 100644 --- a/cmd/bootstrapper/server_test.go +++ b/cmd/bootstrapper/server_test.go @@ -20,7 +20,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/log/logtest" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) //go:embed checkpointdata.json @@ -57,7 +57,7 @@ func updateCheckpoint(t *testing.T, ctx context.Context, data string) { } func TestServer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() cfg, cleanup := launchServer(t, datastore.NewCachedDB(db, zaptest.NewLogger(t))) t.Cleanup(cleanup) diff --git a/cmd/merge-nodes/internal/errors.go b/cmd/merge-nodes/internal/errors.go index 53af669de0..b3e3449b0b 100644 --- a/cmd/merge-nodes/internal/errors.go +++ b/cmd/merge-nodes/internal/errors.go @@ -4,7 +4,4 @@ import ( "errors" ) -var ( - ErrSupervisedNode = errors.New("merging of supervised smeshing nodes is not supported") - ErrInvalidSchema = errors.New("database has an invalid schema version") -) +var ErrSupervisedNode = errors.New("merging of supervised smeshing nodes is not supported") diff --git a/cmd/merge-nodes/internal/merge_action.go b/cmd/merge-nodes/internal/merge_action.go index 43b9929fdc..58c7586b59 100644 --- a/cmd/merge-nodes/internal/merge_action.go +++ b/cmd/merge-nodes/internal/merge_action.go @@ -186,35 +186,17 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { func openDB(dbLog *zap.Logger, path string) (*localsql.Database, error) { dbPath := filepath.Join(path, localDbFile) if _, err := os.Stat(dbPath); err != nil { - return nil, fmt.Errorf("open database %s: %w", dbPath, err) - } - - migrations, err := sql.LocalMigrations() - if err != nil { - return nil, fmt.Errorf("get local migrations: %w", err) + return nil, fmt.Errorf("stat source database %s: %w", dbPath, err) } db, err := localsql.Open("file:"+dbPath, sql.WithLogger(dbLog), - sql.WithMigrations(nil), // do not migrate database when opening + sql.WithEnableMigrations(false), ) if err != nil { return nil, fmt.Errorf("open source database %s: %w", dbPath, err) } - // check if the source database has the right schema - var version int - _, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool { - version = stmt.ColumnInt(0) - return true - }) - if err != nil { - return nil, fmt.Errorf("get source database schema for %s: %w", dbPath, err) - } - if version != len(migrations) { - db.Close() - return nil, ErrInvalidSchema - } return db, nil } diff --git a/cmd/merge-nodes/internal/merge_action_test.go b/cmd/merge-nodes/internal/merge_action_test.go index f2fdc11135..3315821c8c 100644 --- a/cmd/merge-nodes/internal/merge_action_test.go +++ b/cmd/merge-nodes/internal/merge_action_test.go @@ -24,20 +24,25 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) +func oldSchema(t *testing.T) *sql.Schema { + schema, err := localsql.Schema() + require.NoError(t, err) + schema.Migrations = schema.Migrations[:2] + return schema +} + func Test_MergeDBs_InvalidTargetScheme(t *testing.T) { tmpDst := t.TempDir() - migrations, err := sql.LocalMigrations() - require.NoError(t, err) - db, err := localsql.Open("file:"+filepath.Join(tmpDst, localDbFile), - sql.WithMigrations(migrations[:2]), // old schema + sql.WithDatabaseSchema(oldSchema(t)), + sql.WithForceMigrations(true), ) require.NoError(t, err) require.NoError(t, db.Close()) err = MergeDBs(context.Background(), zaptest.NewLogger(t), "", tmpDst) - require.ErrorIs(t, err, ErrInvalidSchema) + require.ErrorIs(t, err, sql.ErrOld) require.ErrorContains(t, err, "target database") } @@ -82,22 +87,20 @@ func Test_MergeDBs_InvalidSourcePath(t *testing.T) { func Test_MergeDBs_InvalidSourceScheme(t *testing.T) { tmpDst := t.TempDir() - migrations, err := sql.LocalMigrations() - require.NoError(t, err) - db, err := localsql.Open("file:" + filepath.Join(tmpDst, localDbFile)) require.NoError(t, err) require.NoError(t, db.Close()) tmpSrc := t.TempDir() db, err = localsql.Open("file:"+filepath.Join(tmpSrc, localDbFile), - sql.WithMigrations(migrations[:2]), // old schema + sql.WithDatabaseSchema(oldSchema(t)), + sql.WithForceMigrations(true), ) require.NoError(t, err) require.NoError(t, db.Close()) err = MergeDBs(context.Background(), zaptest.NewLogger(t), tmpSrc, tmpDst) - require.ErrorIs(t, err, ErrInvalidSchema) + require.ErrorIs(t, err, sql.ErrOld) require.ErrorContains(t, err, "source database") } diff --git a/datastore/store_test.go b/datastore/store_test.go index cc52b1415e..6155664d3a 100644 --- a/datastore/store_test.go +++ b/datastore/store_test.go @@ -26,6 +26,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/poets" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -54,7 +55,7 @@ func getBytes( } func TestMalfeasanceProof_Honest(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, zaptest.NewLogger(t)) require.Equal(t, 0, cdb.MalfeasanceCacheSize()) @@ -115,7 +116,7 @@ func TestMalfeasanceProof_Honest(t *testing.T) { } func TestMalfeasanceProof_Dishonest(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, zaptest.NewLogger(t)) require.Equal(t, 0, cdb.MalfeasanceCacheSize()) @@ -143,7 +144,7 @@ func TestMalfeasanceProof_Dishonest(t *testing.T) { } func TestBlobStore_GetATXBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -186,7 +187,7 @@ func TestBlobStore_GetATXBlob(t *testing.T) { } func TestBlobStore_GetBallotBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -221,7 +222,7 @@ func TestBlobStore_GetBallotBlob(t *testing.T) { } func TestBlobStore_GetBlockBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -256,7 +257,7 @@ func TestBlobStore_GetBlockBlob(t *testing.T) { } func TestBlobStore_GetPoetBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -285,7 +286,7 @@ func TestBlobStore_GetPoetBlob(t *testing.T) { } func TestBlobStore_GetProposalBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() proposals := store.New() bs := datastore.NewBlobStore(db, proposals) ctx := context.Background() @@ -323,7 +324,7 @@ func TestBlobStore_GetProposalBlob(t *testing.T) { } func TestBlobStore_GetTXBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -351,7 +352,7 @@ func TestBlobStore_GetTXBlob(t *testing.T) { } func TestBlobStore_GetMalfeasanceBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -385,7 +386,7 @@ func TestBlobStore_GetMalfeasanceBlob(t *testing.T) { } func TestBlobStore_GetActiveSet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() bs := datastore.NewBlobStore(db, store.New()) ctx := context.Background() @@ -409,7 +410,7 @@ func TestBlobStore_GetActiveSet(t *testing.T) { } func Test_MarkingMalicious(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() store := atxsdata.New() id := types.RandomNodeID() cdb := datastore.NewCachedDB(db, zaptest.NewLogger(t), datastore.WithConsensusCache(store)) diff --git a/fetch/fetch_test.go b/fetch/fetch_test.go index dff1aa0cd4..85bdcfe7a1 100644 --- a/fetch/fetch_test.go +++ b/fetch/fetch_test.go @@ -21,7 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/server" "github.com/spacemeshos/go-spacemesh/proposals/store" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testFetch struct { @@ -80,7 +80,7 @@ func createFetch(tb testing.TB) *testFetch { } lg := logtest.New(tb) - tf.Fetch = NewFetch(datastore.NewCachedDB(sql.InMemory(), lg.Zap()), store.New(), nil, + tf.Fetch = NewFetch(datastore.NewCachedDB(statesql.InMemory(), lg.Zap()), store.New(), nil, WithContext(context.TODO()), WithConfig(cfg), WithLogger(lg), @@ -117,7 +117,7 @@ func badReceiver(context.Context, types.Hash32, p2p.Peer, []byte) error { func TestFetch_Start(t *testing.T) { lg := logtest.New(t) - f := NewFetch(datastore.NewCachedDB(sql.InMemory(), lg.Zap()), store.New(), nil, + f := NewFetch(datastore.NewCachedDB(statesql.InMemory(), lg.Zap()), store.New(), nil, WithContext(context.TODO()), WithConfig(DefaultConfig()), WithLogger(lg), @@ -384,7 +384,7 @@ func TestFetch_PeerDroppedWhenMessageResultsInValidationReject(t *testing.T) { }) defer eg.Wait() - fetcher := NewFetch(datastore.NewCachedDB(sql.InMemory(), lg.Zap()), store.New(), h, + fetcher := NewFetch(datastore.NewCachedDB(statesql.InMemory(), lg.Zap()), store.New(), h, WithContext(ctx), WithConfig(cfg), WithLogger(lg), diff --git a/fetch/handler_test.go b/fetch/handler_test.go index 87c48f89bc..1fcd79e174 100644 --- a/fetch/handler_test.go +++ b/fetch/handler_test.go @@ -24,17 +24,18 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testHandler struct { *handler - db *sql.Database + db *statesql.Database cdb *datastore.CachedDB } func createTestHandler(t testing.TB, opts ...sql.Opt) *testHandler { lg := logtest.New(t) - db := sql.InMemory(opts...) + db := statesql.InMemory(opts...) cdb := datastore.NewCachedDB(db, lg.Zap()) return &testHandler{ handler: newHandler(cdb, datastore.NewBlobStore(cdb, store.New()), lg), diff --git a/fetch/mesh_data_test.go b/fetch/mesh_data_test.go index da6ef89ccf..b1bf857e55 100644 --- a/fetch/mesh_data_test.go +++ b/fetch/mesh_data_test.go @@ -24,7 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/server" "github.com/spacemeshos/go-spacemesh/proposals/store" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -867,7 +867,7 @@ func Test_GetAtxsLimiting(t *testing.T) { cfg.QueueSize = 1000 cfg.GetAtxsConcurrency = getAtxConcurrency - cdb := datastore.NewCachedDB(sql.InMemory(), zaptest.NewLogger(t)) + cdb := datastore.NewCachedDB(statesql.InMemory(), zaptest.NewLogger(t)) client := server.New(mesh.Hosts()[0], hashProtocol, nil) host, err := p2p.Upgrade(mesh.Hosts()[0]) require.NoError(t, err) diff --git a/fetch/p2p_test.go b/fetch/p2p_test.go index 5557bbcea1..9ab1a16d3a 100644 --- a/fetch/p2p_test.go +++ b/fetch/p2p_test.go @@ -27,6 +27,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/poets" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -37,13 +38,13 @@ type blobKey struct { type testP2PFetch struct { t *testing.T - clientDB *sql.Database + clientDB *statesql.Database // client proposals clientPDB *store.Store clientCDB *datastore.CachedDB clientFetch *Fetch serverID peer.ID - serverDB *sql.Database + serverDB *statesql.Database // server proposals serverPDB *store.Store serverCDB *datastore.CachedDB @@ -104,8 +105,8 @@ func createP2PFetch( if sqlCache { sqlOpts = []sql.Opt{sql.WithQueryCache(true)} } - clientDB := sql.InMemory(sqlOpts...) - serverDB := sql.InMemory(sqlOpts...) + clientDB := statesql.InMemory(sqlOpts...) + serverDB := statesql.InMemory(sqlOpts...) tpf := &testP2PFetch{ t: t, clientDB: clientDB, diff --git a/genvm/core/context_test.go b/genvm/core/context_test.go index e13a8287ff..92d8cb627d 100644 --- a/genvm/core/context_test.go +++ b/genvm/core/context_test.go @@ -10,23 +10,23 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/core" "github.com/spacemeshos/go-spacemesh/genvm/core/mocks" "github.com/spacemeshos/go-spacemesh/genvm/registry" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestTransfer(t *testing.T) { t.Run("NoBalance", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} require.ErrorIs(t, ctx.Transfer(core.Address{}, 100), core.ErrNoBalance) }) t.Run("MaxSpend", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} ctx.PrincipalAccount.Balance = 1000 ctx.Header.MaxSpend = 100 require.NoError(t, ctx.Transfer(core.Address{1}, 50)) require.ErrorIs(t, ctx.Transfer(core.Address{2}, 100), core.ErrMaxSpend) }) t.Run("ReducesBalance", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} ctx.PrincipalAccount.Balance = 1000 ctx.Header.MaxSpend = 1000 for _, amount := range []uint64{50, 100, 200, 255} { @@ -67,7 +67,7 @@ func TestConsume(t *testing.T) { func TestApply(t *testing.T) { t.Run("UpdatesNonce", func(t *testing.T) { - ss := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + ss := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) ctx := core.Context{Loader: ss} ctx.PrincipalAccount.Address = core.Address{1} ctx.Header.Nonce = 10 @@ -80,7 +80,7 @@ func TestApply(t *testing.T) { require.Equal(t, ctx.PrincipalAccount.NextNonce, account.NextNonce) }) t.Run("ConsumeMaxGas", func(t *testing.T) { - ss := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + ss := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) ctx := core.Context{Loader: ss} ctx.PrincipalAccount.Balance = 1000 @@ -97,7 +97,7 @@ func TestApply(t *testing.T) { require.Equal(t, ctx.Fee(), ctx.Header.MaxGas*ctx.Header.GasPrice) }) t.Run("PreserveTransferOrder", func(t *testing.T) { - ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{sql.InMemory()})} + ctx := core.Context{Loader: core.NewStagedCache(core.DBLoader{statesql.InMemory()})} ctx.PrincipalAccount.Address = core.Address{1} ctx.PrincipalAccount.Balance = 1000 ctx.Header.MaxSpend = 1000 @@ -129,7 +129,7 @@ func TestRelay(t *testing.T) { remote = core.Address{'r', 'e', 'm'} ) t.Run("not spawned", func(t *testing.T) { - cache := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + cache := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) ctx := core.Context{Loader: cache} call := func(remote core.Host) error { require.Fail(t, "not expected to be called") @@ -138,7 +138,7 @@ func TestRelay(t *testing.T) { require.ErrorIs(t, ctx.Relay(template, remote, call), core.ErrNotSpawned) }) t.Run("mismatched template", func(t *testing.T) { - cache := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + cache := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) require.NoError(t, cache.Update(core.Account{ Address: remote, TemplateAddress: &core.Address{'m', 'i', 's'}, @@ -166,7 +166,7 @@ func TestRelay(t *testing.T) { reg := registry.New() reg.Register(template, handler) - cache := core.NewStagedCache(core.DBLoader{sql.InMemory()}) + cache := core.NewStagedCache(core.DBLoader{statesql.InMemory()}) receiver2 := core.Address{'f'} const ( total = 1000 diff --git a/genvm/core/staged_cache_test.go b/genvm/core/staged_cache_test.go index 8a6b29519e..a018930955 100644 --- a/genvm/core/staged_cache_test.go +++ b/genvm/core/staged_cache_test.go @@ -6,11 +6,11 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/genvm/core" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestCacheGetCopies(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ss := core.NewStagedCache(core.DBLoader{db}) address := core.Address{1} account, err := ss.Get(address) @@ -23,7 +23,7 @@ func TestCacheGetCopies(t *testing.T) { } func TestCacheUpdatePreserveOrder(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ss := core.NewStagedCache(core.DBLoader{db}) order := []core.Address{{3}, {1}, {2}} for _, address := range order { diff --git a/genvm/templates/vault/vault_test.go b/genvm/templates/vault/vault_test.go index 28fef09d14..e1f461fd9c 100644 --- a/genvm/templates/vault/vault_test.go +++ b/genvm/templates/vault/vault_test.go @@ -9,7 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/genvm/core" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestVested(t *testing.T) { @@ -392,7 +392,7 @@ func TestSpend(t *testing.T) { } ctx := core.Context{ LayerID: types.LayerID(tc.lid), - Loader: core.NewStagedCache(core.DBLoader{Executor: sql.InMemory()}), + Loader: core.NewStagedCache(core.DBLoader{Executor: statesql.InMemory()}), Header: types.TxHeader{MaxSpend: math.MaxUint64}, PrincipalAccount: types.Account{ Address: owner, @@ -419,7 +419,7 @@ func TestSpend(t *testing.T) { } ctx := core.Context{ LayerID: types.LayerID(2), - Loader: core.NewStagedCache(core.DBLoader{Executor: sql.InMemory()}), + Loader: core.NewStagedCache(core.DBLoader{Executor: statesql.InMemory()}), Header: types.TxHeader{MaxSpend: math.MaxUint64}, PrincipalAccount: types.Account{ Address: owner, diff --git a/genvm/vm.go b/genvm/vm.go index 7174763511..24a156f7d7 100644 --- a/genvm/vm.go +++ b/genvm/vm.go @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/rewards" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/system" ) @@ -58,7 +59,7 @@ func WithConfig(cfg Config) Opt { } // New returns VM instance. -func New(db *sql.Database, opts ...Opt) *VM { +func New(db *statesql.Database, opts ...Opt) *VM { vm := &VM{ logger: log.NewNop(), db: db, @@ -78,7 +79,7 @@ func New(db *sql.Database, opts ...Opt) *VM { // VM handles modifications to the account state. type VM struct { logger log.Log - db *sql.Database + db *statesql.Database cfg Config registry *registry.Registry } diff --git a/genvm/vm_test.go b/genvm/vm_test.go index fa78d5248e..7fbf5d0cc0 100644 --- a/genvm/vm_test.go +++ b/genvm/vm_test.go @@ -34,9 +34,9 @@ import ( "github.com/spacemeshos/go-spacemesh/hash" "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func testContext(lid types.LayerID) ApplyContext { @@ -48,7 +48,7 @@ func testContext(lid types.LayerID) ApplyContext { func newTester(tb testing.TB) *tester { return &tester{ TB: tb, - VM: New(sql.InMemory(), + VM: New(statesql.InMemory(), WithLogger(logtest.New(tb)), WithConfig(Config{GasLimit: math.MaxUint64}), ), @@ -279,7 +279,7 @@ type tester struct { } func (t *tester) persistent() *tester { - db, err := sql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) + db, err := statesql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) t.Cleanup(func() { require.NoError(t, db.Close()) }) require.NoError(t, err) t.VM = New(db, WithLogger(logtest.New(t)), @@ -2572,7 +2572,7 @@ func TestVestingData(t *testing.T) { spendAccountNonce := t2.nonces[0] spendAmount := uint64(1_000_000) - vm := New(sql.InMemory(), WithLogger(logtest.New(t))) + vm := New(statesql.InMemory(), WithLogger(logtest.New(t))) require.NoError(t, vm.ApplyGenesis( []core.Account{ { diff --git a/go.mod b/go.mod index 9ed0fe0dab..b93c60170e 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/rs/cors v1.11.0 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/seehuhn/mt19937 v1.0.0 + github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e github.com/spacemeshos/api/release/go v1.42.0 github.com/spacemeshos/economics v0.1.3 github.com/spacemeshos/fixed v0.1.1 diff --git a/go.sum b/go.sum index 3139dde7ef..b78db2b5f6 100644 --- a/go.sum +++ b/go.sum @@ -600,6 +600,8 @@ github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3 github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e h1:H+jDTUeF+SVd4ApwnSFoew8ZwGNRfgb9EsZc7LcocAg= +github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e/go.mod h1:VsUklG6OQo7Ctunu0gS3AtEOCEc2kMB6r5rKzxAes58= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/spacemeshos/api/release/go v1.42.0 h1:K85zw+KZA1UA3VNwvXD2UIND7NLyAiJo4Kz6ZznFEEc= github.com/spacemeshos/api/release/go v1.42.0/go.mod h1:aCDRfna5MA7LJWZPa4k+vTRvBUf1Swz8kcziPcdp6i8= diff --git a/hare3/eligibility/oracle_test.go b/hare3/eligibility/oracle_test.go index 60652a04a7..39b2ba4cf5 100644 --- a/hare3/eligibility/oracle_test.go +++ b/hare3/eligibility/oracle_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -46,14 +47,14 @@ func TestMain(m *testing.M) { type testOracle struct { *Oracle tb testing.TB - db *sql.Database + db *statesql.Database atxsdata *atxsdata.Data mBeacon *mocks.MockBeaconGetter mVerifier *MockvrfVerifier } func defaultOracle(tb testing.TB) *testOracle { - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() ctrl := gomock.NewController(tb) diff --git a/hare3/hare.go b/hare3/hare.go index 946278ebd6..c5f2bb3780 100644 --- a/hare3/hare.go +++ b/hare3/hare.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/beacons" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -146,7 +147,7 @@ type nodeclock interface { func New( nodeclock nodeclock, pubsub pubsub.PublishSubsciber, - db *sql.Database, + db *statesql.Database, atxsdata *atxsdata.Data, proposals *store.Store, verifier *signing.EdVerifier, @@ -208,7 +209,7 @@ type Hare struct { // dependencies nodeclock nodeclock pubsub pubsub.PublishSubsciber - db *sql.Database + db *statesql.Database atxsdata *atxsdata.Data proposals *store.Store verifier *signing.EdVerifier diff --git a/hare3/hare_test.go b/hare3/hare_test.go index c53f78fbe5..8221f0d584 100644 --- a/hare3/hare_test.go +++ b/hare3/hare_test.go @@ -24,11 +24,11 @@ import ( pmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/proposals/store" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/beacons" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -114,7 +114,7 @@ type node struct { vrfsigner *signing.VRFSigner atx *types.ActivationTx oracle *eligibility.Oracle - db *sql.Database + db *statesql.Database atxsdata *atxsdata.Data proposals *store.Store @@ -146,7 +146,7 @@ func (n *node) reuseSigner(signer *signing.EdSigner) *node { } func (n *node) withDb() *node { - n.db = sql.InMemory() + n.db = statesql.InMemory() n.atxsdata = atxsdata.New() n.proposals = store.New() return n @@ -897,7 +897,7 @@ func TestProposals(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() proposals := store.New() hare := New( diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index 0218bc906d..e73358ddde 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -24,6 +24,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -48,7 +49,7 @@ func createIdentity(tb testing.TB, db sql.Executor, sig *signing.EdSigner) { } func TestHandler_HandleMalfeasanceProof_multipleATXs(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) ctrl := gomock.NewController(t) @@ -270,7 +271,7 @@ func TestHandler_HandleMalfeasanceProof_multipleATXs(t *testing.T) { } func TestHandler_HandleMalfeasanceProof_multipleBallots(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) ctrl := gomock.NewController(t) trt := malfeasance.NewMocktortoise(ctrl) @@ -491,7 +492,7 @@ func TestHandler_HandleMalfeasanceProof_multipleBallots(t *testing.T) { } func TestHandler_HandleMalfeasanceProof_hareEquivocation(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) ctrl := gomock.NewController(t) trt := malfeasance.NewMocktortoise(ctrl) @@ -760,7 +761,7 @@ func TestHandler_HandleMalfeasanceProof_hareEquivocation(t *testing.T) { } func TestHandler_CrossDomain(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) ctrl := gomock.NewController(t) trt := malfeasance.NewMocktortoise(ctrl) @@ -824,7 +825,7 @@ func TestHandler_CrossDomain(t *testing.T) { } func TestHandler_HandleSyncedMalfeasanceProof_multipleATXs(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) ctrl := gomock.NewController(t) trt := malfeasance.NewMocktortoise(ctrl) @@ -887,7 +888,7 @@ func TestHandler_HandleSyncedMalfeasanceProof_multipleATXs(t *testing.T) { } func TestHandler_HandleSyncedMalfeasanceProof_multipleBallots(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) ctrl := gomock.NewController(t) trt := malfeasance.NewMocktortoise(ctrl) @@ -949,7 +950,7 @@ func TestHandler_HandleSyncedMalfeasanceProof_multipleBallots(t *testing.T) { } func TestHandler_HandleSyncedMalfeasanceProof_hareEquivocation(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) ctrl := gomock.NewController(t) trt := malfeasance.NewMocktortoise(ctrl) @@ -1014,7 +1015,7 @@ func TestHandler_HandleSyncedMalfeasanceProof_hareEquivocation(t *testing.T) { } func TestHandler_HandleSyncedMalfeasanceProof_wrongHash(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lg := logtest.New(t) ctrl := gomock.NewController(t) trt := malfeasance.NewMocktortoise(ctrl) @@ -1079,7 +1080,7 @@ func TestHandler_HandleSyncedMalfeasanceProof_wrongHash(t *testing.T) { type testMalfeasanceHandler struct { *malfeasance.Handler - db *sql.Database + db *statesql.Database sig *signing.EdSigner mPostVerifier *malfeasance.MockpostVerifier @@ -1087,7 +1088,7 @@ type testMalfeasanceHandler struct { } func newTestMalfeasanceHandler(t testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) diff --git a/mesh/executor_test.go b/mesh/executor_test.go index 01330cfb6e..790066d2b6 100644 --- a/mesh/executor_test.go +++ b/mesh/executor_test.go @@ -22,6 +22,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMain(m *testing.M) { @@ -33,7 +34,7 @@ func TestMain(m *testing.M) { type testExecutor struct { tb testing.TB exec *mesh.Executor - db *sql.Database + db *statesql.Database atxsdata *atxsdata.Data mcs *mocks.MockconservativeState mvm *mocks.MockvmState @@ -43,7 +44,7 @@ func newTestExecutor(t *testing.T) *testExecutor { ctrl := gomock.NewController(t) te := &testExecutor{ tb: t, - db: sql.InMemory(), + db: statesql.InMemory(), atxsdata: atxsdata.New(), mvm: mocks.NewMockvmState(ctrl), mcs: mocks.NewMockconservativeState(ctrl), diff --git a/mesh/mesh.go b/mesh/mesh.go index eee3391bb1..aeb460673b 100644 --- a/mesh/mesh.go +++ b/mesh/mesh.go @@ -29,13 +29,14 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/rewards" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) // Mesh is the logic layer above our mesh.DB database. type Mesh struct { logger log.Log - cdb *sql.Database + cdb *statesql.Database atxsdata *atxsdata.Data clock layerClock @@ -58,7 +59,7 @@ type Mesh struct { // NewMesh creates a new instant of a mesh. func NewMesh( - db *sql.Database, + db *statesql.Database, atxsdata *atxsdata.Data, c layerClock, trtl system.Tortoise, diff --git a/mesh/mesh_test.go b/mesh/mesh_test.go index 30abe325d0..da4ec6887e 100644 --- a/mesh/mesh_test.go +++ b/mesh/mesh_test.go @@ -27,6 +27,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -39,7 +40,7 @@ const ( type testMesh struct { *Mesh - db *sql.Database + db *statesql.Database // it is used in malfeasence.Validate, which is called in the tests cdb *datastore.CachedDB atxsdata *atxsdata.Data @@ -53,7 +54,7 @@ func createTestMesh(t *testing.T) *testMesh { t.Helper() types.SetLayersPerEpoch(3) lg := logtest.New(t) - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() ctrl := gomock.NewController(t) tm := &testMesh{ diff --git a/miner/active_set_generator_test.go b/miner/active_set_generator_test.go index 3eead31e24..6b4319a60e 100644 --- a/miner/active_set_generator_test.go +++ b/miner/active_set_generator_test.go @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/activeset" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type expect struct { @@ -65,7 +66,7 @@ func unixPtr(sec, nsec int64) *time.Time { func newTesterActiveSetGenerator(tb testing.TB, cfg config) *testerActiveSetGenerator { var ( - db = sql.InMemory() + db = statesql.InMemory() localdb = localsql.InMemory() atxsdata = atxsdata.New() ctrl = gomock.NewController(tb) @@ -97,7 +98,7 @@ type testerActiveSetGenerator struct { tb testing.TB gen *activeSetGenerator - db *sql.Database + db *statesql.Database localdb *localsql.Database atxsdata *atxsdata.Data ctrl *gomock.Controller diff --git a/miner/proposal_builder_test.go b/miner/proposal_builder_test.go index 542927cc2c..6e048886f3 100644 --- a/miner/proposal_builder_test.go +++ b/miner/proposal_builder_test.go @@ -25,7 +25,6 @@ import ( pmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/proposals" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/ballots" @@ -35,6 +34,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -748,7 +748,7 @@ func TestBuild(t *testing.T) { publisher = pmocks.NewMockPublisher(ctrl) tortoise = mocks.NewMockvotesEncoder(ctrl) syncer = smocks.NewMockSyncStateProvider(ctrl) - db = sql.InMemory() + db = statesql.InMemory() localdb = localsql.InMemory() atxsdata = atxsdata.New() ) @@ -904,7 +904,7 @@ func TestStartStop(t *testing.T) { publisher = pmocks.NewMockPublisher(ctrl) tortoise = mocks.NewMockvotesEncoder(ctrl) syncer = smocks.NewMockSyncStateProvider(ctrl) - db = sql.InMemory() + db = statesql.InMemory() localdb = localsql.InMemory() atxsdata = atxsdata.New() ) diff --git a/node/node.go b/node/node.go index 313437c3dc..368e96c814 100644 --- a/node/node.go +++ b/node/node.go @@ -77,6 +77,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" dbmetrics "github.com/spacemeshos/go-spacemesh/sql/metrics" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer" "github.com/spacemeshos/go-spacemesh/syncer/atxsync" "github.com/spacemeshos/go-spacemesh/syncer/blockssync" @@ -384,7 +385,7 @@ type App struct { fileLock *flock.Flock signers []*signing.EdSigner Config *config.Config - db *sql.Database + db *statesql.Database cachedDB *datastore.CachedDB dbMetrics *dbmetrics.DBMetricsCollector localDB *localsql.Database @@ -1886,14 +1887,17 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { if err := os.MkdirAll(dbPath, os.ModePerm); err != nil { return fmt.Errorf("failed to create %s: %w", dbPath, err) } - migrations, err := sql.StateMigrations() + dbLog := app.addLogger(StateDbLogger, lg) + schema, err := statesql.Schema() if err != nil { - return fmt.Errorf("failed to load migrations: %w", err) + return fmt.Errorf("error loading db schema: %w", err) + } + if len(app.Config.DatabaseSkipMigrations) > 0 { + schema.SkipMigrations(app.Config.DatabaseSkipMigrations...) } - dbLog := app.addLogger(StateDbLogger, lg) dbopts := []sql.Opt{ sql.WithLogger(dbLog.Zap()), - sql.WithMigrations(migrations), + sql.WithDatabaseSchema(schema), sql.WithConnections(app.Config.DatabaseConnections), sql.WithLatencyMetering(app.Config.DatabaseLatencyMetering), sql.WithVacuumState(app.Config.DatabaseVacuumState), @@ -1904,10 +1908,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { activesets.CacheKindActiveSetBlob: app.Config.DatabaseQueryCacheSizes.ActiveSetBlob, }), } - if len(app.Config.DatabaseSkipMigrations) > 0 { - dbopts = append(dbopts, sql.WithSkipMigrations(app.Config.DatabaseSkipMigrations...)) - } - sqlDB, err := sql.Open("file:"+filepath.Join(dbPath, dbFile), dbopts...) + sqlDB, err := statesql.Open("file:"+filepath.Join(dbPath, dbFile), dbopts...) if err != nil { return fmt.Errorf("open sqlite db %w", err) } @@ -1949,13 +1950,8 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { app.log.With().Info("malicious ATX check completed", log.Duration("duration", time.Since(start))) } - migrations, err = sql.LocalMigrations() - if err != nil { - return fmt.Errorf("load local migrations: %w", err) - } localDB, err := localsql.Open("file:"+filepath.Join(dbPath, localDbFile), sql.WithLogger(dbLog.Zap()), - sql.WithMigrations(migrations), sql.WithConnections(app.Config.DatabaseConnections), ) if err != nil { diff --git a/node/node_version_check_test.go b/node/node_version_check_test.go index 8e8d8a6d76..787a7b84e6 100644 --- a/node/node_version_check_test.go +++ b/node/node_version_check_test.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/config" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestUpgradeToV15(t *testing.T) { @@ -37,10 +38,14 @@ func TestUpgradeToV15(t *testing.T) { uri := path.Join(cfg.DataDir(), localDbFile) - migrations, err := sql.LocalMigrations() + schema, err := statesql.Schema() require.NoError(t, err) - db, err := sql.Open(uri, sql.WithMigrations(migrations[:2])) + schema.Migrations = schema.Migrations[:2] + + db, err := statesql.Open(uri, + sql.WithDatabaseSchema(schema), + sql.WithForceMigrations(true)) require.NoError(t, err) require.NoError(t, db.Close()) diff --git a/proposals/handler.go b/proposals/handler.go index 8109d8d9eb..5344465b35 100644 --- a/proposals/handler.go +++ b/proposals/handler.go @@ -26,6 +26,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/ballots" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -49,7 +50,7 @@ type Handler struct { logger log.Log cfg Config - db *sql.Database + db *statesql.Database atxsdata *atxsdata.Data activeSets *lru.Cache[types.Hash32, uint64] edVerifier *signing.EdVerifier @@ -108,7 +109,7 @@ func WithConfig(cfg Config) Opt { // NewHandler creates new Handler. func NewHandler( - db *sql.Database, + db *statesql.Database, atxsdata *atxsdata.Data, proposals proposalsConsumer, edVerifier *signing.EdVerifier, diff --git a/proposals/handler_test.go b/proposals/handler_test.go index 2fbff099f3..0023ae113d 100644 --- a/proposals/handler_test.go +++ b/proposals/handler_test.go @@ -31,6 +31,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/blocks" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -93,7 +94,7 @@ func fullMockSet(tb testing.TB) *mockSet { func createTestHandler(t *testing.T) *testHandler { types.SetLayersPerEpoch(layersPerEpoch) ms := fullMockSet(t) - db := sql.InMemory() + db := statesql.InMemory() atxsdata := atxsdata.New() ms.md.EXPECT().GetBallot(gomock.Any()).AnyTimes().DoAndReturn(func(id types.BallotID) *tortoise.BallotData { ballot, err := ballots.Get(db, id) @@ -236,7 +237,7 @@ func createProposal(t *testing.T, opts ...any) *types.Proposal { return p } -func createAtx(t *testing.T, db *sql.Database, epoch types.EpochID, atxID types.ATXID, nodeID types.NodeID) { +func createAtx(t *testing.T, db *statesql.Database, epoch types.EpochID, atxID types.ATXID, nodeID types.NodeID) { atx := &types.ActivationTx{ PublishEpoch: epoch, NumUnits: 1, diff --git a/prune/prune.go b/prune/prune.go index 0f7ddb4218..c5b4551ea9 100644 --- a/prune/prune.go +++ b/prune/prune.go @@ -7,9 +7,9 @@ import ( "go.uber.org/zap" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -22,7 +22,7 @@ func WithLogger(logger *zap.Logger) Opt { } } -func New(db *sql.Database, safeDist uint32, activesetEpoch types.EpochID, opts ...Opt) *Pruner { +func New(db *statesql.Database, safeDist uint32, activesetEpoch types.EpochID, opts ...Opt) *Pruner { p := &Pruner{ logger: zap.NewNop(), db: db, @@ -37,7 +37,7 @@ func New(db *sql.Database, safeDist uint32, activesetEpoch types.EpochID, opts . type Pruner struct { logger *zap.Logger - db *sql.Database + db *statesql.Database safeDist uint32 activesetEpoch types.EpochID } diff --git a/prune/prune_test.go b/prune/prune_test.go index 91011f98f9..0eb4ee5e7e 100644 --- a/prune/prune_test.go +++ b/prune/prune_test.go @@ -11,13 +11,14 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) func TestPrune(t *testing.T) { types.SetLayersPerEpoch(3) - db := sql.InMemory() + db := statesql.InMemory() current := types.LayerID(10) lyrProps := make([]*types.Proposal, 0, current) diff --git a/sql/accounts/accounts_test.go b/sql/accounts/accounts_test.go index 21d34dbea2..4556ad332a 100644 --- a/sql/accounts/accounts_test.go +++ b/sql/accounts/accounts_test.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/builder" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func genSeq(address types.Address, n int) []*types.Account { @@ -21,7 +22,7 @@ func genSeq(address types.Address, n int) []*types.Account { func TestUpdate(t *testing.T) { address := types.Address{1, 2, 3} - db := sql.InMemory() + db := statesql.InMemory() seq := genSeq(address, 2) for _, update := range seq { require.NoError(t, Update(db, update)) @@ -34,7 +35,7 @@ func TestUpdate(t *testing.T) { func TestHas(t *testing.T) { address := types.Address{1, 2, 3} - db := sql.InMemory() + db := statesql.InMemory() has, err := Has(db, address) require.NoError(t, err) require.False(t, has) @@ -50,7 +51,7 @@ func TestHas(t *testing.T) { func TestRevert(t *testing.T) { address := types.Address{1, 1} seq := genSeq(address, 10) - db := sql.InMemory() + db := statesql.InMemory() for _, update := range seq { require.NoError(t, Update(db, update)) } @@ -62,7 +63,7 @@ func TestRevert(t *testing.T) { } func TestAll(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() addresses := []types.Address{{1, 1}, {2, 2}, {3, 3}} n := []int{10, 7, 20} for i, address := range addresses { @@ -81,7 +82,7 @@ func TestAll(t *testing.T) { } func TestSnapshot(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := Snapshot(db, types.LayerID(1)) require.ErrorIs(t, err, sql.ErrNotFound) @@ -108,7 +109,7 @@ func TestSnapshot(t *testing.T) { } func TestIterateAccountsOps(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() for i := 0; i < 100; i++ { addr := types.Address{} diff --git a/sql/activesets/activesets_test.go b/sql/activesets/activesets_test.go index 8828fa417e..8aa3c9193b 100644 --- a/sql/activesets/activesets_test.go +++ b/sql/activesets/activesets_test.go @@ -9,6 +9,7 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestActiveSet(t *testing.T) { @@ -19,7 +20,7 @@ func TestActiveSet(t *testing.T) { Epoch: 2, Set: []types.ATXID{{1}, {2}}, } - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, Add(db, ids[0], set)) require.ErrorIs(t, Add(db, ids[0], set), sql.ErrObjectExists) @@ -68,7 +69,7 @@ func TestCachedActiveSet(t *testing.T) { Epoch: 2, Set: []types.ATXID{{3}, {4}}, } - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) require.NoError(t, Add(db, ids[0], set0)) require.NoError(t, Add(db, ids[1], set1)) diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 9162ca17db..9a64954d82 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -16,6 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 5 @@ -28,7 +29,7 @@ func TestMain(m *testing.M) { } func TestGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxList := make([]*types.ActivationTx, 0) for i := 0; i < 3; i++ { @@ -54,7 +55,7 @@ func TestGet(t *testing.T) { } func TestAll(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxList := make([]*types.ActivationTx, 0) for i := 0; i < 3; i++ { @@ -76,7 +77,7 @@ func TestAll(t *testing.T) { } func TestHasID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atxList := make([]*types.ActivationTx, 0) for i := 0; i < 3; i++ { @@ -102,7 +103,7 @@ func TestHasID(t *testing.T) { } func TestGetFirstIDByNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -137,7 +138,7 @@ func TestGetFirstIDByNodeID(t *testing.T) { } func TestLatestN(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -228,7 +229,7 @@ func TestLatestN(t *testing.T) { } func TestGetByEpochAndNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -260,7 +261,7 @@ func TestGetByEpochAndNodeID(t *testing.T) { } func TestGetLastIDByNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -295,7 +296,7 @@ func TestGetLastIDByNodeID(t *testing.T) { } func TestGetIDByEpochAndNodeID(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig1, err := signing.NewEdSigner() require.NoError(t, err) @@ -339,7 +340,7 @@ func TestGetIDByEpochAndNodeID(t *testing.T) { } func TestGetIDsByEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() sig1, err := signing.NewEdSigner() @@ -375,7 +376,7 @@ func TestGetIDsByEpoch(t *testing.T) { } func TestGetIDsByEpochCached(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) ctx := context.Background() sig1, err := signing.NewEdSigner() @@ -449,7 +450,7 @@ func TestGetIDsByEpochCached(t *testing.T) { } func Test_IterateAtxsWithMalfeasance(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() e1 := types.EpochID(1) m := make(map[types.ATXID]bool) @@ -479,7 +480,7 @@ func Test_IterateAtxsWithMalfeasance(t *testing.T) { } func Test_IterateAtxIdsWithMalfeasance(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() e1 := types.EpochID(1) m := make(map[types.ATXID]bool) @@ -510,7 +511,7 @@ func Test_IterateAtxIdsWithMalfeasance(t *testing.T) { func TestVRFNonce(t *testing.T) { // Arrange - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -543,7 +544,7 @@ func TestVRFNonce(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() sig, err := signing.NewEdSigner() @@ -598,7 +599,7 @@ func TestLoadBlob(t *testing.T) { } func TestGetBlobCached(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) ctx := context.Background() sig, err := signing.NewEdSigner() @@ -620,7 +621,7 @@ func TestGetBlobCached(t *testing.T) { // Test that we don't put in the cache a reference to the blob that was passed to LoadBlob. // Each cache entry must use a unique slice for the blob. func TestGetBlobCached_CacheEntriesAreDistinct(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) atx := types.ActivationTx{AtxBlob: types.AtxBlob{Blob: []byte("original blob")}} atx.SetID(types.RandomATXID()) @@ -648,7 +649,7 @@ func TestGetBlobCached_CacheEntriesAreDistinct(t *testing.T) { } func TestCachedBlobEviction(t *testing.T) { - db := sql.InMemory( + db := statesql.InMemory( sql.WithQueryCache(true), sql.WithQueryCacheSizes(map[sql.QueryCacheKind]int{ atxs.CacheKindATXBlob: 10, @@ -689,7 +690,7 @@ func TestCachedBlobEviction(t *testing.T) { } func TestCheckpointATX(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() sig, err := signing.NewEdSigner() @@ -736,7 +737,7 @@ func TestCheckpointATX(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nonExistingATXID := types.ATXID(types.CalcHash32([]byte("0"))) _, err := atxs.Get(db, nonExistingATXID) @@ -809,7 +810,7 @@ type header struct { filteredOut bool } -func createAtx(tb testing.TB, db *sql.Database, hdr header) (types.ATXID, *signing.EdSigner) { +func createAtx(tb testing.TB, db *statesql.Database, hdr header) (types.ATXID, *signing.EdSigner) { sig, err := signing.NewEdSigner() require.NoError(tb, err) @@ -923,7 +924,7 @@ func TestGetIDWithMaxHeight(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() var sigs []*signing.EdSigner var ids []types.ATXID filtered := make(map[types.ATXID]struct{}) @@ -964,7 +965,7 @@ func TestLatest(t *testing.T) { {"out of order", []uint32{3, 4, 1, 2}, 4}, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() for i, epoch := range tc.epochs { full := &types.ActivationTx{ PublishEpoch: types.EpochID(epoch), @@ -983,7 +984,7 @@ func TestLatest(t *testing.T) { } func Test_PrevATXCollisions(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -1036,13 +1037,13 @@ func TestCoinbase(t *testing.T) { t.Parallel() t.Run("not found", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() _, err := atxs.Coinbase(db, types.NodeID{}) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("found", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) atx := newAtx(t, sig, func(a *types.ActivationTx) { a.Coinbase = types.Address{1, 2, 3} }) @@ -1053,7 +1054,7 @@ func TestCoinbase(t *testing.T) { }) t.Run("picks last", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) atx1 := newAtx(t, sig, withPublishEpoch(1), func(a *types.ActivationTx) { a.Coinbase = types.Address{1, 2, 3} }) diff --git a/sql/ballots/ballots_test.go b/sql/ballots/ballots_test.go index 624d648548..63826f4236 100644 --- a/sql/ballots/ballots_test.go +++ b/sql/ballots/ballots_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 3 @@ -26,7 +27,7 @@ func TestMain(m *testing.M) { } func TestLayer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() start := types.LayerID(1) pub := types.BytesToNodeID([]byte{1, 1, 1}) ballots := []types.Ballot{ @@ -65,7 +66,7 @@ func TestLayer(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nodeID := types.RandomNodeID() ballot := types.NewExistingBallot(types.BallotID{1}, types.RandomEdSignature(), nodeID, types.LayerID(0)) _, err := Get(db, ballot.ID()) @@ -85,7 +86,7 @@ func TestAdd(t *testing.T) { } func TestHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ballot := types.NewExistingBallot(types.BallotID{1}, types.EmptyEdSignature, types.EmptyNodeID, types.LayerID(0)) exists, err := Has(db, ballot.ID()) @@ -99,7 +100,7 @@ func TestHas(t *testing.T) { } func TestLatest(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() latest, err := LatestLayer(db) require.NoError(t, err) require.Equal(t, types.LayerID(0), latest) @@ -123,7 +124,7 @@ func TestLatest(t *testing.T) { } func TestLayerBallotBySmesher(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(1) nodeID1 := types.RandomNodeID() nodeID2 := types.RandomNodeID() @@ -158,7 +159,7 @@ func newAtx(signer *signing.EdSigner, layerID types.LayerID) *types.ActivationTx } func TestFirstInEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(layersPerEpoch * 2) sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -285,7 +286,7 @@ func TestAllFirstInEpoch(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() for _, ballot := range tc.ballots { require.NoError(t, Add(db, &ballot)) } @@ -301,7 +302,7 @@ func TestAllFirstInEpoch(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() ballot1 := types.NewExistingBallot( diff --git a/sql/beacons/beacons_test.go b/sql/beacons/beacons_test.go index 1d5648d30d..8a012e54ba 100644 --- a/sql/beacons/beacons_test.go +++ b/sql/beacons/beacons_test.go @@ -7,12 +7,13 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const baseEpoch = 3 func TestGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() beacons := []types.Beacon{ types.HexToBeacon("0x1"), @@ -35,7 +36,7 @@ func TestGet(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := Get(db, types.EpochID(baseEpoch)) require.ErrorIs(t, err, sql.ErrNotFound) @@ -50,7 +51,7 @@ func TestAdd(t *testing.T) { } func TestSet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := Get(db, types.EpochID(baseEpoch)) require.ErrorIs(t, err, sql.ErrNotFound) diff --git a/sql/blocks/blocks_test.go b/sql/blocks/blocks_test.go index 8b64d2ab65..038b542caa 100644 --- a/sql/blocks/blocks_test.go +++ b/sql/blocks/blocks_test.go @@ -10,10 +10,11 @@ import ( "github.com/spacemeshos/go-spacemesh/codec" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestAddGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() block := types.NewExistingBlock( types.BlockID{1, 1}, types.InnerBlock{LayerIndex: types.LayerID(1)}, @@ -26,7 +27,7 @@ func TestAddGet(t *testing.T) { } func TestAlreadyExists(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() block := types.NewExistingBlock( types.BlockID{1}, types.InnerBlock{}, @@ -36,7 +37,7 @@ func TestAlreadyExists(t *testing.T) { } func TestHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() block := types.NewExistingBlock( types.BlockID{1}, types.InnerBlock{}, @@ -52,7 +53,7 @@ func TestHas(t *testing.T) { } func TestValidity(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -86,7 +87,7 @@ func TestValidity(t *testing.T) { } func TestLayerFilter(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() start := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -122,7 +123,7 @@ func TestLayerFilter(t *testing.T) { } func TestLayerOrdered(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() start := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -153,7 +154,7 @@ func TestLayerOrdered(t *testing.T) { } func TestContextualValidity(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(1) blocks := []*types.Block{ types.NewExistingBlock( @@ -197,7 +198,7 @@ func TestContextualValidity(t *testing.T) { } func TestGetLayer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid1 := types.LayerID(11) block1 := types.NewExistingBlock( types.BlockID{1, 1}, @@ -222,12 +223,12 @@ func TestGetLayer(t *testing.T) { func TestLastValid(t *testing.T) { t.Run("empty", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := LastValid(db) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("all valid", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() blocks := map[types.BlockID]struct { lid types.LayerID }{ @@ -248,7 +249,7 @@ func TestLastValid(t *testing.T) { require.Equal(t, 33, int(last)) }) t.Run("last is invalid", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() blocks := map[types.BlockID]struct { invalid bool lid types.LayerID @@ -274,7 +275,7 @@ func TestLastValid(t *testing.T) { } func TestLoadBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() lid1 := types.LayerID(11) @@ -315,7 +316,7 @@ func TestLoadBlob(t *testing.T) { } func TestLayerForMangledBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := db.Exec("insert into blocks (id, layer, block) values (?1, ?2, ?3);", func(stmt *sql.Statement) { stmt.BindBytes(1, []byte(`mangled-block-id`)) diff --git a/sql/certificates/certs_test.go b/sql/certificates/certs_test.go index cfe4ff9104..2056f6b091 100644 --- a/sql/certificates/certs_test.go +++ b/sql/certificates/certs_test.go @@ -8,6 +8,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 5 @@ -44,7 +45,7 @@ func makeCert(lid types.LayerID, bid types.BlockID) *types.Certificate { } func TestCertificates(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) got, err := Get(db, lid) @@ -94,7 +95,7 @@ func TestCertificates(t *testing.T) { } func TestHareOutput(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) ho, err := GetHareOutput(db, lid) @@ -132,7 +133,7 @@ func TestHareOutput(t *testing.T) { } func TestCertifiedBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lyrBlocks := map[types.LayerID]types.BlockID{ types.LayerID(layersPerEpoch - 1): {1}, // epoch 0 @@ -161,7 +162,7 @@ func TestCertifiedBlock(t *testing.T) { } func TestDeleteCert(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, Add(db, types.LayerID(2), &types.Certificate{BlockID: types.BlockID{2}})) require.NoError(t, Add(db, types.LayerID(3), &types.Certificate{BlockID: types.BlockID{3}})) require.NoError(t, Add(db, types.LayerID(4), &types.Certificate{BlockID: types.BlockID{4}})) @@ -177,7 +178,7 @@ func TestDeleteCert(t *testing.T) { } func TestFirstInEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lyrBlocks := map[types.LayerID]types.BlockID{ types.LayerID(layersPerEpoch - 1): {1}, // epoch 0 diff --git a/sql/database.go b/sql/database.go index 799a19293f..8de21300d6 100644 --- a/sql/database.go +++ b/sql/database.go @@ -5,8 +5,6 @@ import ( "errors" "fmt" "maps" - "slices" - "sort" "strings" "sync" "sync/atomic" @@ -30,6 +28,9 @@ var ( ErrObjectExists = errors.New("database: object exists") // ErrTooNew is returned if database version is newer than expected. ErrTooNew = errors.New("database version is too new") + // ErrOld is returned when the database version differs from the expected one + // and migrations are disabled. + ErrOld = errors.New("old database version") ) const ( @@ -61,29 +62,25 @@ type Encoder func(*Statement) type Decoder func(*Statement) bool func defaultConf() *conf { - migrations, err := StateMigrations() - if err != nil { - panic(err) - } - return &conf{ - connections: 16, - migrations: migrations, - skipMigration: map[int]struct{}{}, - logger: zap.NewNop(), + enableMigrations: true, + connections: 16, + logger: zap.NewNop(), + schema: &Schema{}, } } type conf struct { - flags sqlite.OpenFlags - connections int - skipMigration map[int]struct{} - vacuumState int - migrations []Migration - enableLatency bool - cache bool - cacheSizes map[QueryCacheKind]int - logger *zap.Logger + enableMigrations bool + forceFresh bool + forceMigrations bool + connections int + vacuumState int + enableLatency bool + cache bool + cacheSizes map[QueryCacheKind]int + logger *zap.Logger + schema *Schema } // WithConnections overwrites number of pooled connections. @@ -93,48 +90,18 @@ func WithConnections(n int) Opt { } } +// WithLogger specifies logger for the database. func WithLogger(logger *zap.Logger) Opt { return func(c *conf) { c.logger = logger } } -// WithMigrations overwrites embedded migrations. -// Migrations are sorted by order before applying. -func WithMigrations(migrations []Migration) Opt { - return func(c *conf) { - sort.Slice(migrations, func(i, j int) bool { - return migrations[i].Order() < migrations[j].Order() - }) - c.migrations = migrations - } -} - -// WithMigration adds migration to the list of migrations. -// It will overwrite an existing migration with the same order. -func WithMigration(migration Migration) Opt { +// WithEnableMigrations enables or disables migrations on the database. +// The migrations are enabled by default. +func WithEnableMigrations(enable bool) Opt { return func(c *conf) { - for i, m := range c.migrations { - if m.Order() == migration.Order() { - c.migrations[i] = migration - return - } - if m.Order() > migration.Order() { - c.migrations = slices.Insert(c.migrations, i, migration) - return - } - } - c.migrations = append(c.migrations, migration) - } -} - -// WithSkipMigrations will update database version with executing associated migrations. -// It should be used at your own risk. -func WithSkipMigrations(i ...int) Opt { - return func(c *conf) { - for _, index := range i { - c.skipMigration[index] = struct{}{} - } + c.enableMigrations = enable } } @@ -172,12 +139,33 @@ func WithQueryCacheSizes(sizes map[QueryCacheKind]int) Opt { } } +// WithForceMigrations forces database to run all the migrations instead +// of using a schema snapshot in case of a fresh database. +func WithForceMigrations(force bool) Opt { + return func(c *conf) { + c.forceMigrations = true + } +} + +// WithSchema specifies database schema script. +func WithDatabaseSchema(schema *Schema) Opt { + return func(c *conf) { + c.schema = schema + } +} + +func withForceFresh(fresh bool) Opt { + return func(c *conf) { + c.forceFresh = fresh + } +} + // Opt for configuring database. type Opt func(c *conf) // InMemory database for testing. func InMemory(opts ...Opt) *Database { - opts = append(opts, WithConnections(1)) + opts = append(opts, WithConnections(1), withForceFresh(true)) db, err := Open("file::memory:?mode=memory", opts...) if err != nil { panic(err) @@ -195,75 +183,57 @@ func Open(uri string, opts ...Opt) (*Database, error) { for _, opt := range opts { opt(config) } - pool, err := sqlitex.Open(uri, config.flags, config.connections) + var flags sqlite.OpenFlags + if !config.forceFresh { + flags = sqlite.SQLITE_OPEN_READWRITE | + sqlite.SQLITE_OPEN_WAL | + sqlite.SQLITE_OPEN_URI | + sqlite.SQLITE_OPEN_NOMUTEX + } + freshDB := config.forceFresh + pool, err := sqlitex.Open(uri, flags, config.connections) if err != nil { - return nil, fmt.Errorf("open db %s: %w", uri, err) + if config.forceFresh || sqlite.ErrCode(err) != sqlite.SQLITE_CANTOPEN { + return nil, fmt.Errorf("open db %s: %w", uri, err) + } + flags |= sqlite.SQLITE_OPEN_CREATE + freshDB = true + pool, err = sqlitex.Open(uri, flags, config.connections) + if err != nil { + return nil, fmt.Errorf("create db %s: %w", uri, err) + } } db := &Database{pool: pool} if config.enableLatency { db.latency = newQueryLatency() } - //nolint:nestif - if config.migrations != nil { - before, err := version(db) - if err != nil { - return nil, err - } - after := 0 - if len(config.migrations) > 0 { - after = config.migrations[len(config.migrations)-1].Order() + if freshDB && !config.forceMigrations { + if err := config.schema.Apply(db); err != nil { + return nil, errors.Join( + fmt.Errorf("error running schema script: %w", err), + db.Close()) } - if before > after { - pool.Close() - config.logger.Error("database version is newer than expected - downgrade is not supported", - zap.String("uri", uri), - zap.Int("current version", before), - zap.Int("target version", after), - ) - return nil, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) + } else { + if err := config.schema.Migrate( + config.logger.With(zap.String("uri", uri)), + db, config.vacuumState, config.enableMigrations, + ); err != nil { + return nil, errors.Join(err, db.Close()) } - config.logger.Info("running migrations", + } + + loaded, err := LoadDBSchemaScript(db) + if err != nil { + return nil, fmt.Errorf("error loading database schema: %w", err) + } + diff := config.schema.Diff(loaded) + if diff != "" { + config.logger.Warn("database schema drift detected", zap.String("uri", uri), - zap.Int("current version", before), - zap.Int("target version", after), + zap.String("diff", diff), ) - for i, m := range config.migrations { - if m.Order() <= before { - continue - } - if err := db.WithTx(context.Background(), func(tx *Tx) error { - if _, ok := config.skipMigration[m.Order()]; !ok { - if err := m.Apply(tx); err != nil { - for j := i; j >= 0 && config.migrations[j].Order() > before; j-- { - if e := config.migrations[j].Rollback(); e != nil { - err = errors.Join(err, fmt.Errorf("rollback %s: %w", m.Name(), e)) - break - } - } - - return fmt.Errorf("apply %s: %w", m.Name(), err) - } - } - // version is set intentionally even if actual migration was skipped - if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.Order()), nil, nil); err != nil { - return fmt.Errorf("update user_version to %d: %w", m.Order(), err) - } - return nil - }); err != nil { - err = errors.Join(err, db.Close()) - return nil, err - } - - if config.vacuumState != 0 && before <= config.vacuumState { - if err := Vacuum(db); err != nil { - err = errors.Join(err, db.Close()) - return nil, err - } - } - before = m.Order() - } - } + if config.cache { config.logger.Debug("using query cache", zap.Any("sizes", config.cacheSizes)) db.queryCache = &queryCache{cacheSizesByKind: config.cacheSizes} diff --git a/sql/database_test.go b/sql/database_test.go index ab7d71438a..04b559e7db 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -12,26 +12,16 @@ import ( ) func Test_Transaction_Isolation(t *testing.T) { - ctrl := gomock.NewController(t) - testMigration := NewMockMigration(ctrl) - testMigration.EXPECT().Name().Return("test").AnyTimes() - testMigration.EXPECT().Order().Return(1).AnyTimes() - testMigration.EXPECT().Apply(gomock.Any()).DoAndReturn(func(e Executor) error { - if _, err := e.Exec(`create table testing1 ( - id varchar primary key, - field int - )`, nil, nil); err != nil { - return err - } - return nil - }) - db := InMemory( - WithMigrations([]Migration{testMigration}), WithConnections(10), WithLatencyMetering(true), + WithDatabaseSchema(&Schema{ + Script: `create table testing1 ( + id varchar primary key, + field int + );`, + }), ) - tx, err := db.Tx(context.Background()) require.NoError(t, err) @@ -74,7 +64,10 @@ func Test_Migration_Rollback(t *testing.T) { dbFile := filepath.Join(t.TempDir(), "test.sql") _, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), + WithForceMigrations(true), ) require.ErrorContains(t, err, "migration 2 failed") } @@ -88,7 +81,10 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { dbFile := filepath.Join(t.TempDir(), "test.sql") db, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1}), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -100,7 +96,9 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { migration2.EXPECT().Rollback().Return(nil) _, err = Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), ) require.ErrorContains(t, err, "migration 2 failed") } @@ -115,10 +113,14 @@ func TestDatabaseSkipMigrations(t *testing.T) { migration2.EXPECT().Order().Return(2).AnyTimes() migration2.EXPECT().Apply(gomock.Any()).Return(nil) + schema := &Schema{ + Migrations: MigrationList{migration1, migration2}, + } + schema.SkipMigrations(1) dbFile := filepath.Join(t.TempDir(), "test.sql") db, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), - WithSkipMigrations(1), + WithDatabaseSchema(schema), + WithForceMigrations(true), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -138,13 +140,18 @@ func TestDatabaseVacuumState(t *testing.T) { dbFile := filepath.Join(dir, "test.sql") db, err := Open("file:"+dbFile, - WithMigrations([]Migration{migration1}), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), ) require.NoError(t, err) require.NoError(t, db.Close()) db, err = Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), WithVacuumState(2), ) require.NoError(t, err) @@ -180,14 +187,16 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { migration2.EXPECT().Order().Return(2).AnyTimes() dbFile := filepath.Join(dir, "test.sql") - db, err := Open("file:" + dbFile) + db, err := Open("file:"+dbFile, WithForceMigrations(true)) require.NoError(t, err) _, err = db.Exec("PRAGMA user_version = 3", nil, nil) require.NoError(t, err) require.NoError(t, db.Close()) _, err = Open("file:"+dbFile, - WithMigrations([]Migration{migration1, migration2}), + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), ) require.ErrorIs(t, err, ErrTooNew) } diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index 58f5b4df4d..f5c718e05d 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -11,10 +11,11 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestMalicious(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() nodeID := types.NodeID{1, 1, 1, 1} mal, err := IsMalicious(db, nodeID) @@ -56,7 +57,7 @@ func TestMalicious(t *testing.T) { } func Test_GetMalicious(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() got, err := GetMalicious(db) require.NoError(t, err) require.Nil(t, got) @@ -74,7 +75,7 @@ func Test_GetMalicious(t *testing.T) { } func TestLoadMalfeasanceBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() nid1 := types.RandomNodeID() diff --git a/sql/layers/layers_test.go b/sql/layers/layers_test.go index f16a5e40d7..70f1f57bf1 100644 --- a/sql/layers/layers_test.go +++ b/sql/layers/layers_test.go @@ -8,6 +8,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const layersPerEpoch = 4 @@ -20,7 +21,7 @@ func TestMain(m *testing.M) { } func TestWeakCoin(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) _, err := GetWeakCoin(db, lid) @@ -38,7 +39,7 @@ func TestWeakCoin(t *testing.T) { } func TestAppliedBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) _, err := GetApplied(db, lid) @@ -67,7 +68,7 @@ func TestAppliedBlock(t *testing.T) { } func TestFirstAppliedInEpoch(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() blks := map[types.LayerID]types.BlockID{ types.EpochID(1).FirstLayer(): {1}, types.EpochID(2).FirstLayer(): types.EmptyBlockID, @@ -107,7 +108,7 @@ func TestFirstAppliedInEpoch(t *testing.T) { } func TestUnsetAppliedFrom(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid := types.LayerID(10) last := lid.Add(99) for i := lid; !i.After(last); i = i.Add(1) { @@ -123,7 +124,7 @@ func TestUnsetAppliedFrom(t *testing.T) { } func TestStateHash(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() layers := []uint32{9, 10, 8, 7} hashes := []types.Hash32{{1}, {2}, {3}, {4}} for i := range layers { @@ -147,7 +148,7 @@ func TestStateHash(t *testing.T) { } func TestSetHashes(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() _, err := GetAggregatedHash(db, types.LayerID(11)) require.ErrorIs(t, err, sql.ErrNotFound) @@ -178,7 +179,7 @@ func TestSetHashes(t *testing.T) { } func TestProcessed(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() lid, err := GetProcessed(db) require.NoError(t, err) require.Equal(t, types.LayerID(0), lid) @@ -193,7 +194,7 @@ func TestProcessed(t *testing.T) { } func TestGetAggHashes(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() hashes := make(map[types.LayerID]types.Hash32) diff --git a/sql/localsql/local.go b/sql/localsql/local.go deleted file mode 100644 index cbf1d9f2f4..0000000000 --- a/sql/localsql/local.go +++ /dev/null @@ -1,38 +0,0 @@ -package localsql - -import "github.com/spacemeshos/go-spacemesh/sql" - -type Database struct { - *sql.Database -} - -func Open(uri string, opts ...sql.Opt) (*Database, error) { - migrations, err := sql.LocalMigrations() - if err != nil { - return nil, err - } - defaultOpts := []sql.Opt{ - sql.WithConnections(16), - sql.WithMigrations(migrations), - } - opts = append(defaultOpts, opts...) - db, err := sql.Open(uri, opts...) - if err != nil { - return nil, err - } - return &Database{Database: db}, nil -} - -func InMemory(opts ...sql.Opt) *Database { - migrations, err := sql.LocalMigrations() - if err != nil { - panic(err) - } - defaultOpts := []sql.Opt{ - sql.WithConnections(1), - sql.WithMigrations(migrations), - } - opts = append(defaultOpts, opts...) - db := sql.InMemory(opts...) - return &Database{Database: db} -} diff --git a/sql/localsql/local_test.go b/sql/localsql/local_test.go deleted file mode 100644 index 1fed142b1a..0000000000 --- a/sql/localsql/local_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package localsql - -import ( - "path/filepath" - "strings" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/spacemeshos/go-spacemesh/sql" -) - -func TestDatabase_MigrateTwice_NoOp(t *testing.T) { - file := filepath.Join(t.TempDir(), "test.db") - db, err := Open("file:" + file) - require.NoError(t, err) - - var sqls1 []string - _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - sqls1 = append(sqls1, sql) - return true - }) - require.NoError(t, err) - require.NoError(t, db.Close()) - - db, err = Open("file:" + file) - require.NoError(t, err) - var sqls2 []string - _, err = db.Exec("SELECT sql FROM sqlite_schema;", nil, func(stmt *sql.Statement) bool { - sql := stmt.ColumnText(0) - sql = strings.Join(strings.Fields(sql), " ") // remove whitespace - sqls2 = append(sqls2, sql) - return true - }) - require.NoError(t, err) - require.NoError(t, db.Close()) - - require.Equal(t, sqls1, sqls2) -} diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go new file mode 100644 index 0000000000..88f41f2a17 --- /dev/null +++ b/sql/localsql/localsql.go @@ -0,0 +1,59 @@ +package localsql + +import ( + "embed" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +//go:embed schema/schema.sql schema/migrations/*.sql +var embedded embed.FS + +// Database represents a local database. +type Database struct { + *sql.Database +} + +// Schema returns the schema for the local database. +func Schema() (*sql.Schema, error) { + migrations, err := sql.LoadSQLMigrations(embedded) + if err != nil { + return nil, err + } + // NOTE: coded state migrations can be added here + // They can be a part of this localsql package + return sql.LoadSchema(embedded, migrations) +} + +// Open opens a local database. +func Open(uri string, opts ...sql.Opt) (*Database, error) { + schema, err := Schema() + if err != nil { + return nil, err + } + defaultOpts := []sql.Opt{ + sql.WithConnections(16), + sql.WithDatabaseSchema(schema), + } + opts = append(defaultOpts, opts...) + db, err := sql.Open(uri, opts...) + if err != nil { + return nil, err + } + return &Database{Database: db}, nil +} + +// Open opens an in-memory local database. +func InMemory(opts ...sql.Opt) *Database { + schema, err := Schema() + if err != nil { + panic(err) + } + defaultOpts := []sql.Opt{ + sql.WithConnections(1), + sql.WithDatabaseSchema(schema), + } + opts = append(defaultOpts, opts...) + db := sql.InMemory(opts...) + return &Database{Database: db} +} diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go new file mode 100644 index 0000000000..a6dbf958bb --- /dev/null +++ b/sql/localsql/localsql_test.go @@ -0,0 +1,74 @@ +package localsql + +import ( + "path/filepath" + "slices" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +func TestDatabase_MigrateTwice_NoOp(t *testing.T) { + file := filepath.Join(t.TempDir(), "test.db") + db, err := Open("file:"+file, sql.WithForceMigrations(true)) + require.NoError(t, err) + + sql1, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + require.NoError(t, db.Close()) + + db, err = Open("file:" + file) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + + require.Equal(t, sql1, sql2) + + var version int + _, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool { + version = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + require.NoError(t, err) + schema, err := Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), version) + + require.NoError(t, db.Close()) +} + +func TestSchema(t *testing.T) { + for _, tc := range []struct { + name string + forceMigrations bool + }{ + {name: "no migrations", forceMigrations: false}, + {name: "force migrations", forceMigrations: true}, + } { + t.Run(tc.name, func(t *testing.T) { + db := InMemory(sql.WithForceMigrations(tc.forceMigrations)) + loadedScript, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + expSchema, err := Schema() + require.NoError(t, err) + diff := expSchema.Diff(loadedScript) + if diff != "" { + s := &sql.Schema{ + Script: loadedScript, + } + require.NoError(t, s.WriteToFile(".")) + t.Logf("updated schema written to %s", sql.UpdatedSchemaPath) + } + require.Empty(t, diff, "local schema diff") + }) + } +} diff --git a/sql/migrations/local/0001_initial.sql b/sql/localsql/schema/migrations/0001_initial.sql similarity index 100% rename from sql/migrations/local/0001_initial.sql rename to sql/localsql/schema/migrations/0001_initial.sql diff --git a/sql/migrations/local/0002_extend_initial_post.sql b/sql/localsql/schema/migrations/0002_extend_initial_post.sql similarity index 100% rename from sql/migrations/local/0002_extend_initial_post.sql rename to sql/localsql/schema/migrations/0002_extend_initial_post.sql diff --git a/sql/migrations/local/0003_add_nipost_builder_state.sql b/sql/localsql/schema/migrations/0003_add_nipost_builder_state.sql similarity index 100% rename from sql/migrations/local/0003_add_nipost_builder_state.sql rename to sql/localsql/schema/migrations/0003_add_nipost_builder_state.sql diff --git a/sql/migrations/local/0004_atx_sync.sql b/sql/localsql/schema/migrations/0004_atx_sync.sql similarity index 100% rename from sql/migrations/local/0004_atx_sync.sql rename to sql/localsql/schema/migrations/0004_atx_sync.sql diff --git a/sql/migrations/local/0005_fast_startup.sql b/sql/localsql/schema/migrations/0005_fast_startup.sql similarity index 100% rename from sql/migrations/local/0005_fast_startup.sql rename to sql/localsql/schema/migrations/0005_fast_startup.sql diff --git a/sql/migrations/local/0006_prepared_activeset.sql b/sql/localsql/schema/migrations/0006_prepared_activeset.sql similarity index 100% rename from sql/migrations/local/0006_prepared_activeset.sql rename to sql/localsql/schema/migrations/0006_prepared_activeset.sql diff --git a/sql/migrations/local/0007_malfeasance_sync.sql b/sql/localsql/schema/migrations/0007_malfeasance_sync.sql similarity index 100% rename from sql/migrations/local/0007_malfeasance_sync.sql rename to sql/localsql/schema/migrations/0007_malfeasance_sync.sql diff --git a/sql/migrations/local/0008_next.sql b/sql/localsql/schema/migrations/0008_next.sql similarity index 100% rename from sql/migrations/local/0008_next.sql rename to sql/localsql/schema/migrations/0008_next.sql diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql new file mode 100755 index 0000000000..bcc8ca05c8 --- /dev/null +++ b/sql/localsql/schema/schema.sql @@ -0,0 +1,82 @@ +PRAGMA user_version = 8; +CREATE TABLE atx_sync_requests +( + epoch INT NOT NULL, + timestamp INT NOT NULL, total INTEGER, downloaded INTEGER, + PRIMARY KEY (epoch) +) WITHOUT ROWID; +CREATE TABLE atx_sync_state +( + epoch INT NOT NULL, + id CHAR(32) NOT NULL, + requests INT NOT NULL DEFAULT 0, + PRIMARY KEY (epoch, id) +) WITHOUT ROWID; +CREATE TABLE "challenge" +( + id CHAR(32) PRIMARY KEY, + epoch UNSIGNED INT NOT NULL, + sequence UNSIGNED INT NOT NULL, + prev_atx CHAR(32) NOT NULL, + pos_atx CHAR(32) NOT NULL, + commit_atx CHAR(32), + post_nonce UNSIGNED INT, + post_indices VARCHAR, + post_pow UNSIGNED LONG INT +, poet_proof_ref CHAR(32), poet_proof_membership VARCHAR) WITHOUT ROWID; +CREATE TABLE malfeasance_sync_state +( + timestamp INT NOT NULL +); +CREATE TABLE nipost +( + id CHAR(32) PRIMARY KEY, + post_nonce UNSIGNED INT NOT NULL, + post_indices VARCHAR NOT NULL, + post_pow UNSIGNED LONG INT NOT NULL, + + num_units UNSIGNED INT NOT NULL, + vrf_nonce UNSIGNED LONG INT NOT NULL, + + poet_proof_membership VARCHAR NOT NULL, + poet_proof_ref CHAR(32) NOT NULL, + labels_per_unit UNSIGNED INT NOT NULL +) WITHOUT ROWID; +CREATE TABLE poet_certificates +( + node_id BLOB NOT NULL, + certifier_id BLOB NOT NULL, + certificate BLOB NOT NULL, + signature BLOB NOT NULL +); +CREATE UNIQUE INDEX idx_poet_certificates ON poet_certificates (node_id, certifier_id); +CREATE TABLE poet_registration +( + id CHAR(32) NOT NULL, + hash CHAR(32) NOT NULL, + address VARCHAR NOT NULL, + round_id VARCHAR NOT NULL, + round_end INT NOT NULL, + + PRIMARY KEY (id, address) +) WITHOUT ROWID; +CREATE TABLE "post" +( + id CHAR(32) PRIMARY KEY, + post_nonce UNSIGNED INT NOT NULL, + post_indices VARCHAR NOT NULL, + post_pow UNSIGNED LONG INT NOT NULL, + + num_units UNSIGNED INT NOT NULL, + commit_atx CHAR(32) NOT NULL, + vrf_nonce UNSIGNED LONG INT NOT NULL +, challenge BLOB NOT NULL DEFAULT x'0000000000000000000000000000000000000000000000000000000000000000'); +CREATE TABLE prepared_activeset +( + kind UNSIGNED INT NOT NULL, + epoch UNSIGNED INT NOT NULL, + id CHAR(32) NOT NULL, + weight UNSIGNED INT NOT NULL, + data BLOB NOT NULL, + PRIMARY KEY (kind, epoch) +) WITHOUT ROWID; diff --git a/sql/metrics/prometheus.go b/sql/metrics/prometheus.go index 6a7ba2ef5f..acab0b4532 100644 --- a/sql/metrics/prometheus.go +++ b/sql/metrics/prometheus.go @@ -11,6 +11,7 @@ import ( "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/metrics" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -22,7 +23,7 @@ const ( type DBMetricsCollector struct { logger log.Logger checkInterval time.Duration - db *sql.Database + db *statesql.Database tablesList map[string]struct{} eg errgroup.Group cancel context.CancelFunc @@ -35,7 +36,7 @@ type DBMetricsCollector struct { // NewDBMetricsCollector creates new DBMetricsCollector. func NewDBMetricsCollector( ctx context.Context, - db *sql.Database, + db *statesql.Database, logger log.Logger, checkInterval time.Duration, ) *DBMetricsCollector { diff --git a/sql/migrations.go b/sql/migrations.go index 3b9a019e53..2565e295ce 100644 --- a/sql/migrations.go +++ b/sql/migrations.go @@ -3,15 +3,40 @@ package sql import ( "bufio" "bytes" - "embed" "fmt" "io/fs" + "slices" "strconv" "strings" ) -//go:embed migrations/**/*.sql -var embedded embed.FS +// MigrationList denotes a list of migrations. +type MigrationList []Migration + +// AddMigration adds a Migration to the MigrationList, overriding the migration with the +// same order number if it already exists. The function returns updated migration list. +// The state of the original migration list is undefined after calling this function. +func (l MigrationList) AddMigration(migration Migration) MigrationList { + for i, m := range l { + if m.Order() == migration.Order() { + l[i] = migration + return l + } + if m.Order() > migration.Order() { + l = slices.Insert(l, i, migration) + return l + } + } + return append(l, migration) +} + +// Version returns database version for the specified migration list. +func (l MigrationList) Version() int { + if len(l) == 0 { + return 0 + } + return l[len(l)-1].Order() +} type sqlMigration struct { order int @@ -65,18 +90,9 @@ func version(db Executor) (int, error) { return current, nil } -func StateMigrations() ([]Migration, error) { - return sqlMigrations("state") -} - -func LocalMigrations() ([]Migration, error) { - return sqlMigrations("local") -} - -func sqlMigrations(dbname string) ([]Migration, error) { - var migrations []Migration - root := fmt.Sprintf("migrations/%s", dbname) - err := fs.WalkDir(embedded, root, func(path string, d fs.DirEntry, err error) error { +func LoadSQLMigrations(fsys fs.FS) (MigrationList, error) { + var migrations MigrationList + err := fs.WalkDir(fsys, "schema/migrations", func(path string, d fs.DirEntry, err error) error { if err != nil { return fmt.Errorf("walkdir %s: %w", path, err) } @@ -91,7 +107,7 @@ func sqlMigrations(dbname string) ([]Migration, error) { if err != nil { return fmt.Errorf("invalid migration %s: %w", d.Name(), err) } - f, err := embedded.Open(path) + f, err := fsys.Open(path) if err != nil { return fmt.Errorf("read file %s: %w", path, err) } diff --git a/sql/migrations_test.go b/sql/migrations_test.go deleted file mode 100644 index 13bfb86b80..0000000000 --- a/sql/migrations_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package sql - -import ( - "slices" - "testing" - - "github.com/stretchr/testify/require" -) - -func Test_MigrationsAppliedOnce(t *testing.T) { - db := InMemory() - - var version int - _, err := db.Exec("PRAGMA user_version;", nil, func(stmt *Statement) bool { - version = stmt.ColumnInt(0) - return true - }) - require.NoError(t, err) - - migrations, err := StateMigrations() - require.NoError(t, err) - expectedVersion := slices.MaxFunc(migrations, func(a, b Migration) int { - return a.Order() - b.Order() - }) - require.Equal(t, expectedVersion.Order(), version) -} diff --git a/sql/poets/poets_test.go b/sql/poets/poets_test.go index 9545f0075b..be7d88e192 100644 --- a/sql/poets/poets_test.go +++ b/sql/poets/poets_test.go @@ -8,10 +8,11 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() refs := []types.PoetProofRef{ {0xca, 0xfe}, @@ -51,7 +52,7 @@ func TestHas(t *testing.T) { } func TestGet(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() refs := []types.PoetProofRef{ {0xca, 0xfe}, @@ -102,7 +103,7 @@ func TestGet(t *testing.T) { } func TestAdd(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ref := types.PoetProofRef{0xca, 0xfe} poet := []byte("proof0") @@ -121,7 +122,7 @@ func TestAdd(t *testing.T) { } func TestGetRef(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sids := [][]byte{ []byte("sid1"), diff --git a/sql/recovery/recovery_test.go b/sql/recovery/recovery_test.go index 976091e077..7f4e396925 100644 --- a/sql/recovery/recovery_test.go +++ b/sql/recovery/recovery_test.go @@ -6,12 +6,12 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/recovery" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestRecoveryInfo(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() restore := types.LayerID(12) got, err := recovery.CheckpointInfo(db) diff --git a/sql/rewards/rewards_test.go b/sql/rewards/rewards_test.go index 5cd5715a53..3cf7ef11a4 100644 --- a/sql/rewards/rewards_test.go +++ b/sql/rewards/rewards_test.go @@ -9,10 +9,11 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestRewards(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() var part uint64 = math.MaxUint64 / 2 lyrReward := part / 2 @@ -199,13 +200,17 @@ func TestRewards(t *testing.T) { } func Test_0008Migration_EmptyDBIsNoOp(t *testing.T) { - migrations, err := sql.StateMigrations() + schema, err := statesql.Schema() require.NoError(t, err) - sort.Slice(migrations, func(i, j int) bool { return migrations[i].Order() < migrations[j].Order() }) + sort.Slice(schema.Migrations, func(i, j int) bool { + return schema.Migrations[i].Order() < schema.Migrations[j].Order() + }) + origMigrations := schema.Migrations + schema.Migrations = schema.Migrations[:7] // apply previous migrations - db := sql.InMemory( - sql.WithMigrations(migrations[:7]), + db := statesql.InMemory( + sql.WithDatabaseSchema(schema), ) // verify that the DB is empty @@ -217,7 +222,7 @@ func Test_0008Migration_EmptyDBIsNoOp(t *testing.T) { require.NoError(t, err) // apply the migration - err = migrations[7].Apply(db) + err = origMigrations[7].Apply(db) require.NoError(t, err) // verify that db is still empty @@ -230,13 +235,17 @@ func Test_0008Migration_EmptyDBIsNoOp(t *testing.T) { } func Test_0008Migration(t *testing.T) { - migrations, err := sql.StateMigrations() + schema, err := statesql.Schema() require.NoError(t, err) - sort.Slice(migrations, func(i, j int) bool { return migrations[i].Order() < migrations[j].Order() }) + sort.Slice(schema.Migrations, func(i, j int) bool { + return schema.Migrations[i].Order() < schema.Migrations[j].Order() + }) + origMigrations := schema.Migrations + schema.Migrations = schema.Migrations[:7] // apply previous migrations - db := sql.InMemory( - sql.WithMigrations(migrations[:7]), + db := statesql.InMemory( + sql.WithDatabaseSchema(schema), ) // verify that the DB is empty @@ -279,7 +288,7 @@ func Test_0008Migration(t *testing.T) { require.NoError(t, err) // apply the migration - err = migrations[7].Apply(db) + err = origMigrations[7].Apply(db) require.NoError(t, err) // verify that one row is still present diff --git a/sql/schema.go b/sql/schema.go new file mode 100644 index 0000000000..4f495c9eb7 --- /dev/null +++ b/sql/schema.go @@ -0,0 +1,188 @@ +package sql + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + + godiffpatch "github.com/sourcegraph/go-diff-patch" + "go.uber.org/zap" +) + +const ( + SchemaPath = "schema/schema.sql" + UpdatedSchemaPath = "schema/schema.sql.updated" +) + +// LoadSchema loads the schema embedded in the executable. +func LoadSchema(fsys fs.FS, migrations []Migration) (*Schema, error) { + text, err := fs.ReadFile(fsys, SchemaPath) + if err != nil { + return nil, fmt.Errorf("error reading schema file %s: %w", SchemaPath, err) + } + return &Schema{Script: string(text), Migrations: migrations}, nil +} + +// LoadDBSchemaScript retrieves the database schema as text. +func LoadDBSchemaScript(db Executor) (string, error) { + var sb strings.Builder + version, err := version(db) + if err != nil { + return "", err + } + fmt.Fprintf(&sb, "PRAGMA user_version = %d;\n", version) + if _, err = db.Exec( + // Type is either 'index' or 'table', we want tables + // to go first. Also, we ignore _litestream tables + `select sql || ';' from sqlite_master + where sql is not null and tbl_name not like '_litestream%' + order by tbl_name, type desc, name`, + nil, func(st *Statement) bool { + fmt.Fprintln(&sb, st.ColumnText(0)) + return true + }); err != nil { + return "", err + } + return sb.String(), nil +} + +// Schema represents database schema. +type Schema struct { + Script string + Migrations MigrationList + skipMigration map[int]struct{} +} + +// Diff diffs the database schema against the actual schema. +// If there's no differences, it returns an empty string. +func (s *Schema) Diff(actualScript string) string { + if s.Script == actualScript { + return "" + } + diff := godiffpatch.GeneratePatch(SchemaPath, s.Script, actualScript) + if diff == "" { + return "" + } + return diff +} + +// WriteToFile writes the schema to the corresponding updated schema file. +func (s *Schema) WriteToFile(basedir string) error { + path := filepath.Join(basedir, UpdatedSchemaPath) + if err := os.WriteFile(path, []byte(s.Script), 0o777); err != nil { + return fmt.Errorf("error writing schema file %s: %w", path, err) + } + return nil +} + +// SkipMigrations skips the specified migrations. +func (s *Schema) SkipMigrations(i ...int) { + if s.skipMigration == nil { + s.skipMigration = make(map[int]struct{}) + } + for _, index := range i { + s.skipMigration[index] = struct{}{} + } +} + +// Apply applies the schema to the database. +func (s *Schema) Apply(db *Database) error { + return db.WithTx(context.Background(), func(tx *Tx) error { + scanner := bufio.NewScanner(strings.NewReader(s.Script)) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if i := bytes.Index(data, []byte(";")); i >= 0 { + return i + 1, data[0 : i+1], nil + } + return 0, nil, nil + }) + for scanner.Scan() { + if _, err := tx.Exec(scanner.Text(), nil, nil); err != nil { + return fmt.Errorf("exec %s: %w", scanner.Text(), err) + } + } + return nil + }) +} + +// Migrate performs database migration. In case if migrations are disabled, the database +// version is checked but no migrations are run, and if the database is too old and +// migrations are disabled, an error is returned. +func (s *Schema) Migrate(logger *zap.Logger, db *Database, vacuumState int, enable bool) error { + if len(s.Migrations) == 0 { + return nil + } + before, err := version(db) + if err != nil { + return err + } + after := 0 + if len(s.Migrations) > 0 { + after = s.Migrations.Version() + } + if before > after { + logger.Error("database version is newer than expected - downgrade is not supported", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) + } + + if before == after { + return nil + } + + if !enable { + logger.Error("database version is too old", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return fmt.Errorf("%w: %d < %d", ErrOld, before, after) + } + + logger.Info("running migrations", + zap.Int("current version", before), + zap.Int("target version", after), + ) + for i, m := range s.Migrations { + if m.Order() <= before { + continue + } + if err := db.WithTx(context.Background(), func(tx *Tx) error { + if _, ok := s.skipMigration[m.Order()]; !ok { + if err := m.Apply(tx); err != nil { + for j := i; j >= 0 && s.Migrations[j].Order() > before; j-- { + if e := s.Migrations[j].Rollback(); e != nil { + err = errors.Join(err, fmt.Errorf("rollback %s: %w", m.Name(), e)) + break + } + } + + return fmt.Errorf("apply %s: %w", m.Name(), err) + } + } + // version is set intentionally even if actual migration was skipped + if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.Order()), nil, nil); err != nil { + return fmt.Errorf("update user_version to %d: %w", m.Order(), err) + } + return nil + }); err != nil { + err = errors.Join(err, db.Close()) + return err + } + + if vacuumState != 0 && before <= vacuumState { + if err := Vacuum(db); err != nil { + err = errors.Join(err, db.Close()) + return err + } + } + before = m.Order() + } + return nil +} diff --git a/sql/migrations/state/0001_initial.sql b/sql/statesql/schema/migrations/0001_initial.sql similarity index 100% rename from sql/migrations/state/0001_initial.sql rename to sql/statesql/schema/migrations/0001_initial.sql diff --git a/sql/migrations/state/0002_v1.0.3.sql b/sql/statesql/schema/migrations/0002_v1.0.3.sql similarity index 100% rename from sql/migrations/state/0002_v1.0.3.sql rename to sql/statesql/schema/migrations/0002_v1.0.3.sql diff --git a/sql/migrations/state/0003_v1.1.5.sql b/sql/statesql/schema/migrations/0003_v1.1.5.sql similarity index 100% rename from sql/migrations/state/0003_v1.1.5.sql rename to sql/statesql/schema/migrations/0003_v1.1.5.sql diff --git a/sql/migrations/state/0004_v1.1.7.sql b/sql/statesql/schema/migrations/0004_v1.1.7.sql similarity index 100% rename from sql/migrations/state/0004_v1.1.7.sql rename to sql/statesql/schema/migrations/0004_v1.1.7.sql diff --git a/sql/migrations/state/0005_v1.2.0.sql b/sql/statesql/schema/migrations/0005_v1.2.0.sql similarity index 100% rename from sql/migrations/state/0005_v1.2.0.sql rename to sql/statesql/schema/migrations/0005_v1.2.0.sql diff --git a/sql/migrations/state/0006_v1.2.2.sql b/sql/statesql/schema/migrations/0006_v1.2.2.sql similarity index 100% rename from sql/migrations/state/0006_v1.2.2.sql rename to sql/statesql/schema/migrations/0006_v1.2.2.sql diff --git a/sql/migrations/state/0007_v1.3.0.sql b/sql/statesql/schema/migrations/0007_v1.3.0.sql similarity index 100% rename from sql/migrations/state/0007_v1.3.0.sql rename to sql/statesql/schema/migrations/0007_v1.3.0.sql diff --git a/sql/migrations/state/0008_rewards.sql b/sql/statesql/schema/migrations/0008_rewards.sql similarity index 100% rename from sql/migrations/state/0008_rewards.sql rename to sql/statesql/schema/migrations/0008_rewards.sql diff --git a/sql/migrations/state/0009_prune_activesets.sql b/sql/statesql/schema/migrations/0009_prune_activesets.sql similarity index 100% rename from sql/migrations/state/0009_prune_activesets.sql rename to sql/statesql/schema/migrations/0009_prune_activesets.sql diff --git a/sql/migrations/state/0010_rowid.sql b/sql/statesql/schema/migrations/0010_rowid.sql similarity index 100% rename from sql/migrations/state/0010_rowid.sql rename to sql/statesql/schema/migrations/0010_rowid.sql diff --git a/sql/migrations/state/0011_atxs_extra_index.sql b/sql/statesql/schema/migrations/0011_atxs_extra_index.sql similarity index 100% rename from sql/migrations/state/0011_atxs_extra_index.sql rename to sql/statesql/schema/migrations/0011_atxs_extra_index.sql diff --git a/sql/migrations/state/0012_atx_validity.sql b/sql/statesql/schema/migrations/0012_atx_validity.sql similarity index 100% rename from sql/migrations/state/0012_atx_validity.sql rename to sql/statesql/schema/migrations/0012_atx_validity.sql diff --git a/sql/migrations/state/0013_atx_coinbase_index.sql b/sql/statesql/schema/migrations/0013_atx_coinbase_index.sql similarity index 100% rename from sql/migrations/state/0013_atx_coinbase_index.sql rename to sql/statesql/schema/migrations/0013_atx_coinbase_index.sql diff --git a/sql/migrations/state/0014_remove_proposals.sql b/sql/statesql/schema/migrations/0014_remove_proposals.sql similarity index 100% rename from sql/migrations/state/0014_remove_proposals.sql rename to sql/statesql/schema/migrations/0014_remove_proposals.sql diff --git a/sql/migrations/state/0015_nonce_index.sql b/sql/statesql/schema/migrations/0015_nonce_index.sql similarity index 100% rename from sql/migrations/state/0015_nonce_index.sql rename to sql/statesql/schema/migrations/0015_nonce_index.sql diff --git a/sql/migrations/state/0016_atx_blob.sql b/sql/statesql/schema/migrations/0016_atx_blob.sql similarity index 100% rename from sql/migrations/state/0016_atx_blob.sql rename to sql/statesql/schema/migrations/0016_atx_blob.sql diff --git a/sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql b/sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql similarity index 100% rename from sql/migrations/state/0017_atxs_prev_id_nonce_placeholder.sql rename to sql/statesql/schema/migrations/0017_atxs_prev_id_nonce_placeholder.sql diff --git a/sql/migrations/state/0018_atx_blob_version.sql b/sql/statesql/schema/migrations/0018_atx_blob_version.sql similarity index 100% rename from sql/migrations/state/0018_atx_blob_version.sql rename to sql/statesql/schema/migrations/0018_atx_blob_version.sql diff --git a/sql/statesql/schema/schema.sql b/sql/statesql/schema/schema.sql new file mode 100755 index 0000000000..85e381f87b --- /dev/null +++ b/sql/statesql/schema/schema.sql @@ -0,0 +1,151 @@ +PRAGMA user_version = 18; +CREATE TABLE accounts +( + address CHAR(24), + balance UNSIGNED LONG INT, + next_nonce UNSIGNED LONG INT, + layer_updated UNSIGNED LONG INT, + template CHAR(24), + state BLOB, + PRIMARY KEY (address, layer_updated DESC) +); +CREATE INDEX accounts_by_layer_updated ON accounts (layer_updated); +CREATE TABLE activesets +( + id CHAR(32) PRIMARY KEY, + active_set BLOB +, epoch INT DEFAULT 0 NOT NULL) WITHOUT ROWID; +CREATE INDEX activesets_by_epoch ON activesets (epoch asc); +CREATE TABLE atx_blobs +( + id CHAR(32), + atx BLOB +, version INTEGER); +CREATE UNIQUE INDEX atx_blobs_id ON atx_blobs (id); +CREATE TABLE atxs +( + id CHAR(32), + prev_id CHAR(32), + epoch INT NOT NULL, + effective_num_units INT NOT NULL, + commitment_atx CHAR(32), + nonce UNSIGNED LONG INT, + base_tick_height UNSIGNED LONG INT, + tick_count UNSIGNED LONG INT, + sequence UNSIGNED LONG INT, + pubkey CHAR(32), + coinbase CHAR(24), + received INT NOT NULL, + validity INTEGER DEFAULT false +); +CREATE INDEX atxs_by_coinbase ON atxs (coinbase); +CREATE INDEX atxs_by_epoch_by_pubkey ON atxs (epoch, pubkey); +CREATE INDEX atxs_by_epoch_by_pubkey_nonce ON atxs (pubkey, epoch desc, nonce) WHERE nonce IS NOT NULL; +CREATE INDEX atxs_by_epoch_id on atxs (epoch, id); +CREATE INDEX atxs_by_pubkey_by_epoch_desc ON atxs (pubkey, epoch desc); +CREATE UNIQUE INDEX atxs_id ON atxs (id); +CREATE TABLE ballots +( + id CHAR(20) PRIMARY KEY, + atx CHAR(32) NOT NULL, + layer INT NOT NULL, + pubkey VARCHAR, + ballot BLOB +); +CREATE INDEX ballots_by_atx_by_layer ON ballots (atx, layer asc); +CREATE INDEX ballots_by_layer_by_pubkey ON ballots (layer asc, pubkey); +CREATE TABLE beacons +( + epoch INT NOT NULL PRIMARY KEY, + beacon CHAR(4) +) WITHOUT ROWID; +CREATE TABLE block_transactions +( + tid CHAR(32), + bid CHAR(20), + layer INT NOT NULL, + PRIMARY KEY (tid, bid) +) WITHOUT ROWID; +CREATE TABLE blocks +( + id CHAR(20) PRIMARY KEY, + layer INT NOT NULL, + validity SMALL INT, + block BLOB +); +CREATE INDEX blocks_by_layer ON blocks (layer, id asc); +CREATE TABLE certificates +( + layer INT NOT NULL, + block VARCHAR NOT NULL, + cert BLOB, + valid bool NOT NULL, + PRIMARY KEY (layer, block) +); +CREATE TABLE identities +( + pubkey VARCHAR PRIMARY KEY, + proof BLOB +, received INT DEFAULT 0 NOT NULL) WITHOUT ROWID; +CREATE TABLE layers +( + id INT PRIMARY KEY DESC, + weak_coin SMALL INT, + processed SMALL INT, + applied_block VARCHAR, + state_hash CHAR(32), + aggregated_hash CHAR(32) +) WITHOUT ROWID; +CREATE INDEX layers_by_processed ON layers (processed); +CREATE TABLE poets +( + ref VARCHAR PRIMARY KEY, + poet BLOB, + service_id VARCHAR, + round_id VARCHAR +); +CREATE INDEX poets_by_service_id_by_round_id ON poets (service_id, round_id); +CREATE TABLE proposal_transactions +( + tid CHAR(32), + pid CHAR(20), + layer INT NOT NULL, + PRIMARY KEY (tid, pid) +) WITHOUT ROWID; +CREATE TABLE recovery +( + id INTEGER PRIMARY KEY CHECK (id = 1), + restore INT NOT NULL +); +CREATE TABLE rewards +( + pubkey CHAR(32), + coinbase CHAR(24) NOT NULL, + layer INT NOT NULL, + total_reward UNSIGNED LONG INT, + layer_reward UNSIGNED LONG INT, + PRIMARY KEY (pubkey, layer) +); +CREATE INDEX rewards_by_coinbase ON rewards (coinbase, layer); +CREATE INDEX rewards_by_layer ON rewards (layer asc); +CREATE TABLE transactions +( + id CHAR(32) PRIMARY KEY, + tx BLOB, + header BLOB, + result BLOB, + layer INT, + block CHAR(20), + principal CHAR(24), + nonce BLOB, + timestamp INT NOT NULL +) WITHOUT ROWID; +CREATE INDEX transaction_by_layer_principal ON transactions (layer asc, principal); +CREATE INDEX transaction_by_principal_nonce ON transactions (principal, nonce); +CREATE TABLE transactions_results_addresses +( + address CHAR(24), + tid CHAR(32), + PRIMARY KEY (tid, address) +) WITHOUT ROWID; +CREATE INDEX transactions_results_addresses_by_address ON transactions_results_addresses(address); diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go new file mode 100644 index 0000000000..6536243498 --- /dev/null +++ b/sql/statesql/statesql.go @@ -0,0 +1,58 @@ +package statesql + +import ( + "embed" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +//go:embed schema/schema.sql schema/migrations/*.sql +var embedded embed.FS + +// Database represents a state database. +type Database struct { + *sql.Database +} + +// Schema returns the schema for the state database. +func Schema() (*sql.Schema, error) { + migrations, err := sql.LoadSQLMigrations(embedded) + if err != nil { + return nil, err + } + // NOTE: coded state migrations can be added here + // They can be a part of this statesql package + return sql.LoadSchema(embedded, migrations) +} + +// Open opens a state database. +func Open(uri string, opts ...sql.Opt) (*Database, error) { + schema, err := Schema() + if err != nil { + return nil, err + } + opts = append([]sql.Opt{sql.WithDatabaseSchema(schema)}, opts...) + db, err := sql.Open(uri, opts...) + if err != nil { + return nil, err + } + return &Database{Database: db}, nil +} + +// Open opens an in-memory state database. +func InMemory(opts ...sql.Opt) *Database { + schema, err := Schema() + if err != nil { + panic(err) + } + defaultOpts := []sql.Opt{ + sql.WithDatabaseSchema(schema), + } + opts = append(defaultOpts, opts...) + db := sql.InMemory(opts...) + return &Database{Database: db} +} + +// TBD: QQQQQ: check disabling migrations in database_test.go +// TBD: QQQQQ: add sql/test package with test skeletons +// TBD: QQQQQ: verify identity merging code diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go new file mode 100644 index 0000000000..d13a76671e --- /dev/null +++ b/sql/statesql/statesql_test.go @@ -0,0 +1,72 @@ +package statesql + +import ( + "path/filepath" + "slices" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +func TestDatabase_MigrateTwice_NoOp(t *testing.T) { + file := filepath.Join(t.TempDir(), "test.db") + db, err := Open("file:"+file, sql.WithForceMigrations(true)) + require.NoError(t, err) + + sql1, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + require.NoError(t, db.Close()) + + db, err = Open("file:" + file) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + + require.Equal(t, sql1, sql2) + + var version int + _, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool { + version = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + require.NoError(t, err) + schema, err := Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), version) + + require.NoError(t, db.Close()) +} + +func TestSchema(t *testing.T) { + for _, tc := range []struct { + name string + forceMigrations bool + }{ + {name: "no migrations", forceMigrations: false}, + {name: "force migrations", forceMigrations: true}, + } { + t.Run(tc.name, func(t *testing.T) { + db := InMemory(sql.WithForceMigrations(tc.forceMigrations)) + loadedScript, err := sql.LoadDBSchemaScript(db) + require.NoError(t, err) + expSchema, err := Schema() + require.NoError(t, err) + diff := expSchema.Diff(loadedScript) + if diff != "" { + s := &sql.Schema{Script: loadedScript} + require.NoError(t, s.WriteToFile(".")) + t.Logf("updated schema written to %s", sql.UpdatedSchemaPath) + } + require.Empty(t, diff, "schema diff") + }) + } +} diff --git a/sql/transactions/iterator_test.go b/sql/transactions/iterator_test.go index fd3cad3a8e..4b9bdb0dcb 100644 --- a/sql/transactions/iterator_test.go +++ b/sql/transactions/iterator_test.go @@ -15,6 +15,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/fixture" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func matchTx(tx types.TransactionWithResult, filter ResultsFilter) bool { @@ -59,7 +60,7 @@ func filterTxs(txs []types.TransactionWithResult, filter ResultsFilter) []types. } func TestIterateResults(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() gen := fixture.NewTransactionResultGenerator() txs := make([]types.TransactionWithResult, 100) @@ -142,7 +143,7 @@ func TestIterateResults(t *testing.T) { } func TestIterateSnapshot(t *testing.T) { - db, err := sql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) + db, err := statesql.Open("file:" + filepath.Join(t.TempDir(), "test.sql")) t.Cleanup(func() { require.NoError(t, db.Close()) }) require.NoError(t, err) gen := fixture.NewTransactionResultGenerator() diff --git a/sql/transactions/transactions_test.go b/sql/transactions/transactions_test.go index ed2eef3cc9..dc5d4714f8 100644 --- a/sql/transactions/transactions_test.go +++ b/sql/transactions/transactions_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/sdk/wallet" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -70,7 +71,7 @@ func checkMeshTXEqual(t *testing.T, expected, got types.MeshTransaction) { } func TestAddGetHas(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer1, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -109,7 +110,7 @@ func TestAddGetHas(t *testing.T) { } func TestAddUpdatesHeader(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() txs := []*types.Transaction{ { RawTx: types.NewRawTx([]byte{1, 2, 3}), @@ -142,7 +143,7 @@ func TestAddUpdatesHeader(t *testing.T) { } func TestAddToProposal(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -166,7 +167,7 @@ func TestAddToProposal(t *testing.T) { } func TestDeleteProposalTxs(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() proposals := map[types.LayerID][]types.ProposalID{ types.LayerID(10): {{1, 1}, {1, 2}}, types.LayerID(11): {{2, 1}, {2, 2}}, @@ -197,7 +198,7 @@ func TestDeleteProposalTxs(t *testing.T) { } func TestAddToBlock(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -221,7 +222,7 @@ func TestAddToBlock(t *testing.T) { } func TestApply_AlreadyApplied(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) lid := types.LayerID(10) @@ -251,7 +252,7 @@ func TestApply_AlreadyApplied(t *testing.T) { } func TestUndoLayers_Empty(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { return transactions.UndoLayers(dtx, types.LayerID(199)) @@ -259,7 +260,7 @@ func TestUndoLayers_Empty(t *testing.T) { } func TestApplyAndUndoLayers(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) firstLayer := types.LayerID(10) @@ -300,7 +301,7 @@ func TestApplyAndUndoLayers(t *testing.T) { } func TestGetBlob(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ctx := context.Background() rng := rand.New(rand.NewSource(1001)) @@ -333,7 +334,7 @@ func TestGetBlob(t *testing.T) { } func TestGetByAddress(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer1, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -369,7 +370,7 @@ func TestGetByAddress(t *testing.T) { } func TestGetAcctPendingFromNonce(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) @@ -404,7 +405,7 @@ func TestGetAcctPendingFromNonce(t *testing.T) { } func TestAppliedLayer(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() rng := rand.New(rand.NewSource(1001)) signer, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) require.NoError(t, err) @@ -455,7 +456,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { TxHeader: &types.TxHeader{Principal: principals[1], Nonce: 0}, }, } - db := sql.InMemory() + db := statesql.InMemory() for _, tx := range txs { require.NoError(t, transactions.Add(db, &tx, time.Time{})) } @@ -520,7 +521,7 @@ func TestTransactionInProposal(t *testing.T) { {2}, {3}, } - db := sql.InMemory() + db := statesql.InMemory() for i := range lids { require.NoError(t, transactions.AddToProposal(db, tid, lids[i], pids[i])) } @@ -546,7 +547,7 @@ func TestTransactionInBlock(t *testing.T) { {2}, {3}, } - db := sql.InMemory() + db := statesql.InMemory() for i := range lids { require.NoError(t, transactions.AddToBlock(db, tid, lids[i], bids[i])) } diff --git a/syncer/atxsync/atxsync.go b/syncer/atxsync/atxsync.go index 41ac7ec7c0..ac5309dcb4 100644 --- a/syncer/atxsync/atxsync.go +++ b/syncer/atxsync/atxsync.go @@ -9,12 +9,12 @@ import ( "go.uber.org/zap/zapcore" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) -func getMissing(db *sql.Database, set []types.ATXID) ([]types.ATXID, error) { +func getMissing(db *statesql.Database, set []types.ATXID) ([]types.ATXID, error) { missing := []types.ATXID{} for _, atx := range set { exist, err := atxs.Has(db, atx) @@ -35,7 +35,7 @@ func Download( ctx context.Context, retryInterval time.Duration, logger *zap.Logger, - db *sql.Database, + db *statesql.Database, fetcher system.AtxFetcher, set []types.ATXID, ) error { diff --git a/syncer/atxsync/atxsync_test.go b/syncer/atxsync/atxsync_test.go index 33f649c727..6804e782cc 100644 --- a/syncer/atxsync/atxsync_test.go +++ b/syncer/atxsync/atxsync_test.go @@ -11,8 +11,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/log/logtest" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -97,7 +97,7 @@ func TestDownload(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { logger := logtest.New(t) - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(t) fetcher := mocks.NewMockAtxFetcher(ctrl) for _, atx := range tc.existing { diff --git a/syncer/atxsync/syncer_test.go b/syncer/atxsync/syncer_test.go index 79165f950d..bfa3f53410 100644 --- a/syncer/atxsync/syncer_test.go +++ b/syncer/atxsync/syncer_test.go @@ -14,10 +14,10 @@ import ( "github.com/spacemeshos/go-spacemesh/fetch" "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/atxsync" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/atxsync/mocks" "github.com/spacemeshos/go-spacemesh/system" ) @@ -42,7 +42,7 @@ func edata(ids ...string) *fetch.EpochData { func newTester(tb testing.TB, cfg Config) *tester { localdb := localsql.InMemory() - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(tb) fetcher := mocks.NewMockfetcher(ctrl) syncer := New(fetcher, db, localdb, WithConfig(cfg), WithLogger(logtest.New(tb).Zap())) @@ -61,7 +61,7 @@ type tester struct { tb testing.TB syncer *Syncer localdb *localsql.Database - db *sql.Database + db *statesql.Database cfg Config ctrl *gomock.Controller fetcher *mocks.Mockfetcher diff --git a/syncer/find_fork_test.go b/syncer/find_fork_test.go index 8400f0d46a..a80c8541cc 100644 --- a/syncer/find_fork_test.go +++ b/syncer/find_fork_test.go @@ -17,21 +17,21 @@ import ( "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer" "github.com/spacemeshos/go-spacemesh/syncer/mocks" ) type testForkFinder struct { *syncer.ForkFinder - db *sql.Database + db *statesql.Database mFetcher *mocks.Mockfetcher } func newTestForkFinderWithDuration(t *testing.T, d time.Duration, lg log.Log) *testForkFinder { mf := mocks.NewMockfetcher(gomock.NewController(t)) - db := sql.InMemory() + db := statesql.InMemory() require.NoError(t, layers.SetMeshHash(db, types.GetEffectiveGenesis(), types.RandomHash())) return &testForkFinder{ ForkFinder: syncer.NewForkFinder(lg, db, mf, d), @@ -88,7 +88,7 @@ func layerHash(layer int, good bool) types.Hash32 { return h2 } -func storeNodeHashes(t *testing.T, db *sql.Database, diverge, max int) { +func storeNodeHashes(t *testing.T, db *statesql.Database, diverge, max int) { for lid := 0; lid <= max; lid++ { if lid < diverge { require.NoError(t, layers.SetMeshHash(db, types.LayerID(uint32(lid)), layerHash(lid, true))) diff --git a/syncer/malsync/syncer_test.go b/syncer/malsync/syncer_test.go index f4297a7e87..4a9ec6f300 100644 --- a/syncer/malsync/syncer_test.go +++ b/syncer/malsync/syncer_test.go @@ -20,9 +20,9 @@ import ( "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/malsync/mocks" ) @@ -138,7 +138,7 @@ type tester struct { tb testing.TB syncer *Syncer localdb *localsql.Database - db *sql.Database + db *statesql.Database cfg Config ctrl *gomock.Controller fetcher *mocks.Mockfetcher @@ -151,7 +151,7 @@ type tester struct { func newTester(tb testing.TB, cfg Config) *tester { localdb := localsql.InMemory() - db := sql.InMemory() + db := statesql.InMemory() ctrl := gomock.NewController(tb) fetcher := mocks.NewMockfetcher(ctrl) clock := clockwork.NewFakeClock() diff --git a/syncer/syncer_test.go b/syncer/syncer_test.go index efa7caf5ae..4a591939ea 100644 --- a/syncer/syncer_test.go +++ b/syncer/syncer_test.go @@ -20,8 +20,8 @@ import ( "github.com/spacemeshos/go-spacemesh/mesh" mmocks "github.com/spacemeshos/go-spacemesh/mesh/mocks" "github.com/spacemeshos/go-spacemesh/p2p" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/certificates" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer/mocks" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -127,7 +127,7 @@ func newTestSyncer(t *testing.T, interval time.Duration) *testSyncer { mCertHdr: mocks.NewMockcertHandler(ctrl), mForkFinder: mocks.NewMockforkFinder(ctrl), } - db := sql.InMemory() + db := statesql.InMemory() ts.cdb = datastore.NewCachedDB(db, lg.Zap()) var err error atxsdata := atxsdata.New() diff --git a/systest/tests/distributed_post_verification_test.go b/systest/tests/distributed_post_verification_test.go index bd03703bd2..62042797c8 100644 --- a/systest/tests/distributed_post_verification_test.go +++ b/systest/tests/distributed_post_verification_test.go @@ -31,9 +31,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/handshake" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/systest/cluster" "github.com/spacemeshos/go-spacemesh/systest/testcontext" "github.com/spacemeshos/go-spacemesh/timesync" @@ -112,7 +112,7 @@ func TestPostMalfeasanceProof(t *testing.T) { postSetupMgr, err := activation.NewPostSetupManager( cfg.POST, logger.Named("post"), - datastore.NewCachedDB(sql.InMemory(), zap.NewNop()), + datastore.NewCachedDB(statesql.InMemory(), zap.NewNop()), atxsdata.New(), cl.GoldenATX(), syncer, @@ -155,7 +155,7 @@ func TestPostMalfeasanceProof(t *testing.T) { require.NoError(t, grpcPrivateServer.Start()) t.Cleanup(func() { assert.NoError(t, grpcPrivateServer.Close()) }) - db := sql.InMemory() + db := statesql.InMemory() localDb := localsql.InMemory() certClient := activation.NewCertifierClient(db, localDb, logger.Named("certifier")) certifier := activation.NewCertifier(localDb, logger, certClient) diff --git a/tortoise/model/core.go b/tortoise/model/core.go index ce7022fa33..d8c622f7df 100644 --- a/tortoise/model/core.go +++ b/tortoise/model/core.go @@ -19,6 +19,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -28,7 +29,7 @@ const ( ) func newCore(rng *rand.Rand, id string, logger *zap.Logger) *core { - cdb := datastore.NewCachedDB(sql.InMemory(), logger) + cdb := datastore.NewCachedDB(statesql.InMemory(), logger) sig, err := signing.NewEdSigner(signing.WithKeyFromRand(rng)) if err != nil { panic(err) diff --git a/tortoise/replay/replay_test.go b/tortoise/replay/replay_test.go index 77e71079ee..9f6dd399fa 100644 --- a/tortoise/replay/replay_test.go +++ b/tortoise/replay/replay_test.go @@ -15,8 +15,8 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/config" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/timesync" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -50,7 +50,7 @@ func TestReplayMainnet(t *testing.T) { ) require.NoError(t, err) - db, err := sql.Open(fmt.Sprintf("file:%s?mode=ro", *dbpath)) + db, err := statesql.Open(fmt.Sprintf("file:%s?mode=ro", *dbpath)) require.NoError(t, err) applied, err := layers.GetLastApplied(db) diff --git a/tortoise/sim/utils.go b/tortoise/sim/utils.go index e7f66f0664..c2385e4310 100644 --- a/tortoise/sim/utils.go +++ b/tortoise/sim/utils.go @@ -7,7 +7,7 @@ import ( "go.uber.org/zap" "github.com/spacemeshos/go-spacemesh/datastore" - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -16,13 +16,13 @@ const ( func newCacheDB(logger *zap.Logger, conf config) *datastore.CachedDB { var ( - db *sql.Database + db *statesql.Database err error ) if len(conf.Path) == 0 { - db = sql.InMemory() + db = statesql.InMemory() } else { - db, err = sql.Open(filepath.Join(conf.Path, atxpath)) + db, err = statesql.Open(filepath.Join(conf.Path, atxpath)) if err != nil { panic(err) } diff --git a/tortoise/threshold_test.go b/tortoise/threshold_test.go index 3b414ef06d..4b874ebc78 100644 --- a/tortoise/threshold_test.go +++ b/tortoise/threshold_test.go @@ -8,8 +8,8 @@ import ( "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func TestComputeThreshold(t *testing.T) { @@ -165,7 +165,7 @@ func TestReferenceHeight(t *testing.T) { }, } { t.Run(tc.desc, func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() for i, height := range tc.heights { atx := &types.ActivationTx{ PublishEpoch: types.EpochID(tc.epoch) - 1, diff --git a/tortoise/tortoise_test.go b/tortoise/tortoise_test.go index d8b2de1b9f..b34c25a225 100644 --- a/tortoise/tortoise_test.go +++ b/tortoise/tortoise_test.go @@ -28,6 +28,7 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/builder" "github.com/spacemeshos/go-spacemesh/sql/certificates" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/tortoise/opinionhash" "github.com/spacemeshos/go-spacemesh/tortoise/sim" ) @@ -467,7 +468,7 @@ func TestComputeExpectedWeight(t *testing.T) { } { t.Run(tc.desc, func(t *testing.T) { var ( - db = sql.InMemory() + db = statesql.InMemory() epochs = map[types.EpochID]*epochInfo{} first = tc.target.Add(1).GetEpoch() ) diff --git a/txs/cache.go b/txs/cache.go index f15c10d3bb..8a0594ea4f 100644 --- a/txs/cache.go +++ b/txs/cache.go @@ -19,6 +19,7 @@ import ( "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -332,7 +333,7 @@ func (ac *accountCache) add(logger *zap.Logger, tx *types.Transaction, received func (ac *accountCache) addPendingFromNonce( logger *zap.Logger, - db *sql.Database, + db *statesql.Database, nonce uint64, applied types.LayerID, ) error { @@ -423,7 +424,7 @@ func (ac *accountCache) getMempool(logger *zap.Logger) []*NanoTX { // because applying a layer changes the conservative balance in the cache. func (ac *accountCache) resetAfterApply( logger *zap.Logger, - db *sql.Database, + db *statesql.Database, nextNonce, newBalance uint64, applied types.LayerID, ) error { @@ -489,7 +490,7 @@ func groupTXsByPrincipal(logger *zap.Logger, mtxs []*types.MeshTransaction) map[ } // buildFromScratch builds the cache from database. -func (c *Cache) buildFromScratch(db *sql.Database) error { +func (c *Cache) buildFromScratch(db *statesql.Database) error { applied, err := layers.GetLastApplied(db) if err != nil { return fmt.Errorf("cache: get pending %w", err) @@ -606,7 +607,7 @@ func acceptable(err error) bool { func (c *Cache) Add( ctx context.Context, - db *sql.Database, + db *statesql.Database, tx *types.Transaction, received time.Time, mustPersist bool, @@ -653,7 +654,7 @@ func (c *Cache) has(tid types.TransactionID) bool { // LinkTXsWithProposal associates the transactions to a proposal. func (c *Cache) LinkTXsWithProposal( - db *sql.Database, + db *statesql.Database, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID, @@ -670,7 +671,7 @@ func (c *Cache) LinkTXsWithProposal( // LinkTXsWithBlock associates the transactions to a block. func (c *Cache) LinkTXsWithBlock( - db *sql.Database, + db *statesql.Database, lid types.LayerID, bid types.BlockID, tids []types.TransactionID, @@ -702,7 +703,7 @@ func (c *Cache) updateLayer(lid types.LayerID, bid types.BlockID, tids []types.T return nil } -func (c *Cache) applyEmptyLayer(db *sql.Database, lid types.LayerID) error { +func (c *Cache) applyEmptyLayer(db *statesql.Database, lid types.LayerID) error { c.mu.Lock() defer c.mu.Unlock() @@ -721,7 +722,7 @@ func (c *Cache) applyEmptyLayer(db *sql.Database, lid types.LayerID) error { // ApplyLayer retires the applied transactions from the cache and updates the balances. func (c *Cache) ApplyLayer( ctx context.Context, - db *sql.Database, + db *statesql.Database, lid types.LayerID, bid types.BlockID, results []types.TransactionWithResult, @@ -838,7 +839,7 @@ func (c *Cache) ApplyLayer( return nil } -func (c *Cache) RevertToLayer(db *sql.Database, revertTo types.LayerID) error { +func (c *Cache) RevertToLayer(db *statesql.Database, revertTo types.LayerID) error { if err := undoLayers(db, revertTo.Add(1)); err != nil { return err } @@ -879,7 +880,7 @@ func (c *Cache) GetMempool(logger *zap.Logger) map[types.Address][]*NanoTX { } // checkApplyOrder returns an error if layers were not applied in order. -func checkApplyOrder(logger *zap.Logger, db *sql.Database, toApply types.LayerID) error { +func checkApplyOrder(logger *zap.Logger, db *statesql.Database, toApply types.LayerID) error { lastApplied, err := layers.GetLastApplied(db) if err != nil { logger.Error("failed to get last applied layer", zap.Error(err)) @@ -895,7 +896,7 @@ func checkApplyOrder(logger *zap.Logger, db *sql.Database, toApply types.LayerID return nil } -func addToProposal(db *sql.Database, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID) error { +func addToProposal(db *statesql.Database, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID) error { return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { for _, tid := range tids { if err := transactions.AddToProposal(dbtx, tid, lid, pid); err != nil { @@ -906,7 +907,7 @@ func addToProposal(db *sql.Database, lid types.LayerID, pid types.ProposalID, ti }) } -func addToBlock(db *sql.Database, lid types.LayerID, bid types.BlockID, tids []types.TransactionID) error { +func addToBlock(db *statesql.Database, lid types.LayerID, bid types.BlockID, tids []types.TransactionID) error { return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { for _, tid := range tids { if err := transactions.AddToBlock(dbtx, tid, lid, bid); err != nil { @@ -917,7 +918,7 @@ func addToBlock(db *sql.Database, lid types.LayerID, bid types.BlockID, tids []t }) } -func undoLayers(db *sql.Database, from types.LayerID) error { +func undoLayers(db *statesql.Database, from types.LayerID) error { return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { err := transactions.UndoLayers(dbtx, from) if err != nil { diff --git a/txs/cache_test.go b/txs/cache_test.go index 4a8a7c364a..fce0897e8a 100644 --- a/txs/cache_test.go +++ b/txs/cache_test.go @@ -13,12 +13,13 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) type testCache struct { *Cache - db *sql.Database + db *statesql.Database } type testAcct struct { @@ -67,7 +68,7 @@ func newMeshTX( func genAndSaveTXs( t *testing.T, - db *sql.Database, + db *statesql.Database, signer *signing.EdSigner, from, to uint64, startTime time.Time, @@ -88,14 +89,14 @@ func genTXs(t *testing.T, signer *signing.EdSigner, from, to uint64, startTime t return mtxs } -func saveTXs(t *testing.T, db *sql.Database, mtxs []*types.MeshTransaction) { +func saveTXs(t *testing.T, db *statesql.Database, mtxs []*types.MeshTransaction) { t.Helper() for _, mtx := range mtxs { require.NoError(t, transactions.Add(db, &mtx.Transaction, mtx.Received)) } } -func checkTXStateFromDB(t *testing.T, db *sql.Database, txs []*types.MeshTransaction, state types.TXState) { +func checkTXStateFromDB(t *testing.T, db *statesql.Database, txs []*types.MeshTransaction, state types.TXState) { for _, mtx := range txs { got, err := transactions.Get(db, mtx.ID) require.NoError(t, err) @@ -103,7 +104,7 @@ func checkTXStateFromDB(t *testing.T, db *sql.Database, txs []*types.MeshTransac } } -func checkTXNotInDB(t *testing.T, db *sql.Database, tid types.TransactionID) { +func checkTXNotInDB(t *testing.T, db *statesql.Database, tid types.TransactionID) { _, err := transactions.Get(db, tid) require.ErrorIs(t, err, sql.ErrNotFound) } @@ -169,7 +170,7 @@ func createState(tb testing.TB, numAccounts int) map[types.Address]*testAcct { func createCache(tb testing.TB, numAccounts int) (*testCache, map[types.Address]*testAcct) { tb.Helper() accounts := createState(tb, numAccounts) - db := sql.InMemory() + db := statesql.InMemory() return &testCache{ Cache: NewCache(getStateFunc(accounts), zaptest.NewLogger(tb)), db: db, @@ -183,7 +184,7 @@ func createSingleAccountTestCache(tb testing.TB) (*testCache, *testAcct) { principal := types.GenerateAddress(signer.PublicKey().Bytes()) ta := &testAcct{signer: signer, principal: principal, nonce: rand.Uint64()%1000 + 1, balance: defaultBalance} states := map[types.Address]*testAcct{principal: ta} - db := sql.InMemory() + db := statesql.InMemory() return &testCache{ Cache: NewCache(getStateFunc(states), zaptest.NewLogger(tb)), db: db, diff --git a/txs/conservative_state.go b/txs/conservative_state.go index 88912319b3..15f0aa2d47 100644 --- a/txs/conservative_state.go +++ b/txs/conservative_state.go @@ -11,8 +11,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/system" ) @@ -54,12 +54,12 @@ type ConservativeState struct { logger *zap.Logger cfg CSConfig - db *sql.Database + db *statesql.Database cache *Cache } // NewConservativeState returns a ConservativeState. -func NewConservativeState(state vmState, db *sql.Database, opts ...ConservativeStateOpt) *ConservativeState { +func NewConservativeState(state vmState, db *statesql.Database, opts ...ConservativeStateOpt) *ConservativeState { cs := &ConservativeState{ vmState: state, cfg: defaultCSConfig(), diff --git a/txs/conservative_state_test.go b/txs/conservative_state_test.go index 46bad3a5a2..e3648582e6 100644 --- a/txs/conservative_state_test.go +++ b/txs/conservative_state_test.go @@ -23,8 +23,8 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/sdk/wallet" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" + "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" ) @@ -73,7 +73,7 @@ func newTxWthRecipient( type testConState struct { *ConservativeState logger *zap.Logger - db *sql.Database + db *statesql.Database mvm *MockvmState id peer.ID @@ -86,7 +86,7 @@ func (t *testConState) handler() *TxHandler { func createTestState(t *testing.T, gasLimit uint64) *testConState { ctrl := gomock.NewController(t) mvm := NewMockvmState(ctrl) - db := sql.InMemory() + db := statesql.InMemory() cfg := CSConfig{ BlockGasLimit: gasLimit, NumTXsPerProposal: numTXsInProposal, From 6a53628547a3ee38531ce74643800210994b4dc5 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 1 Jun 2024 08:10:30 +0400 Subject: [PATCH 02/62] sql: fixup: malsync and rewards --- sql/malsync/malsync.go | 40 +++++++++++++++++++++++++++---------- sql/rewards/rewards_test.go | 1 + sql/statesql/statesql.go | 1 + syncer/malsync/syncer.go | 10 ++++++---- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/sql/malsync/malsync.go b/sql/malsync/malsync.go index 9ae572b0ad..048ff43d90 100644 --- a/sql/malsync/malsync.go +++ b/sql/malsync/malsync.go @@ -8,6 +8,11 @@ import ( ) func GetSyncState(db sql.Executor) (time.Time, error) { + timestamp, _, err := getSyncState(db) + return timestamp, err +} + +func getSyncState(db sql.Executor) (time.Time, bool, error) { var timestamp time.Time rows, err := db.Exec("select timestamp from malfeasance_sync_state", nil, func(stmt *sql.Statement) bool { @@ -17,21 +22,36 @@ func GetSyncState(db sql.Executor) (time.Time, error) { } return true }) - if err != nil { - return time.Time{}, fmt.Errorf("error getting malfeasance sync state: %w", err) - } else if rows != 1 { - return time.Time{}, fmt.Errorf("expected malfeasance_sync_state to have 1 row but got %d rows", rows) + switch { + case err != nil: + return time.Time{}, false, fmt.Errorf("error getting malfeasance sync state: %w", err) + case rows == 0: + return timestamp, false, nil + case rows == 1: + return timestamp, true, nil + default: + return time.Time{}, false, fmt.Errorf("expected malfeasance_sync_state to have 1 row but got %d rows", rows) } - return timestamp, nil } func updateSyncState(db sql.Executor, ts int64) error { - _, err := db.Exec("update malfeasance_sync_state set timestamp = ?1", - func(stmt *sql.Statement) { - stmt.BindInt64(1, ts) - }, nil) + _, haveTS, err := getSyncState(db) if err != nil { - return fmt.Errorf("error updating malfeasance sync state: %w", err) + return err + } + enc := func(stmt *sql.Statement) { + stmt.BindInt64(1, ts) + } + if haveTS { + _, err := db.Exec("update malfeasance_sync_state set timestamp = ?1", enc, nil) + if err != nil { + return fmt.Errorf("error updating malfeasance sync state: %w", err) + } + } else { + _, err = db.Exec("insert into malfeasance_sync_state (timestamp) values(?1)", enc, nil) + if err != nil { + return fmt.Errorf("error initializing malfeasance sync state: %w", err) + } } return nil } diff --git a/sql/rewards/rewards_test.go b/sql/rewards/rewards_test.go index 3cf7ef11a4..af57f33c9a 100644 --- a/sql/rewards/rewards_test.go +++ b/sql/rewards/rewards_test.go @@ -246,6 +246,7 @@ func Test_0008Migration(t *testing.T) { // apply previous migrations db := statesql.InMemory( sql.WithDatabaseSchema(schema), + sql.WithForceMigrations(true), ) // verify that the DB is empty diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 6536243498..24744004c9 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -56,3 +56,4 @@ func InMemory(opts ...sql.Opt) *Database { // TBD: QQQQQ: check disabling migrations in database_test.go // TBD: QQQQQ: add sql/test package with test skeletons // TBD: QQQQQ: verify identity merging code +// TBD: QQQQQ: instead of "not like '_litestream%'", use regex in the config diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go index 67a96ddc50..a685f82f4f 100644 --- a/syncer/malsync/syncer.go +++ b/syncer/malsync/syncer.go @@ -341,8 +341,10 @@ func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan } } -func (s *Syncer) updateState() error { - if err := malsync.UpdateSyncState(s.localdb, s.clock.Now()); err != nil { +func (s *Syncer) updateState(ctx context.Context) error { + if err := s.localdb.WithTx(ctx, func(tx *sql.Tx) error { + return malsync.UpdateSyncState(tx, s.clock.Now()) + }); err != nil { return fmt.Errorf("error updating malsync state: %w", err) } @@ -360,13 +362,13 @@ func (s *Syncer) downloadMalfeasanceProofs(ctx context.Context, initial bool, up if nothingToDownload { sst.done() if initial && sst.numSyncedPeers() >= s.cfg.MinSyncPeers { - if err := s.updateState(); err != nil { + if err := s.updateState(ctx); err != nil { return err } s.logger.Info("initial sync of malfeasance proofs completed", log.ZContext(ctx)) return nil } else if !initial && gotUpdate { - if err := s.updateState(); err != nil { + if err := s.updateState(ctx); err != nil { return err } } From 5db7e18cf14caa94edd4a336452877f09f12dd85 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 1 Jun 2024 18:48:45 +0400 Subject: [PATCH 03/62] sql: add migration / schema drift related tests --- sql/database_test.go | 112 +++++++++++++++++++++++++++++++++++++++ sql/statesql/statesql.go | 2 - 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/sql/database_test.go b/sql/database_test.go index 04b559e7db..24d5d03e72 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -9,6 +9,10 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" ) func Test_Transaction_Isolation(t *testing.T) { @@ -103,6 +107,35 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { require.ErrorContains(t, err, "migration 2 failed") } +func Test_Migration_Disabled(t *testing.T) { + ctrl := gomock.NewController(t) + migration1 := NewMockMigration(ctrl) + migration1.EXPECT().Name().Return("test").AnyTimes() + migration1.EXPECT().Order().Return(1).AnyTimes() + migration1.EXPECT().Apply(gomock.Any()).Return(nil) + + dbFile := filepath.Join(t.TempDir(), "test.sql") + db, err := Open("file:"+dbFile, + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + ) + require.NoError(t, err) + require.NoError(t, db.Close()) + + migration2 := NewMockMigration(ctrl) + migration2.EXPECT().Order().Return(2).AnyTimes() + + _, err = Open("file:"+dbFile, + WithDatabaseSchema(&Schema{ + Migrations: MigrationList{migration1, migration2}, + }), + WithEnableMigrations(false), + ) + require.ErrorIs(t, err, ErrOld) +} + func TestDatabaseSkipMigrations(t *testing.T) { ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) @@ -200,3 +233,82 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { ) require.ErrorIs(t, err, ErrTooNew) } + +func TestSchemaDrift(t *testing.T) { + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + dbFile := filepath.Join(t.TempDir(), "test.sql") + schema := &Schema{ + // Not using ` here to avoid schema drift warnings due to whitespace + // TODO: ignore whitespace and comments during schema comparison + Script: "PRAGMA user_version = 0;\n" + + "CREATE TABLE testing1 (\n" + + " id varchar primary key,\n" + + " field int\n" + + ");\n", + } + db, err := Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + ) + require.NoError(t, err) + + _, err = db.Exec("create table newtbl (id int)", nil, nil) + require.NoError(t, err) + + require.NoError(t, db.Close()) + require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") + + db, err = Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + ) + require.NoError(t, db.Close()) + require.NoError(t, err) + require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") + require.Equal(t, "database schema drift detected", observedLogs.All()[0].Message) + require.Contains(t, observedLogs.All()[0].ContextMap()["diff"], + "+CREATE TABLE newtbl (id int);") +} + +func TestSchemaDrift_IgnoredTables(t *testing.T) { + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + dbFile := filepath.Join(t.TempDir(), "test.sql") + schema := &Schema{ + // Not using ` here to avoid schema drift warnings due to whitespace + // TODO: ignore whitespace and comments during schema comparison + Script: "PRAGMA user_version = 0;\n" + + "CREATE TABLE testing1 (\n" + + " id varchar primary key,\n" + + " field int\n" + + ");\n", + } + db, err := Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + ) + require.NoError(t, err) + + _, err = db.Exec("create table _litestream_test (id int)", nil, nil) + require.NoError(t, err) + + require.NoError(t, db.Close()) + require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") + + db, err = Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + ) + require.NoError(t, err) + require.NoError(t, db.Close()) + require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") +} diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 24744004c9..23e7d9c9f3 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -53,7 +53,5 @@ func InMemory(opts ...sql.Opt) *Database { return &Database{Database: db} } -// TBD: QQQQQ: check disabling migrations in database_test.go // TBD: QQQQQ: add sql/test package with test skeletons -// TBD: QQQQQ: verify identity merging code // TBD: QQQQQ: instead of "not like '_litestream%'", use regex in the config From c2938b217e4bf26324dda0eb5d13cde2b47039d1 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 1 Jun 2024 18:59:58 +0400 Subject: [PATCH 04/62] sql: make schema drift table ignore regexp configurable --- config/config.go | 8 +++++--- node/node.go | 1 + sql/database.go | 11 ++++++++++- sql/database_test.go | 2 ++ sql/localsql/localsql_test.go | 6 +++--- sql/schema.go | 26 ++++++++++++++++++++------ sql/statesql/statesql.go | 1 - sql/statesql/statesql_test.go | 6 +++--- 8 files changed, 44 insertions(+), 17 deletions(-) diff --git a/config/config.go b/config/config.go index 230d9a9a8a..3342c20e86 100644 --- a/config/config.go +++ b/config/config.go @@ -117,6 +117,7 @@ type BaseConfig struct { DatabaseSkipMigrations []int `mapstructure:"db-skip-migrations"` DatabaseQueryCache bool `mapstructure:"db-query-cache"` DatabaseQueryCacheSizes DatabaseQueryCacheSizes `mapstructure:"db-query-cache-sizes"` + DatabaseIgnoreTableRx string `mapstructure:"db-ignore-table-rx"` PruneActivesetsFrom types.EpochID `mapstructure:"prune-activesets-from"` @@ -245,9 +246,10 @@ func defaultBaseConfig() BaseConfig { ATXBlob: 10000, ActiveSetBlob: 200, }, - NetworkHRP: "sm", - ATXGradeDelay: 10 * time.Second, - PostValidDelay: 12 * time.Hour, + DatabaseIgnoreTableRx: "^_litestream", + NetworkHRP: "sm", + ATXGradeDelay: 10 * time.Second, + PostValidDelay: 12 * time.Hour, PprofHTTPServerListener: "localhost:6060", } diff --git a/node/node.go b/node/node.go index 368e96c814..86b9144705 100644 --- a/node/node.go +++ b/node/node.go @@ -1901,6 +1901,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { sql.WithConnections(app.Config.DatabaseConnections), sql.WithLatencyMetering(app.Config.DatabaseLatencyMetering), sql.WithVacuumState(app.Config.DatabaseVacuumState), + sql.WithIgnoreTableRx(app.Config.DatabaseIgnoreTableRx), sql.WithQueryCache(app.Config.DatabaseQueryCache), sql.WithQueryCacheSizes(map[sql.QueryCacheKind]int{ atxs.CacheKindEpochATXs: app.Config.DatabaseQueryCacheSizes.EpochATXs, diff --git a/sql/database.go b/sql/database.go index 8de21300d6..5d64e38b7a 100644 --- a/sql/database.go +++ b/sql/database.go @@ -81,6 +81,7 @@ type conf struct { cacheSizes map[QueryCacheKind]int logger *zap.Logger schema *Schema + ignoreTableRx string } // WithConnections overwrites number of pooled connections. @@ -154,6 +155,14 @@ func WithDatabaseSchema(schema *Schema) Opt { } } +// WithIgnoreTableRx specifies regular expression for table names to ignore during schema +// drift detection. +func WithIgnoreTableRx(rx string) Opt { + return func(c *conf) { + c.ignoreTableRx = rx + } +} + func withForceFresh(fresh bool) Opt { return func(c *conf) { c.forceFresh = fresh @@ -222,7 +231,7 @@ func Open(uri string, opts ...Opt) (*Database, error) { } } - loaded, err := LoadDBSchemaScript(db) + loaded, err := LoadDBSchemaScript(db, config.ignoreTableRx) if err != nil { return nil, fmt.Errorf("error loading database schema: %w", err) } diff --git a/sql/database_test.go b/sql/database_test.go index 24d5d03e72..33f93eb580 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -295,6 +295,7 @@ func TestSchemaDrift_IgnoredTables(t *testing.T) { db, err := Open("file:"+dbFile, WithDatabaseSchema(schema), WithLogger(logger), + WithIgnoreTableRx("^_litestream"), ) require.NoError(t, err) @@ -307,6 +308,7 @@ func TestSchemaDrift_IgnoredTables(t *testing.T) { db, err = Open("file:"+dbFile, WithDatabaseSchema(schema), WithLogger(logger), + WithIgnoreTableRx("^_litestream"), ) require.NoError(t, err) require.NoError(t, db.Close()) diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go index a6dbf958bb..d999c4d9aa 100644 --- a/sql/localsql/localsql_test.go +++ b/sql/localsql/localsql_test.go @@ -15,13 +15,13 @@ func TestDatabase_MigrateTwice_NoOp(t *testing.T) { db, err := Open("file:"+file, sql.WithForceMigrations(true)) require.NoError(t, err) - sql1, err := sql.LoadDBSchemaScript(db) + sql1, err := sql.LoadDBSchemaScript(db, "") require.NoError(t, err) require.NoError(t, db.Close()) db, err = Open("file:" + file) require.NoError(t, err) - sql2, err := sql.LoadDBSchemaScript(db) + sql2, err := sql.LoadDBSchemaScript(db, "") require.NoError(t, err) require.Equal(t, sql1, sql2) @@ -56,7 +56,7 @@ func TestSchema(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { db := InMemory(sql.WithForceMigrations(tc.forceMigrations)) - loadedScript, err := sql.LoadDBSchemaScript(db) + loadedScript, err := sql.LoadDBSchemaScript(db, "") require.NoError(t, err) expSchema, err := Schema() require.NoError(t, err) diff --git a/sql/schema.go b/sql/schema.go index 4f495c9eb7..6e2cb8a68e 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -9,6 +9,7 @@ import ( "io/fs" "os" "path/filepath" + "regexp" "strings" godiffpatch "github.com/sourcegraph/go-diff-patch" @@ -30,8 +31,19 @@ func LoadSchema(fsys fs.FS, migrations []Migration) (*Schema, error) { } // LoadDBSchemaScript retrieves the database schema as text. -func LoadDBSchemaScript(db Executor) (string, error) { - var sb strings.Builder +func LoadDBSchemaScript(db Executor, ignoreRx string) (string, error) { + var ( + err error + ignRx *regexp.Regexp + sb strings.Builder + ) + if ignoreRx != "" { + ignRx, err = regexp.Compile(ignoreRx) + if err != nil { + return "", fmt.Errorf("error compiling table ignore regexp %q: %w", + ignoreRx, err) + } + } version, err := version(db) if err != nil { return "", err @@ -40,14 +52,16 @@ func LoadDBSchemaScript(db Executor) (string, error) { if _, err = db.Exec( // Type is either 'index' or 'table', we want tables // to go first. Also, we ignore _litestream tables - `select sql || ';' from sqlite_master - where sql is not null and tbl_name not like '_litestream%' + `select tbl_name, sql || ';' from sqlite_master + where sql is not null order by tbl_name, type desc, name`, nil, func(st *Statement) bool { - fmt.Fprintln(&sb, st.ColumnText(0)) + if ignRx == nil || !ignRx.MatchString(st.ColumnText(0)) { + fmt.Fprintln(&sb, st.ColumnText(1)) + } return true }); err != nil { - return "", err + return "", fmt.Errorf("error retrieving DB schema: %w", err) } return sb.String(), nil } diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 23e7d9c9f3..0f34abc6bb 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -54,4 +54,3 @@ func InMemory(opts ...sql.Opt) *Database { } // TBD: QQQQQ: add sql/test package with test skeletons -// TBD: QQQQQ: instead of "not like '_litestream%'", use regex in the config diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go index d13a76671e..e975b16923 100644 --- a/sql/statesql/statesql_test.go +++ b/sql/statesql/statesql_test.go @@ -15,13 +15,13 @@ func TestDatabase_MigrateTwice_NoOp(t *testing.T) { db, err := Open("file:"+file, sql.WithForceMigrations(true)) require.NoError(t, err) - sql1, err := sql.LoadDBSchemaScript(db) + sql1, err := sql.LoadDBSchemaScript(db, "") require.NoError(t, err) require.NoError(t, db.Close()) db, err = Open("file:" + file) require.NoError(t, err) - sql2, err := sql.LoadDBSchemaScript(db) + sql2, err := sql.LoadDBSchemaScript(db, "") require.NoError(t, err) require.Equal(t, sql1, sql2) @@ -56,7 +56,7 @@ func TestSchema(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { db := InMemory(sql.WithForceMigrations(tc.forceMigrations)) - loadedScript, err := sql.LoadDBSchemaScript(db) + loadedScript, err := sql.LoadDBSchemaScript(db, "") require.NoError(t, err) expSchema, err := Schema() require.NoError(t, err) From ec3397ee79902f0f4bef300d01360851ca21069c Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 1 Jun 2024 19:11:41 +0400 Subject: [PATCH 05/62] sql: refactor localsql / statesql tests --- sql/localsql/localsql_test.go | 67 +++------------------------- sql/statesql/statesql.go | 2 - sql/statesql/statesql_test.go | 65 +++------------------------ sql/test/test.go | 83 +++++++++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 124 deletions(-) create mode 100644 sql/test/test.go diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go index d999c4d9aa..67b18fdf5d 100644 --- a/sql/localsql/localsql_test.go +++ b/sql/localsql/localsql_test.go @@ -1,74 +1,17 @@ package localsql import ( - "path/filepath" - "slices" "testing" - "github.com/stretchr/testify/require" - - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/test" ) -func TestDatabase_MigrateTwice_NoOp(t *testing.T) { - file := filepath.Join(t.TempDir(), "test.db") - db, err := Open("file:"+file, sql.WithForceMigrations(true)) - require.NoError(t, err) - - sql1, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) - require.NoError(t, db.Close()) - - db, err = Open("file:" + file) - require.NoError(t, err) - sql2, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) +var fns = test.DBFuncs[*Database]{Schema: Schema, Open: Open, InMemory: InMemory} - require.Equal(t, sql1, sql2) - - var version int - _, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool { - version = stmt.ColumnInt(0) - return true - }) - require.NoError(t, err) - - require.NoError(t, err) - schema, err := Schema() - require.NoError(t, err) - expectedVersion := slices.MaxFunc( - []sql.Migration(schema.Migrations), - func(a, b sql.Migration) int { - return a.Order() - b.Order() - }) - require.Equal(t, expectedVersion.Order(), version) - - require.NoError(t, db.Close()) +func TestDatabase_MigrateTwice_NoOp(t *testing.T) { + test.VerifyMigrateTwiceNoOp(t, fns) } func TestSchema(t *testing.T) { - for _, tc := range []struct { - name string - forceMigrations bool - }{ - {name: "no migrations", forceMigrations: false}, - {name: "force migrations", forceMigrations: true}, - } { - t.Run(tc.name, func(t *testing.T) { - db := InMemory(sql.WithForceMigrations(tc.forceMigrations)) - loadedScript, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) - expSchema, err := Schema() - require.NoError(t, err) - diff := expSchema.Diff(loadedScript) - if diff != "" { - s := &sql.Schema{ - Script: loadedScript, - } - require.NoError(t, s.WriteToFile(".")) - t.Logf("updated schema written to %s", sql.UpdatedSchemaPath) - } - require.Empty(t, diff, "local schema diff") - }) - } + test.VerifySchema(t, fns) } diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 0f34abc6bb..137b9b5016 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -52,5 +52,3 @@ func InMemory(opts ...sql.Opt) *Database { db := sql.InMemory(opts...) return &Database{Database: db} } - -// TBD: QQQQQ: add sql/test package with test skeletons diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go index e975b16923..cb18c78eb5 100644 --- a/sql/statesql/statesql_test.go +++ b/sql/statesql/statesql_test.go @@ -1,72 +1,17 @@ package statesql import ( - "path/filepath" - "slices" "testing" - "github.com/stretchr/testify/require" - - "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/test" ) -func TestDatabase_MigrateTwice_NoOp(t *testing.T) { - file := filepath.Join(t.TempDir(), "test.db") - db, err := Open("file:"+file, sql.WithForceMigrations(true)) - require.NoError(t, err) - - sql1, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) - require.NoError(t, db.Close()) - - db, err = Open("file:" + file) - require.NoError(t, err) - sql2, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) +var fns = test.DBFuncs[*Database]{Schema: Schema, Open: Open, InMemory: InMemory} - require.Equal(t, sql1, sql2) - - var version int - _, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool { - version = stmt.ColumnInt(0) - return true - }) - require.NoError(t, err) - - require.NoError(t, err) - schema, err := Schema() - require.NoError(t, err) - expectedVersion := slices.MaxFunc( - []sql.Migration(schema.Migrations), - func(a, b sql.Migration) int { - return a.Order() - b.Order() - }) - require.Equal(t, expectedVersion.Order(), version) - - require.NoError(t, db.Close()) +func TestDatabase_MigrateTwice_NoOp(t *testing.T) { + test.VerifyMigrateTwiceNoOp(t, fns) } func TestSchema(t *testing.T) { - for _, tc := range []struct { - name string - forceMigrations bool - }{ - {name: "no migrations", forceMigrations: false}, - {name: "force migrations", forceMigrations: true}, - } { - t.Run(tc.name, func(t *testing.T) { - db := InMemory(sql.WithForceMigrations(tc.forceMigrations)) - loadedScript, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) - expSchema, err := Schema() - require.NoError(t, err) - diff := expSchema.Diff(loadedScript) - if diff != "" { - s := &sql.Schema{Script: loadedScript} - require.NoError(t, s.WriteToFile(".")) - t.Logf("updated schema written to %s", sql.UpdatedSchemaPath) - } - require.Empty(t, diff, "schema diff") - }) - } + test.VerifySchema(t, fns) } diff --git a/sql/test/test.go b/sql/test/test.go new file mode 100644 index 0000000000..42b9368992 --- /dev/null +++ b/sql/test/test.go @@ -0,0 +1,83 @@ +package test + +import ( + "path/filepath" + "slices" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +type Database interface { + sql.Executor + Close() error +} + +type DBFuncs[DB Database] struct { + Schema func() (*sql.Schema, error) + Open func(uri string, opts ...sql.Opt) (DB, error) + InMemory func(opts ...sql.Opt) DB +} + +func VerifyMigrateTwiceNoOp[DB Database](t *testing.T, funcs DBFuncs[DB]) { + file := filepath.Join(t.TempDir(), "test.db") + db, err := funcs.Open("file:"+file, sql.WithForceMigrations(true)) + require.NoError(t, err) + + sql1, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + require.NoError(t, db.Close()) + + db, err = funcs.Open("file:" + file) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + + require.Equal(t, sql1, sql2) + + var version int + _, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool { + version = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + require.NoError(t, err) + schema, err := funcs.Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), version) + + require.NoError(t, db.Close()) +} + +func VerifySchema[DB Database](t *testing.T, funcs DBFuncs[DB]) { + for _, tc := range []struct { + name string + forceMigrations bool + }{ + {name: "no migrations", forceMigrations: false}, + {name: "force migrations", forceMigrations: true}, + } { + t.Run(tc.name, func(t *testing.T) { + db := funcs.InMemory(sql.WithForceMigrations(tc.forceMigrations)) + loadedScript, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + expSchema, err := funcs.Schema() + require.NoError(t, err) + diff := expSchema.Diff(loadedScript) + if diff != "" { + s := &sql.Schema{Script: loadedScript} + require.NoError(t, s.WriteToFile(".")) + t.Logf("updated schema written to %s", sql.UpdatedSchemaPath) + } + require.Empty(t, diff, "schema diff") + }) + } +} From 1af47bb335c32a51fa1f7e1fa425e1db45f9c24e Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 1 Jun 2024 19:58:02 +0400 Subject: [PATCH 06/62] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 625600e182..17430c652d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,9 @@ Upgrading to this version requires going through v1.5.x first. Removed migration ATXs. This vulnerability allows an attacker to claim rewards for a full tick amount although they should not be eligible for them. +* [#6003](https://github.com/spacemeshos/go-spacemesh/pull/6003) Improve database schema handling. + This includes schema drift detection which may happen after running unreleased versions. + ## Release v1.5.7 ### Improvements From 7bb2b1bd152fe91c46de725505439e6382aa827b Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 1 Jun 2024 20:16:57 +0400 Subject: [PATCH 07/62] config: fix db presets for db schema drift detection --- config/mainnet.go | 1 + config/presets/testnet.go | 1 + 2 files changed, 2 insertions(+) diff --git a/config/mainnet.go b/config/mainnet.go index 5c98f0248c..569d993d68 100644 --- a/config/mainnet.go +++ b/config/mainnet.go @@ -73,6 +73,7 @@ func MainnetConfig() Config { DatabaseConnections: 16, DatabasePruneInterval: 30 * time.Minute, DatabaseVacuumState: 15, + DatabaseIgnoreTableRx: "^_litestream", PruneActivesetsFrom: 12, // starting from epoch 13 activesets below 12 will be pruned ScanMalfeasantATXs: false, // opt-in NetworkHRP: "sm", diff --git a/config/presets/testnet.go b/config/presets/testnet.go index 892d924aaf..c6f6f29a2e 100644 --- a/config/presets/testnet.go +++ b/config/presets/testnet.go @@ -65,6 +65,7 @@ func testnet() config.Config { DatabaseConnections: 16, DatabaseSizeMeteringInterval: 10 * time.Minute, DatabasePruneInterval: 30 * time.Minute, + DatabaseIgnoreTableRx: "^_litestream", NetworkHRP: "stest", LayerDuration: 5 * time.Minute, From bc982404cc73059c9ac75b9638f49e714dad790d Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 5 Jun 2024 07:45:11 +0400 Subject: [PATCH 08/62] sql: fix review comments --- cmd/merge-nodes/internal/merge_action.go | 2 +- cmd/merge-nodes/internal/merge_action_test.go | 4 +- config/config.go | 10 +- node/node.go | 2 +- sql/database.go | 16 +- sql/database_test.go | 8 +- sql/localsql/localsql.go | 11 +- sql/localsql/localsql_test.go | 11 +- .../migrations/0009_malfeasance_sync_pk.sql | 12 ++ sql/localsql/schema/schema.sql | 3 +- sql/malsync/malsync.go | 42 ++---- sql/schema.go | 26 +--- sql/statesql/statesql.go | 13 +- sql/statesql/statesql_test.go | 11 +- sql/test/test.go | 141 +++++++++++------- 15 files changed, 164 insertions(+), 148 deletions(-) create mode 100644 sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql diff --git a/cmd/merge-nodes/internal/merge_action.go b/cmd/merge-nodes/internal/merge_action.go index 58c7586b59..241c5bb0a9 100644 --- a/cmd/merge-nodes/internal/merge_action.go +++ b/cmd/merge-nodes/internal/merge_action.go @@ -191,7 +191,7 @@ func openDB(dbLog *zap.Logger, path string) (*localsql.Database, error) { db, err := localsql.Open("file:"+dbPath, sql.WithLogger(dbLog), - sql.WithEnableMigrations(false), + sql.WithMigrationsDisabled(), ) if err != nil { return nil, fmt.Errorf("open source database %s: %w", dbPath, err) diff --git a/cmd/merge-nodes/internal/merge_action_test.go b/cmd/merge-nodes/internal/merge_action_test.go index 3315821c8c..3548b64c31 100644 --- a/cmd/merge-nodes/internal/merge_action_test.go +++ b/cmd/merge-nodes/internal/merge_action_test.go @@ -42,7 +42,7 @@ func Test_MergeDBs_InvalidTargetScheme(t *testing.T) { require.NoError(t, db.Close()) err = MergeDBs(context.Background(), zaptest.NewLogger(t), "", tmpDst) - require.ErrorIs(t, err, sql.ErrOld) + require.ErrorIs(t, err, sql.ErrOldSchema) require.ErrorContains(t, err, "target database") } @@ -100,7 +100,7 @@ func Test_MergeDBs_InvalidSourceScheme(t *testing.T) { require.NoError(t, db.Close()) err = MergeDBs(context.Background(), zaptest.NewLogger(t), tmpSrc, tmpDst) - require.ErrorIs(t, err, sql.ErrOld) + require.ErrorIs(t, err, sql.ErrOldSchema) require.ErrorContains(t, err, "source database") } diff --git a/config/config.go b/config/config.go index 3342c20e86..f389fc82a6 100644 --- a/config/config.go +++ b/config/config.go @@ -117,7 +117,7 @@ type BaseConfig struct { DatabaseSkipMigrations []int `mapstructure:"db-skip-migrations"` DatabaseQueryCache bool `mapstructure:"db-query-cache"` DatabaseQueryCacheSizes DatabaseQueryCacheSizes `mapstructure:"db-query-cache-sizes"` - DatabaseIgnoreTableRx string `mapstructure:"db-ignore-table-rx"` + DatabaseSchemaIgnoreRx string `mapstructure:"db-ignore-schema-rx"` PruneActivesetsFrom types.EpochID `mapstructure:"prune-activesets-from"` @@ -246,10 +246,10 @@ func defaultBaseConfig() BaseConfig { ATXBlob: 10000, ActiveSetBlob: 200, }, - DatabaseIgnoreTableRx: "^_litestream", - NetworkHRP: "sm", - ATXGradeDelay: 10 * time.Second, - PostValidDelay: 12 * time.Hour, + DatabaseSchemaIgnoreRx: "^_litestream", + NetworkHRP: "sm", + ATXGradeDelay: 10 * time.Second, + PostValidDelay: 12 * time.Hour, PprofHTTPServerListener: "localhost:6060", } diff --git a/node/node.go b/node/node.go index 86b9144705..685cccdeeb 100644 --- a/node/node.go +++ b/node/node.go @@ -1901,7 +1901,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { sql.WithConnections(app.Config.DatabaseConnections), sql.WithLatencyMetering(app.Config.DatabaseLatencyMetering), sql.WithVacuumState(app.Config.DatabaseVacuumState), - sql.WithIgnoreTableRx(app.Config.DatabaseIgnoreTableRx), + sql.WithIgnoreTableRx(app.Config.DatabaseSchemaIgnoreRx), sql.WithQueryCache(app.Config.DatabaseQueryCache), sql.WithQueryCacheSizes(map[sql.QueryCacheKind]int{ atxs.CacheKindEpochATXs: app.Config.DatabaseQueryCacheSizes.EpochATXs, diff --git a/sql/database.go b/sql/database.go index 5d64e38b7a..746b8c0e13 100644 --- a/sql/database.go +++ b/sql/database.go @@ -28,9 +28,9 @@ var ( ErrObjectExists = errors.New("database: object exists") // ErrTooNew is returned if database version is newer than expected. ErrTooNew = errors.New("database version is too new") - // ErrOld is returned when the database version differs from the expected one + // ErrOldSchema is returned when the database version differs from the expected one // and migrations are disabled. - ErrOld = errors.New("old database version") + ErrOldSchema = errors.New("old database version") ) const ( @@ -98,11 +98,11 @@ func WithLogger(logger *zap.Logger) Opt { } } -// WithEnableMigrations enables or disables migrations on the database. +// WithMigrationsDisabled disables migrations for the database. // The migrations are enabled by default. -func WithEnableMigrations(enable bool) Opt { +func WithMigrationsDisabled() Opt { return func(c *conf) { - c.enableMigrations = enable + c.enableMigrations = false } } @@ -163,9 +163,9 @@ func WithIgnoreTableRx(rx string) Opt { } } -func withForceFresh(fresh bool) Opt { +func withForceFresh() Opt { return func(c *conf) { - c.forceFresh = fresh + c.forceFresh = true } } @@ -174,7 +174,7 @@ type Opt func(c *conf) // InMemory database for testing. func InMemory(opts ...Opt) *Database { - opts = append(opts, WithConnections(1), withForceFresh(true)) + opts = append(opts, WithConnections(1), withForceFresh()) db, err := Open("file::memory:?mode=memory", opts...) if err != nil { panic(err) diff --git a/sql/database_test.go b/sql/database_test.go index 33f93eb580..e1e6332401 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -131,9 +131,9 @@ func Test_Migration_Disabled(t *testing.T) { WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1, migration2}, }), - WithEnableMigrations(false), + WithMigrationsDisabled(), ) - require.ErrorIs(t, err, ErrOld) + require.ErrorIs(t, err, ErrOldSchema) } func TestDatabaseSkipMigrations(t *testing.T) { @@ -271,8 +271,8 @@ func TestSchemaDrift(t *testing.T) { require.NoError(t, err) require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") require.Equal(t, "database schema drift detected", observedLogs.All()[0].Message) - require.Contains(t, observedLogs.All()[0].ContextMap()["diff"], - "+CREATE TABLE newtbl (id int);") + require.Regexp(t, `.*\n\s*\+\s*CREATE TABLE newtbl \(id int\);`, + observedLogs.All()[0].ContextMap()["diff"]) } func TestSchemaDrift_IgnoredTables(t *testing.T) { diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go index 88f41f2a17..012e66e73c 100644 --- a/sql/localsql/localsql.go +++ b/sql/localsql/localsql.go @@ -6,8 +6,11 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" ) -//go:embed schema/schema.sql schema/migrations/*.sql -var embedded embed.FS +//go:embed schema/schema.sql +var schemaScript string + +//go:embed schema/migrations/*.sql +var migrations embed.FS // Database represents a local database. type Database struct { @@ -16,13 +19,13 @@ type Database struct { // Schema returns the schema for the local database. func Schema() (*sql.Schema, error) { - migrations, err := sql.LoadSQLMigrations(embedded) + sqlMigrations, err := sql.LoadSQLMigrations(migrations) if err != nil { return nil, err } // NOTE: coded state migrations can be added here // They can be a part of this localsql package - return sql.LoadSchema(embedded, migrations) + return &sql.Schema{Script: schemaScript, Migrations: sqlMigrations}, nil } // Open opens a local database. diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go index 67b18fdf5d..c6ce6bf02f 100644 --- a/sql/localsql/localsql_test.go +++ b/sql/localsql/localsql_test.go @@ -3,15 +3,16 @@ package localsql import ( "testing" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/test" ) -var fns = test.DBFuncs[*Database]{Schema: Schema, Open: Open, InMemory: InMemory} +type fns struct{} -func TestDatabase_MigrateTwice_NoOp(t *testing.T) { - test.VerifyMigrateTwiceNoOp(t, fns) -} +func (fns) Schema() (*sql.Schema, error) { return Schema() } +func (fns) InMemory(opts ...sql.Opt) *Database { return InMemory(opts...) } +func (fns) Open(uri string, opts ...sql.Opt) (*Database, error) { return Open(uri, opts...) } func TestSchema(t *testing.T) { - test.VerifySchema(t, fns) + test.RunSchemaTests(t, fns{}) } diff --git a/sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql b/sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql new file mode 100644 index 0000000000..d0ade1f535 --- /dev/null +++ b/sql/localsql/schema/migrations/0009_malfeasance_sync_pk.sql @@ -0,0 +1,12 @@ +ALTER TABLE malfeasance_sync_state RENAME TO malfeasance_sync_state_old; + +CREATE TABLE malfeasance_sync_state +( + id INT NOT NULL PRIMARY KEY, + timestamp INT NOT NULL +); + +INSERT INTO malfeasance_sync_state (id, timestamp) +SELECT 1, timestamp FROM malfeasance_sync_state_old LIMIT 1; + +DROP TABLE malfeasance_sync_state_old; diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql index bcc8ca05c8..02c44d3ccb 100755 --- a/sql/localsql/schema/schema.sql +++ b/sql/localsql/schema/schema.sql @@ -1,4 +1,4 @@ -PRAGMA user_version = 8; +PRAGMA user_version = 9; CREATE TABLE atx_sync_requests ( epoch INT NOT NULL, @@ -26,6 +26,7 @@ CREATE TABLE "challenge" , poet_proof_ref CHAR(32), poet_proof_membership VARCHAR) WITHOUT ROWID; CREATE TABLE malfeasance_sync_state ( + id INT NOT NULL PRIMARY KEY, timestamp INT NOT NULL ); CREATE TABLE nipost diff --git a/sql/malsync/malsync.go b/sql/malsync/malsync.go index 048ff43d90..fac8b50fba 100644 --- a/sql/malsync/malsync.go +++ b/sql/malsync/malsync.go @@ -8,13 +8,8 @@ import ( ) func GetSyncState(db sql.Executor) (time.Time, error) { - timestamp, _, err := getSyncState(db) - return timestamp, err -} - -func getSyncState(db sql.Executor) (time.Time, bool, error) { var timestamp time.Time - rows, err := db.Exec("select timestamp from malfeasance_sync_state", + rows, err := db.Exec("select timestamp from malfeasance_sync_state where id = 1", nil, func(stmt *sql.Statement) bool { v := stmt.ColumnInt64(0) if v > 0 { @@ -24,34 +19,23 @@ func getSyncState(db sql.Executor) (time.Time, bool, error) { }) switch { case err != nil: - return time.Time{}, false, fmt.Errorf("error getting malfeasance sync state: %w", err) - case rows == 0: - return timestamp, false, nil - case rows == 1: - return timestamp, true, nil + return time.Time{}, fmt.Errorf("error getting malfeasance sync state: %w", err) + case rows <= 1: + return timestamp, nil default: - return time.Time{}, false, fmt.Errorf("expected malfeasance_sync_state to have 1 row but got %d rows", rows) + return time.Time{}, fmt.Errorf("expected malfeasance_sync_state to have 1 row but got %d rows", rows) } } func updateSyncState(db sql.Executor, ts int64) error { - _, haveTS, err := getSyncState(db) - if err != nil { - return err - } - enc := func(stmt *sql.Statement) { - stmt.BindInt64(1, ts) - } - if haveTS { - _, err := db.Exec("update malfeasance_sync_state set timestamp = ?1", enc, nil) - if err != nil { - return fmt.Errorf("error updating malfeasance sync state: %w", err) - } - } else { - _, err = db.Exec("insert into malfeasance_sync_state (timestamp) values(?1)", enc, nil) - if err != nil { - return fmt.Errorf("error initializing malfeasance sync state: %w", err) - } + if _, err := db.Exec( + `insert into malfeasance_sync_state (id, timestamp) values(1, ?1) + on conflict (id) do update set timestamp = ?1`, + func(stmt *sql.Statement) { + stmt.BindInt64(1, ts) + }, nil, + ); err != nil { + return fmt.Errorf("error initializing malfeasance sync state: %w", err) } return nil } diff --git a/sql/schema.go b/sql/schema.go index 6e2cb8a68e..82fc2badb7 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -6,13 +6,12 @@ import ( "context" "errors" "fmt" - "io/fs" "os" "path/filepath" "regexp" "strings" - godiffpatch "github.com/sourcegraph/go-diff-patch" + "github.com/google/go-cmp/cmp" "go.uber.org/zap" ) @@ -21,15 +20,6 @@ const ( UpdatedSchemaPath = "schema/schema.sql.updated" ) -// LoadSchema loads the schema embedded in the executable. -func LoadSchema(fsys fs.FS, migrations []Migration) (*Schema, error) { - text, err := fs.ReadFile(fsys, SchemaPath) - if err != nil { - return nil, fmt.Errorf("error reading schema file %s: %w", SchemaPath, err) - } - return &Schema{Script: string(text), Migrations: migrations}, nil -} - // LoadDBSchemaScript retrieves the database schema as text. func LoadDBSchemaScript(db Executor, ignoreRx string) (string, error) { var ( @@ -50,8 +40,7 @@ func LoadDBSchemaScript(db Executor, ignoreRx string) (string, error) { } fmt.Fprintf(&sb, "PRAGMA user_version = %d;\n", version) if _, err = db.Exec( - // Type is either 'index' or 'table', we want tables - // to go first. Also, we ignore _litestream tables + // Type is either 'index' or 'table', we want tables to go first `select tbl_name, sql || ';' from sqlite_master where sql is not null order by tbl_name, type desc, name`, @@ -76,14 +65,7 @@ type Schema struct { // Diff diffs the database schema against the actual schema. // If there's no differences, it returns an empty string. func (s *Schema) Diff(actualScript string) string { - if s.Script == actualScript { - return "" - } - diff := godiffpatch.GeneratePatch(SchemaPath, s.Script, actualScript) - if diff == "" { - return "" - } - return diff + return cmp.Diff(s.Script, actualScript) } // WriteToFile writes the schema to the corresponding updated schema file. @@ -156,7 +138,7 @@ func (s *Schema) Migrate(logger *zap.Logger, db *Database, vacuumState int, enab zap.Int("current version", before), zap.Int("target version", after), ) - return fmt.Errorf("%w: %d < %d", ErrOld, before, after) + return fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after) } logger.Info("running migrations", diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 137b9b5016..c33d274f38 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -6,8 +6,11 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" ) -//go:embed schema/schema.sql schema/migrations/*.sql -var embedded embed.FS +//go:embed schema/schema.sql +var schemaScript string + +//go:embed schema/migrations/*.sql +var migrations embed.FS // Database represents a state database. type Database struct { @@ -16,13 +19,13 @@ type Database struct { // Schema returns the schema for the state database. func Schema() (*sql.Schema, error) { - migrations, err := sql.LoadSQLMigrations(embedded) + sqlMigrations, err := sql.LoadSQLMigrations(migrations) if err != nil { return nil, err } // NOTE: coded state migrations can be added here - // They can be a part of this statesql package - return sql.LoadSchema(embedded, migrations) + // They can be a part of this localsql package + return &sql.Schema{Script: schemaScript, Migrations: sqlMigrations}, nil } // Open opens a state database. diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go index cb18c78eb5..79ea387c8d 100644 --- a/sql/statesql/statesql_test.go +++ b/sql/statesql/statesql_test.go @@ -3,15 +3,16 @@ package statesql import ( "testing" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/test" ) -var fns = test.DBFuncs[*Database]{Schema: Schema, Open: Open, InMemory: InMemory} +type fns struct{} -func TestDatabase_MigrateTwice_NoOp(t *testing.T) { - test.VerifyMigrateTwiceNoOp(t, fns) -} +func (fns) Schema() (*sql.Schema, error) { return Schema() } +func (fns) InMemory(opts ...sql.Opt) *Database { return InMemory(opts...) } +func (fns) Open(uri string, opts ...sql.Opt) (*Database, error) { return Open(uri, opts...) } func TestSchema(t *testing.T) { - test.VerifySchema(t, fns) + test.RunSchemaTests(t, fns{}) } diff --git a/sql/test/test.go b/sql/test/test.go index 42b9368992..a70397bfb4 100644 --- a/sql/test/test.go +++ b/sql/test/test.go @@ -6,6 +6,10 @@ import ( "testing" "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" "github.com/spacemeshos/go-spacemesh/sql" ) @@ -15,69 +19,94 @@ type Database interface { Close() error } -type DBFuncs[DB Database] struct { - Schema func() (*sql.Schema, error) - Open func(uri string, opts ...sql.Opt) (DB, error) - InMemory func(opts ...sql.Opt) DB +type DBFuncs[DB Database] interface { + Schema() (*sql.Schema, error) + Open(uri string, opts ...sql.Opt) (DB, error) + InMemory(opts ...sql.Opt) DB } -func VerifyMigrateTwiceNoOp[DB Database](t *testing.T, funcs DBFuncs[DB]) { - file := filepath.Join(t.TempDir(), "test.db") - db, err := funcs.Open("file:"+file, sql.WithForceMigrations(true)) - require.NoError(t, err) +func RunSchemaTests[DB Database](t *testing.T, funcs DBFuncs[DB]) { + t.Run("idempotent migration", func(t *testing.T) { + observer, observedLogs := observer.New(zapcore.InfoLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) - sql1, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) - require.NoError(t, db.Close()) + file := filepath.Join(t.TempDir(), "test.db") + db, err := funcs.Open("file:"+file, sql.WithForceMigrations(true), sql.WithLogger(logger)) + require.NoError(t, err) - db, err = funcs.Open("file:" + file) - require.NoError(t, err) - sql2, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) + var versionA int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionA = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) - require.Equal(t, sql1, sql2) + sql1, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + require.NoError(t, db.Close()) - var version int - _, err = db.Exec("PRAGMA user_version;", nil, func(stmt *sql.Statement) bool { - version = stmt.ColumnInt(0) - return true - }) - require.NoError(t, err) - - require.NoError(t, err) - schema, err := funcs.Schema() - require.NoError(t, err) - expectedVersion := slices.MaxFunc( - []sql.Migration(schema.Migrations), - func(a, b sql.Migration) int { - return a.Order() - b.Order() - }) - require.Equal(t, expectedVersion.Order(), version) + require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") + l := observedLogs.All()[0] + require.Equal(t, "running migrations", l.Message) + require.Equal(t, int64(0), l.ContextMap()["current version"]) + require.Equal(t, int64(versionA), l.ContextMap()["target version"]) - require.NoError(t, db.Close()) -} + db, err = funcs.Open("file:"+file, sql.WithLogger(logger)) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + + require.Equal(t, sql1, sql2) -func VerifySchema[DB Database](t *testing.T, funcs DBFuncs[DB]) { - for _, tc := range []struct { - name string - forceMigrations bool - }{ - {name: "no migrations", forceMigrations: false}, - {name: "force migrations", forceMigrations: true}, - } { - t.Run(tc.name, func(t *testing.T) { - db := funcs.InMemory(sql.WithForceMigrations(tc.forceMigrations)) - loadedScript, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) - expSchema, err := funcs.Schema() - require.NoError(t, err) - diff := expSchema.Diff(loadedScript) - if diff != "" { - s := &sql.Schema{Script: loadedScript} - require.NoError(t, s.WriteToFile(".")) - t.Logf("updated schema written to %s", sql.UpdatedSchemaPath) - } - require.Empty(t, diff, "schema diff") + var versionB int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionB = stmt.ColumnInt(0) + return true }) - } + require.NoError(t, err) + + require.NoError(t, err) + schema, err := funcs.Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), versionA) + require.Equal(t, expectedVersion.Order(), versionB) + + require.NoError(t, db.Close()) + // make sure there's no schema drift warnings in the logs + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") + }) + + t.Run("schema", func(t *testing.T) { + for _, tc := range []struct { + name string + forceMigrations bool + }{ + {name: "no migrations", forceMigrations: false}, + {name: "force migrations", forceMigrations: true}, + } { + t.Run(tc.name, func(t *testing.T) { + db := funcs.InMemory(sql.WithForceMigrations(tc.forceMigrations)) + loadedScript, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + expSchema, err := funcs.Schema() + require.NoError(t, err) + diff := expSchema.Diff(loadedScript) + if diff != "" { + s := &sql.Schema{Script: loadedScript} + require.NoError(t, s.WriteToFile(".")) + t.Logf("updated schema written to %s", sql.UpdatedSchemaPath) + } + require.Empty(t, diff, "schema diff") + }) + } + }) } From 895f50f0cbdfe1ec31e6a75a47999e0322a734ae Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 5 Jun 2024 07:45:19 +0400 Subject: [PATCH 09/62] sql: add database schema handling docs --- README.md | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/README.md b/README.md index 0c0182c416..27b9663159 100644 --- a/README.md +++ b/README.md @@ -513,6 +513,86 @@ $ grpcurl -plaintext 127.0.0.1:9093 spacemesh.v1.DebugService.NetworkInfo } ``` +#### Handling database schema changes + +go-spacemesh currently maintains 2 SQLite databases in the data folder: `state.sql` (state database) and `local.sql` (local database). It employs schema versioning for both databases, with a possibility to upgrade older versions of each database to the current schema version by means of running a series of migrations. Also, go-spacemesh tracks any schema drift (unexpected schema changes) in the databases. + +When a database is first created, the corresponding schema file embedded in go-spacemesh executable is used to initialize it: +* `sql/statesql/schema/schema.sql` for `state.sql` +* `sql/localsql/schema/schema.sql` for `local.sql` +The schema file includes `PRAGMA user_version = ...` which sets the version of the database schema. The version of the schema is equal to the number of migrations defined for the corresponding database (`state.sql` or `local.sql`). + +For an existing database, the `PRAGMA user_version` is checked against the expected version number. If the database's schema version is too new, go-spacemesh fails right away as an older go-spacemesh version cannot be expected to work with a database from a newer version. If the database version number is older than the expected version, go-spacemesh runs the necessary migration scripts embedded in go-spacemesh executable and updates `PRAGMA user_version = ...`. The migration scripts are located in the following folders: +* `sql/statesql/schema/migrations` for `state.sql` +* `sql/localsql/schema/migrations` for `local.sql` + +Additionally, some migrations ("coded migrations") can be implemented in Go code, in which case they reside in `.go` files located in `sql/statesql` and `sql/localsql` packages, respectively. It is worth noting that old coded migrations can be removed at some point, rendering database versions that are *too* old unusable with newer go-spacemesh versions. + +After all the migrations are run, go-spacemesh compares the schema of each database to the embedded schema scripts and if they differ, warns the user about any differences: +``` +logger.go:146: 2024-06-05T05:39:32.247+0400 WARN database schema drift detected {"uri": "file:/var/folders/r0/4mks2v4n5ysbntnf3xq6h_q80000gn/T/TestSchemaidempotent_migration3425594786/001/test.db", "diff": " (\n \t\"\"\"\n \t... // 81 identical lines\n \t PRIMARY KEY (kind, epoch)\n \t) WITHOUT ROWID;\n- \t\n- \t-- some change\n \t\"\"\"\n )\n"} +``` + +In this case, an empty line and `-- some change` was added to `schema.sql` by hand. The pretty-printed diff looks like this: +``` + ( + """ + ... // 81 identical lines + PRIMARY KEY (kind, epoch) + ) WITHOUT ROWID; +- +- -- some change + """ + ) +``` + +The possible reasons for schema drift can be the following: +* running an unreleased version of go-spacemesh using your data folder. The unreleased version may contain migrations that may be changed before the release happens +* manual changes in the database +* external SQLite tooling used on the database that adds some tables, indices etc. + +In the latter case, it is possible to make go-spacemesh ignore certain objects (tables and indices) when checking for schema drift. For this, you can use `main.db-schema-ignore-rx` setting to set a regular expression that is used to ignore tables and indices in the database during schema drift checks. The setting defaults to `_litestream` to help with certain tooling. + +The schema changes in go-spacemesh code should be always done by means of adding migrations. After that, the schema tests in `sql/localsql` and `sql/statesql` will start failing. When the tests fail, they display the difference between the schema stored in `schema.sql` and the schema that is loaded from the database after running all the migrations. +If the schema changes shown in the diff are expected, the schema file needs to be updated. + +```console +$ # run the tests +$ eval $(make print-test-env) go test ./sql/localsql ./sql/statesql +... +=== RUN TestSchema/schema/force_migrations + test.go:106: updated schema written to schema/schema.sql.updated + test.go:108: + Error Trace: /Users/ivan4th/work/spacemesh/go-spacemesh/sql/test/test.go:108 + Error: Should be empty, but was ( + """ + ... // 81 identical lines + PRIMARY KEY (kind, epoch) + ) WITHOUT ROWID; + - -- some change + """ + ) + Test: TestSchema/schema/force_migrations + Messages: schema diff +FAIL +FAIL github.com/spacemeshos/go-spacemesh/sql/localsql 0.163s +ok github.com/spacemeshos/go-spacemesh/sql/statesql 0.286s +FAIL +$ git status +... +Untracked files: + (use "git add ..." to include in what will be committed) + sql/localsql/schema/schema.sql.updated + +$ # update the schema file +$ mv sql/localsql/schema/schema.sql{.updated,} + +$ # rerun the tests +$ eval $(make print-test-env) go test -count=1 ./sql/localsql ./sql/statesql +ok github.com/spacemeshos/go-spacemesh/sql/localsql 0.166s +ok github.com/spacemeshos/go-spacemesh/sql/statesql 0.293s +``` + #### Next Steps - Please visit our [wiki](https://github.com/spacemeshos/go-spacemesh/wiki) From 625717453764806ff9dc5c87079c53aefbe3727c Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 5 Jun 2024 09:05:41 +0400 Subject: [PATCH 10/62] config: fix build errors --- config/mainnet.go | 20 ++++++++++---------- config/presets/testnet.go | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/config/mainnet.go b/config/mainnet.go index 76379609c0..71c426f876 100644 --- a/config/mainnet.go +++ b/config/mainnet.go @@ -67,16 +67,16 @@ func MainnetConfig() Config { hare3conf.EnableLayer = 35117 return Config{ BaseConfig: BaseConfig{ - DataDirParent: defaultDataDir, - FileLock: filepath.Join(os.TempDir(), "spacemesh.lock"), - MetricsPort: 1010, - DatabaseConnections: 16, - DatabasePruneInterval: 30 * time.Minute, - DatabaseVacuumState: 15, - DatabaseIgnoreTableRx: "^_litestream", - PruneActivesetsFrom: 12, // starting from epoch 13 activesets below 12 will be pruned - ScanMalfeasantATXs: false, // opt-in - NetworkHRP: "sm", + DataDirParent: defaultDataDir, + FileLock: filepath.Join(os.TempDir(), "spacemesh.lock"), + MetricsPort: 1010, + DatabaseConnections: 16, + DatabasePruneInterval: 30 * time.Minute, + DatabaseVacuumState: 15, + DatabaseSchemaIgnoreRx: "^_litestream", + PruneActivesetsFrom: 12, // starting from epoch 13 activesets below 12 will be pruned + ScanMalfeasantATXs: false, // opt-in + NetworkHRP: "sm", LayerDuration: 5 * time.Minute, LayerAvgSize: 50, diff --git a/config/presets/testnet.go b/config/presets/testnet.go index c6f6f29a2e..a9c4d028da 100644 --- a/config/presets/testnet.go +++ b/config/presets/testnet.go @@ -65,7 +65,7 @@ func testnet() config.Config { DatabaseConnections: 16, DatabaseSizeMeteringInterval: 10 * time.Minute, DatabasePruneInterval: 30 * time.Minute, - DatabaseIgnoreTableRx: "^_litestream", + DatabaseSchemaIgnoreRx: "^_litestream", NetworkHRP: "stest", LayerDuration: 5 * time.Minute, From aada7c61b5e61505404bbd4cfdd7b4deeb7ca45a Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 8 Jun 2024 01:54:21 +0400 Subject: [PATCH 11/62] Fix README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9dd9254d3a..3c637dbebe 100644 --- a/README.md +++ b/README.md @@ -564,7 +564,7 @@ $ eval $(make print-test-env) go test ./sql/localsql ./sql/statesql === RUN TestSchema/schema/force_migrations test.go:106: updated schema written to schema/schema.sql.updated test.go:108: - Error Trace: /Users/ivan4th/work/spacemesh/go-spacemesh/sql/test/test.go:108 + Error Trace: /Users/user/spacemesh/go-spacemesh/sql/test/test.go:108 Error: Should be empty, but was ( """ ... // 81 identical lines From dbea047c8a1f17df06cf1cdc36471cb319fb5fec Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 8 Jun 2024 01:58:18 +0400 Subject: [PATCH 12/62] Fix go.mod / go.sum --- go.mod | 1 - go.sum | 2 -- 2 files changed, 3 deletions(-) diff --git a/go.mod b/go.mod index 1b60b2a250..ededa1f2cc 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,6 @@ require ( github.com/rs/cors v1.11.0 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/seehuhn/mt19937 v1.0.0 - github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e github.com/spacemeshos/api/release/go v1.42.0 github.com/spacemeshos/economics v0.1.3 github.com/spacemeshos/fixed v0.1.1 diff --git a/go.sum b/go.sum index 3ccb1c111b..06c9c3aad2 100644 --- a/go.sum +++ b/go.sum @@ -600,8 +600,6 @@ github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3 github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= -github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e h1:H+jDTUeF+SVd4ApwnSFoew8ZwGNRfgb9EsZc7LcocAg= -github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e/go.mod h1:VsUklG6OQo7Ctunu0gS3AtEOCEc2kMB6r5rKzxAes58= github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= github.com/spacemeshos/api/release/go v1.42.0 h1:K85zw+KZA1UA3VNwvXD2UIND7NLyAiJo4Kz6ZznFEEc= github.com/spacemeshos/api/release/go v1.42.0/go.mod h1:aCDRfna5MA7LJWZPa4k+vTRvBUf1Swz8kcziPcdp6i8= From 6bf114aacbfd9cbde640ad29d8fc6b40b80b20d0 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 12 Jun 2024 04:31:29 +0400 Subject: [PATCH 13/62] sql: use go:generate for database schema files --- sql/database.go | 63 +++++++++++-------- sql/localsql/localsql.go | 2 + sql/localsql/localsql_test.go | 70 ++++++++++++++++++--- sql/migrations.go | 29 +++++---- sql/schema.go | 56 +++++++++++++++++ sql/schema_test.go | 54 ++++++++++++++++ sql/schemagen/main.go | 48 +++++++++++++++ sql/statesql/statesql.go | 2 + sql/statesql/statesql_test.go | 70 ++++++++++++++++++--- sql/test/test.go | 112 ---------------------------------- 10 files changed, 341 insertions(+), 165 deletions(-) create mode 100644 sql/schema_test.go create mode 100644 sql/schemagen/main.go delete mode 100644 sql/test/test.go diff --git a/sql/database.go b/sql/database.go index bae0b39c88..9e88567911 100644 --- a/sql/database.go +++ b/sql/database.go @@ -71,17 +71,18 @@ func defaultConf() *conf { } type conf struct { - enableMigrations bool - forceFresh bool - forceMigrations bool - connections int - vacuumState int - enableLatency bool - cache bool - cacheSizes map[QueryCacheKind]int - logger *zap.Logger - schema *Schema - ignoreTableRx string + enableMigrations bool + forceFresh bool + forceMigrations bool + connections int + vacuumState int + enableLatency bool + cache bool + cacheSizes map[QueryCacheKind]int + logger *zap.Logger + schema *Schema + ignoreTableRx string + ignoreSchemaDrift bool } // WithConnections overwrites number of pooled connections. @@ -163,6 +164,12 @@ func WithIgnoreTableRx(rx string) Opt { } } +func withIgnoreSchemaDrift(ignore bool) Opt { + return func(c *conf) { + c.ignoreSchemaDrift = ignore + } +} + func withForceFresh() Opt { return func(c *conf) { c.forceFresh = true @@ -172,10 +179,16 @@ func withForceFresh() Opt { // Opt for configuring database. type Opt func(c *conf) -// InMemory database for testing. -func InMemory(opts ...Opt) *Database { +// OpenInMemory creates an in-memory database. +func OpenInMemory(opts ...Opt) (*Database, error) { opts = append(opts, WithConnections(1), withForceFresh()) - db, err := Open("file::memory:?mode=memory", opts...) + return Open("file::memory:?mode=memory", opts...) +} + +// InMemory creates an in-memory database for testing and panics if +// there's an error. +func InMemory(opts ...Opt) *Database { + db, err := OpenInMemory(opts...) if err != nil { panic(err) } @@ -231,16 +244,18 @@ func Open(uri string, opts ...Opt) (*Database, error) { } } - loaded, err := LoadDBSchemaScript(db, config.ignoreTableRx) - if err != nil { - return nil, fmt.Errorf("error loading database schema: %w", err) - } - diff := config.schema.Diff(loaded) - if diff != "" { - config.logger.Warn("database schema drift detected", - zap.String("uri", uri), - zap.String("diff", diff), - ) + if !config.ignoreSchemaDrift { + loaded, err := LoadDBSchemaScript(db, config.ignoreTableRx) + if err != nil { + return nil, fmt.Errorf("error loading database schema: %w", err) + } + diff := config.schema.Diff(loaded) + if diff != "" { + config.logger.Warn("database schema drift detected", + zap.String("uri", uri), + zap.String("diff", diff), + ) + } } if config.cache { diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go index 012e66e73c..5084a996bf 100644 --- a/sql/localsql/localsql.go +++ b/sql/localsql/localsql.go @@ -6,6 +6,8 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" ) +//go:generate go run ../schemagen -dbtype local -output schema/schema.sql + //go:embed schema/schema.sql var schemaScript string diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go index c6ce6bf02f..f00f6be18d 100644 --- a/sql/localsql/localsql_test.go +++ b/sql/localsql/localsql_test.go @@ -1,18 +1,74 @@ package localsql import ( + "path/filepath" + "slices" "testing" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" + "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/test" ) -type fns struct{} +func TestIdempotentMigration(t *testing.T) { + observer, observedLogs := observer.New(zapcore.InfoLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + file := filepath.Join(t.TempDir(), "test.db") + db, err := Open("file:"+file, sql.WithForceMigrations(true), sql.WithLogger(logger)) + require.NoError(t, err) + + var versionA int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionA = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + sql1, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + require.NoError(t, db.Close()) + + require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") + l := observedLogs.All()[0] + require.Equal(t, "running migrations", l.Message) + require.Equal(t, int64(0), l.ContextMap()["current version"]) + require.Equal(t, int64(versionA), l.ContextMap()["target version"]) + + db, err = Open("file:"+file, sql.WithLogger(logger)) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + + require.Equal(t, sql1, sql2) + + var versionB int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionB = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) -func (fns) Schema() (*sql.Schema, error) { return Schema() } -func (fns) InMemory(opts ...sql.Opt) *Database { return InMemory(opts...) } -func (fns) Open(uri string, opts ...sql.Opt) (*Database, error) { return Open(uri, opts...) } + require.NoError(t, err) + schema, err := Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), versionA) + require.Equal(t, expectedVersion.Order(), versionB) -func TestSchema(t *testing.T) { - test.RunSchemaTests(t, fns{}) + require.NoError(t, db.Close()) + // make sure there's no schema drift warnings in the logs + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") } diff --git a/sql/migrations.go b/sql/migrations.go index 2565e295ce..9152fcf189 100644 --- a/sql/migrations.go +++ b/sql/migrations.go @@ -1,10 +1,9 @@ package sql import ( - "bufio" - "bytes" "fmt" "io/fs" + "regexp" "slices" "strconv" "strings" @@ -41,9 +40,11 @@ func (l MigrationList) Version() int { type sqlMigration struct { order int name string - content *bufio.Scanner + content string } +var sqlCommentRx = regexp.MustCompile(`(?m)--.*$`) + func (m *sqlMigration) Apply(db Executor) error { current, err := version(db) if err != nil { @@ -53,9 +54,14 @@ func (m *sqlMigration) Apply(db Executor) error { if m.order <= current { return nil } - for m.content.Scan() { - if _, err := db.Exec(m.content.Text(), nil, nil); err != nil { - return fmt.Errorf("exec %s: %w", m.content.Text(), err) + // TODO: use more advanced approach to split the SQL script + // into commands + for _, cmd := range strings.Split(m.content, ";") { + cmd = sqlCommentRx.ReplaceAllString(cmd, "") + if strings.TrimSpace(cmd) != "" { + if _, err := db.Exec(cmd, nil, nil); err != nil { + return fmt.Errorf("exec %s: %w", cmd, err) + } } } // binding values in pragma statement is not allowed @@ -107,21 +113,14 @@ func LoadSQLMigrations(fsys fs.FS) (MigrationList, error) { if err != nil { return fmt.Errorf("invalid migration %s: %w", d.Name(), err) } - f, err := fsys.Open(path) + script, err := fs.ReadFile(fsys, path) if err != nil { return fmt.Errorf("read file %s: %w", path, err) } - scanner := bufio.NewScanner(f) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if i := bytes.Index(data, []byte(";")); i >= 0 { - return i + 1, data[0 : i+1], nil - } - return 0, nil, nil - }) migrations = append(migrations, &sqlMigration{ order: order, name: d.Name(), - content: scanner, + content: string(script), }) return nil }) diff --git a/sql/schema.go b/sql/schema.go index 82fc2badb7..d0b1aa8d98 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "io" "os" "path/filepath" "regexp" @@ -182,3 +183,58 @@ func (s *Schema) Migrate(logger *zap.Logger, db *Database, vacuumState int, enab } return nil } + +// SchemaGenOpt represents a schema generator option +type SchemaGenOpt func(g *SchemaGen) + +func withDefaultOut(w io.Writer) SchemaGenOpt { + return func(g *SchemaGen) { + g.defaultOut = w + } +} + +// SchemaGen generates database schema files +type SchemaGen struct { + logger *zap.Logger + schema *Schema + defaultOut io.Writer +} + +// NewSchemaGen creates a new SchemaGen instance +func NewSchemaGen(logger *zap.Logger, schema *Schema, opts ...SchemaGenOpt) *SchemaGen { + g := &SchemaGen{logger: logger, schema: schema, defaultOut: os.Stdout} + for _, opt := range opts { + opt(g) + } + return g +} + +// Generate generates database schema and writes it to the specified file. +// If an empty string is specified as outputFile, the +func (g *SchemaGen) Generate(outputFile string) error { + db, err := OpenInMemory( + WithLogger(g.logger), + WithDatabaseSchema(g.schema), + WithForceMigrations(true), + withIgnoreSchemaDrift(true)) + if err != nil { + return fmt.Errorf("error opening in-memory db: %w", err) + } + defer func() { + if err := db.Close(); err != nil { + g.logger.Error("error closing in-memory db: %w", zap.Error(err)) + } + }() + loadedScript, err := LoadDBSchemaScript(db, "") + if err != nil { + return fmt.Errorf("error loading DB schema script: %w", err) + } + if outputFile == "" { + if _, err := io.WriteString(g.defaultOut, loadedScript); err != nil { + return fmt.Errorf("error writing schema file: %w", err) + } + } else if err := os.WriteFile(outputFile, []byte(loadedScript), 0777); err != nil { + return fmt.Errorf("error writing schema file %q: %w", outputFile, err) + } + return nil +} diff --git a/sql/schema_test.go b/sql/schema_test.go new file mode 100644 index 0000000000..f9ccc083fe --- /dev/null +++ b/sql/schema_test.go @@ -0,0 +1,54 @@ +package sql + +import ( + "os" + "path/filepath" + "strings" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" +) + +func TestSchemaGen(t *testing.T) { + observer, observedLogs := observer.New(zapcore.WarnLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + fs := fstest.MapFS{ + "schema/migrations/0001_first.sql": &fstest.MapFile{ + Data: []byte("create table foo(id int);"), + }, + "schema/migrations/0002_second.sql": &fstest.MapFile{ + Data: []byte("create table bar(id int);"), + }, + } + migrations, err := LoadSQLMigrations(fs) + require.NoError(t, err) + require.Len(t, migrations, 2) + schema := &Schema{ + Script: "this should not be run", + Migrations: migrations, + } + var sb strings.Builder + g := NewSchemaGen(logger, schema, withDefaultOut(&sb)) + tempDir := t.TempDir() + schemaPath := filepath.Join(tempDir, "schema.sql") + require.NoError(t, g.Generate(schemaPath)) + contents, err := os.ReadFile(schemaPath) + require.NoError(t, err) + require.Equal(t, + "PRAGMA user_version = 2;\nCREATE TABLE bar(id int);\nCREATE TABLE foo(id int);\n", + string(contents)) + require.NoError(t, g.Generate("")) + require.Equal(t, string(contents), sb.String()) + + require.Equal(t, 0, observedLogs.Len(), + "expected 0 warning messages in the log (schema drift warnings?)") +} diff --git a/sql/schemagen/main.go b/sql/schemagen/main.go new file mode 100644 index 0000000000..db064ce60b --- /dev/null +++ b/sql/schemagen/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "flag" + "os" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +var ( + level = zap.LevelFlag("level", zapcore.ErrorLevel, "set log verbosity level") + dbType = flag.String("dbtype", "state", "database type (state, local, default state)") + output = flag.String("output", "", "output file (defaults to stdin)") +) + +func main() { + var ( + err error + schema *sql.Schema + ) + flag.Parse() + core := zapcore.NewCore( + zapcore.NewConsoleEncoder(zap.NewProductionEncoderConfig()), + os.Stderr, + zap.NewAtomicLevelAt(*level), + ) + logger := zap.New(core).With(zap.String("dbType", *dbType)) + switch *dbType { + case "state": + schema, err = statesql.Schema() + case "local": + schema, err = localsql.Schema() + default: + logger.Fatal("unknown database type, must be state or local") + } + if err != nil { + logger.Fatal("error loading db schema", zap.Error(err)) + } + g := sql.NewSchemaGen(logger, schema) + if err := g.Generate(*output); err != nil { + logger.Fatal("error generating schema", zap.Error(err), zap.String("output", *output)) + } +} diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index c33d274f38..76139fac5f 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -6,6 +6,8 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" ) +//go:generate go run ../schemagen -dbtype state -output schema/schema.sql + //go:embed schema/schema.sql var schemaScript string diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go index 79ea387c8d..66a55cca8e 100644 --- a/sql/statesql/statesql_test.go +++ b/sql/statesql/statesql_test.go @@ -1,18 +1,74 @@ package statesql import ( + "path/filepath" + "slices" "testing" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" + "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/test" ) -type fns struct{} +func TestIdempotentMigration(t *testing.T) { + observer, observedLogs := observer.New(zapcore.InfoLevel) + logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( + func(core zapcore.Core) zapcore.Core { + return zapcore.NewTee(core, observer) + }, + ))) + + file := filepath.Join(t.TempDir(), "test.db") + db, err := Open("file:"+file, sql.WithForceMigrations(true), sql.WithLogger(logger)) + require.NoError(t, err) + + var versionA int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionA = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) + + sql1, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + require.NoError(t, db.Close()) + + require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") + l := observedLogs.All()[0] + require.Equal(t, "running migrations", l.Message) + require.Equal(t, int64(0), l.ContextMap()["current version"]) + require.Equal(t, int64(versionA), l.ContextMap()["target version"]) + + db, err = Open("file:"+file, sql.WithLogger(logger)) + require.NoError(t, err) + sql2, err := sql.LoadDBSchemaScript(db, "") + require.NoError(t, err) + + require.Equal(t, sql1, sql2) + + var versionB int + _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { + versionB = stmt.ColumnInt(0) + return true + }) + require.NoError(t, err) -func (fns) Schema() (*sql.Schema, error) { return Schema() } -func (fns) InMemory(opts ...sql.Opt) *Database { return InMemory(opts...) } -func (fns) Open(uri string, opts ...sql.Opt) (*Database, error) { return Open(uri, opts...) } + require.NoError(t, err) + schema, err := Schema() + require.NoError(t, err) + expectedVersion := slices.MaxFunc( + []sql.Migration(schema.Migrations), + func(a, b sql.Migration) int { + return a.Order() - b.Order() + }) + require.Equal(t, expectedVersion.Order(), versionA) + require.Equal(t, expectedVersion.Order(), versionB) -func TestSchema(t *testing.T) { - test.RunSchemaTests(t, fns{}) + require.NoError(t, db.Close()) + // make sure there's no schema drift warnings in the logs + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") } diff --git a/sql/test/test.go b/sql/test/test.go deleted file mode 100644 index a70397bfb4..0000000000 --- a/sql/test/test.go +++ /dev/null @@ -1,112 +0,0 @@ -package test - -import ( - "path/filepath" - "slices" - "testing" - - "github.com/stretchr/testify/require" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "go.uber.org/zap/zaptest" - "go.uber.org/zap/zaptest/observer" - - "github.com/spacemeshos/go-spacemesh/sql" -) - -type Database interface { - sql.Executor - Close() error -} - -type DBFuncs[DB Database] interface { - Schema() (*sql.Schema, error) - Open(uri string, opts ...sql.Opt) (DB, error) - InMemory(opts ...sql.Opt) DB -} - -func RunSchemaTests[DB Database](t *testing.T, funcs DBFuncs[DB]) { - t.Run("idempotent migration", func(t *testing.T) { - observer, observedLogs := observer.New(zapcore.InfoLevel) - logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( - func(core zapcore.Core) zapcore.Core { - return zapcore.NewTee(core, observer) - }, - ))) - - file := filepath.Join(t.TempDir(), "test.db") - db, err := funcs.Open("file:"+file, sql.WithForceMigrations(true), sql.WithLogger(logger)) - require.NoError(t, err) - - var versionA int - _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { - versionA = stmt.ColumnInt(0) - return true - }) - require.NoError(t, err) - - sql1, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) - require.NoError(t, db.Close()) - - require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") - l := observedLogs.All()[0] - require.Equal(t, "running migrations", l.Message) - require.Equal(t, int64(0), l.ContextMap()["current version"]) - require.Equal(t, int64(versionA), l.ContextMap()["target version"]) - - db, err = funcs.Open("file:"+file, sql.WithLogger(logger)) - require.NoError(t, err) - sql2, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) - - require.Equal(t, sql1, sql2) - - var versionB int - _, err = db.Exec("PRAGMA user_version", nil, func(stmt *sql.Statement) bool { - versionB = stmt.ColumnInt(0) - return true - }) - require.NoError(t, err) - - require.NoError(t, err) - schema, err := funcs.Schema() - require.NoError(t, err) - expectedVersion := slices.MaxFunc( - []sql.Migration(schema.Migrations), - func(a, b sql.Migration) int { - return a.Order() - b.Order() - }) - require.Equal(t, expectedVersion.Order(), versionA) - require.Equal(t, expectedVersion.Order(), versionB) - - require.NoError(t, db.Close()) - // make sure there's no schema drift warnings in the logs - require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") - }) - - t.Run("schema", func(t *testing.T) { - for _, tc := range []struct { - name string - forceMigrations bool - }{ - {name: "no migrations", forceMigrations: false}, - {name: "force migrations", forceMigrations: true}, - } { - t.Run(tc.name, func(t *testing.T) { - db := funcs.InMemory(sql.WithForceMigrations(tc.forceMigrations)) - loadedScript, err := sql.LoadDBSchemaScript(db, "") - require.NoError(t, err) - expSchema, err := funcs.Schema() - require.NoError(t, err) - diff := expSchema.Diff(loadedScript) - if diff != "" { - s := &sql.Schema{Script: loadedScript} - require.NoError(t, s.WriteToFile(".")) - t.Logf("updated schema written to %s", sql.UpdatedSchemaPath) - } - require.Empty(t, diff, "schema diff") - }) - } - }) -} From bd62a8eb426332ac21087e6a08fe0817e990c6a9 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 12 Jun 2024 17:08:07 +0400 Subject: [PATCH 14/62] sql: make schema drift fatal by default --- config/config.go | 1 + node/node.go | 1 + sql/atxs/atxs_test.go | 4 ++-- sql/database.go | 20 +++++++++++++++++--- sql/database_test.go | 22 +++++++++++++++++++--- sql/rewards/rewards_test.go | 1 + sql/schema.go | 2 +- sql/vacuum_test.go | 2 +- 8 files changed, 43 insertions(+), 10 deletions(-) diff --git a/config/config.go b/config/config.go index f389fc82a6..e9e54b2ff0 100644 --- a/config/config.go +++ b/config/config.go @@ -118,6 +118,7 @@ type BaseConfig struct { DatabaseQueryCache bool `mapstructure:"db-query-cache"` DatabaseQueryCacheSizes DatabaseQueryCacheSizes `mapstructure:"db-query-cache-sizes"` DatabaseSchemaIgnoreRx string `mapstructure:"db-ignore-schema-rx"` + DatabaseSchemaAllowDrift bool `mapstructure:"db-allow-schema-drift"` PruneActivesetsFrom types.EpochID `mapstructure:"prune-activesets-from"` diff --git a/node/node.go b/node/node.go index b7723fdd49..2f21f6d8f4 100644 --- a/node/node.go +++ b/node/node.go @@ -1934,6 +1934,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { sql.WithLatencyMetering(app.Config.DatabaseLatencyMetering), sql.WithVacuumState(app.Config.DatabaseVacuumState), sql.WithIgnoreTableRx(app.Config.DatabaseSchemaIgnoreRx), + sql.WithAllowSchemaDrift(app.Config.DatabaseSchemaAllowDrift), sql.WithQueryCache(app.Config.DatabaseQueryCache), sql.WithQueryCacheSizes(map[sql.QueryCacheKind]int{ atxs.CacheKindEpochATXs: app.Config.DatabaseQueryCacheSizes.EpochATXs, diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 35696e1afc..1fa6c3b563 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -103,7 +103,7 @@ func TestHasID(t *testing.T) { } func Test_IdentityExists(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -669,7 +669,7 @@ func TestGetBlobCached_CacheEntriesAreDistinct(t *testing.T) { // Test that the cached blob is not shared with the caller // but copied into the provided blob. func TestGetBlobCached_OverwriteSafety(t *testing.T) { - db := sql.InMemory(sql.WithQueryCache(true)) + db := statesql.InMemory(sql.WithQueryCache(true)) atx := types.ActivationTx{} atx.SetID(types.RandomATXID()) atx.AtxBlob.Blob = []byte("original blob") diff --git a/sql/database.go b/sql/database.go index 9e88567911..94391eaf77 100644 --- a/sql/database.go +++ b/sql/database.go @@ -82,6 +82,7 @@ type conf struct { logger *zap.Logger schema *Schema ignoreTableRx string + allowSchemaDrift bool ignoreSchemaDrift bool } @@ -164,9 +165,16 @@ func WithIgnoreTableRx(rx string) Opt { } } -func withIgnoreSchemaDrift(ignore bool) Opt { +// WithAllowSchemaDrift prevents Open from failing upon schema drift. A warning is printed instead +func WithAllowSchemaDrift(allow bool) Opt { return func(c *conf) { - c.ignoreSchemaDrift = ignore + c.allowSchemaDrift = allow + } +} + +func withIgnoreSchemaDrift() Opt { + return func(c *conf) { + c.ignoreSchemaDrift = true } } @@ -250,11 +258,17 @@ func Open(uri string, opts ...Opt) (*Database, error) { return nil, fmt.Errorf("error loading database schema: %w", err) } diff := config.schema.Diff(loaded) - if diff != "" { + switch { + case diff == "": // ok + case config.allowSchemaDrift: config.logger.Warn("database schema drift detected", zap.String("uri", uri), zap.String("diff", diff), ) + default: + return nil, errors.Join( + fmt.Errorf("schema drift detected (uri %s):\n%s", uri, diff), + db.Close()) } } diff --git a/sql/database_test.go b/sql/database_test.go index e1e6332401..29ceaca47d 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -25,6 +25,7 @@ func Test_Transaction_Isolation(t *testing.T) { field int );`, }), + withIgnoreSchemaDrift(), ) tx, err := db.Tx(context.Background()) require.NoError(t, err) @@ -89,6 +90,7 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), + withIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -120,6 +122,7 @@ func Test_Migration_Disabled(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), + withIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -154,6 +157,7 @@ func TestDatabaseSkipMigrations(t *testing.T) { db, err := Open("file:"+dbFile, WithDatabaseSchema(schema), WithForceMigrations(true), + withIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -177,6 +181,7 @@ func TestDatabaseVacuumState(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), + withIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -186,6 +191,7 @@ func TestDatabaseVacuumState(t *testing.T) { Migrations: MigrationList{migration1, migration2}, }), WithVacuumState(2), + withIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -196,7 +202,7 @@ func TestDatabaseVacuumState(t *testing.T) { } func TestQueryCount(t *testing.T) { - db := InMemory() + db := InMemory(withIgnoreSchemaDrift()) require.Equal(t, 0, db.QueryCount()) n, err := db.Exec("select 1", nil, nil) @@ -220,7 +226,7 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { migration2.EXPECT().Order().Return(2).AnyTimes() dbFile := filepath.Join(dir, "test.sql") - db, err := Open("file:"+dbFile, WithForceMigrations(true)) + db, err := Open("file:"+dbFile, WithForceMigrations(true), withIgnoreSchemaDrift()) require.NoError(t, err) _, err = db.Exec("PRAGMA user_version = 3", nil, nil) require.NoError(t, err) @@ -230,6 +236,7 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1, migration2}, }), + withIgnoreSchemaDrift(), ) require.ErrorIs(t, err, ErrTooNew) } @@ -267,11 +274,20 @@ func TestSchemaDrift(t *testing.T) { WithDatabaseSchema(schema), WithLogger(logger), ) + require.Error(t, err) + require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, err.Error()) + require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") + + db, err = Open("file:"+dbFile, + WithDatabaseSchema(schema), + WithLogger(logger), + WithAllowSchemaDrift(true), + ) require.NoError(t, db.Close()) require.NoError(t, err) require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") require.Equal(t, "database schema drift detected", observedLogs.All()[0].Message) - require.Regexp(t, `.*\n\s*\+\s*CREATE TABLE newtbl \(id int\);`, + require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, observedLogs.All()[0].ContextMap()["diff"]) } diff --git a/sql/rewards/rewards_test.go b/sql/rewards/rewards_test.go index af57f33c9a..45dc26d2d9 100644 --- a/sql/rewards/rewards_test.go +++ b/sql/rewards/rewards_test.go @@ -247,6 +247,7 @@ func Test_0008Migration(t *testing.T) { db := statesql.InMemory( sql.WithDatabaseSchema(schema), sql.WithForceMigrations(true), + sql.WithAllowSchemaDrift(true), ) // verify that the DB is empty diff --git a/sql/schema.go b/sql/schema.go index d0b1aa8d98..c2ffabc5b1 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -216,7 +216,7 @@ func (g *SchemaGen) Generate(outputFile string) error { WithLogger(g.logger), WithDatabaseSchema(g.schema), WithForceMigrations(true), - withIgnoreSchemaDrift(true)) + withIgnoreSchemaDrift()) if err != nil { return fmt.Errorf("error opening in-memory db: %w", err) } diff --git a/sql/vacuum_test.go b/sql/vacuum_test.go index b994516279..5017cad677 100644 --- a/sql/vacuum_test.go +++ b/sql/vacuum_test.go @@ -7,6 +7,6 @@ import ( ) func TestVacuumDB(t *testing.T) { - db := InMemory() + db := InMemory(withIgnoreSchemaDrift()) require.NoError(t, Vacuum(db)) } From c9d1a6c90aebef568f93f688d3ce81f586223c6f Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 12 Jun 2024 17:27:25 +0400 Subject: [PATCH 15/62] sql: remove db-ignore-schema-rx config option Also, add a migration that removes stray tables from quicksynced DBs --- config/config.go | 8 ++-- config/mainnet.go | 19 ++++----- config/presets/testnet.go | 1 - node/node.go | 1 - sql/database.go | 13 +----- sql/database_test.go | 42 +------------------ sql/localsql/localsql_test.go | 4 +- sql/schema.go | 31 +++++--------- .../schema/migrations/0019_schema_cleanup.sql | 2 + sql/statesql/schema/schema.sql | 2 +- sql/statesql/statesql_test.go | 4 +- 11 files changed, 32 insertions(+), 95 deletions(-) create mode 100644 sql/statesql/schema/migrations/0019_schema_cleanup.sql diff --git a/config/config.go b/config/config.go index e9e54b2ff0..78ef09c66a 100644 --- a/config/config.go +++ b/config/config.go @@ -117,7 +117,6 @@ type BaseConfig struct { DatabaseSkipMigrations []int `mapstructure:"db-skip-migrations"` DatabaseQueryCache bool `mapstructure:"db-query-cache"` DatabaseQueryCacheSizes DatabaseQueryCacheSizes `mapstructure:"db-query-cache-sizes"` - DatabaseSchemaIgnoreRx string `mapstructure:"db-ignore-schema-rx"` DatabaseSchemaAllowDrift bool `mapstructure:"db-allow-schema-drift"` PruneActivesetsFrom types.EpochID `mapstructure:"prune-activesets-from"` @@ -247,10 +246,9 @@ func defaultBaseConfig() BaseConfig { ATXBlob: 10000, ActiveSetBlob: 200, }, - DatabaseSchemaIgnoreRx: "^_litestream", - NetworkHRP: "sm", - ATXGradeDelay: 10 * time.Second, - PostValidDelay: 12 * time.Hour, + NetworkHRP: "sm", + ATXGradeDelay: 10 * time.Second, + PostValidDelay: 12 * time.Hour, PprofHTTPServerListener: "localhost:6060", } diff --git a/config/mainnet.go b/config/mainnet.go index 71c426f876..4a910bce6f 100644 --- a/config/mainnet.go +++ b/config/mainnet.go @@ -67,16 +67,15 @@ func MainnetConfig() Config { hare3conf.EnableLayer = 35117 return Config{ BaseConfig: BaseConfig{ - DataDirParent: defaultDataDir, - FileLock: filepath.Join(os.TempDir(), "spacemesh.lock"), - MetricsPort: 1010, - DatabaseConnections: 16, - DatabasePruneInterval: 30 * time.Minute, - DatabaseVacuumState: 15, - DatabaseSchemaIgnoreRx: "^_litestream", - PruneActivesetsFrom: 12, // starting from epoch 13 activesets below 12 will be pruned - ScanMalfeasantATXs: false, // opt-in - NetworkHRP: "sm", + DataDirParent: defaultDataDir, + FileLock: filepath.Join(os.TempDir(), "spacemesh.lock"), + MetricsPort: 1010, + DatabaseConnections: 16, + DatabasePruneInterval: 30 * time.Minute, + DatabaseVacuumState: 15, + PruneActivesetsFrom: 12, // starting from epoch 13 activesets below 12 will be pruned + ScanMalfeasantATXs: false, // opt-in + NetworkHRP: "sm", LayerDuration: 5 * time.Minute, LayerAvgSize: 50, diff --git a/config/presets/testnet.go b/config/presets/testnet.go index a9c4d028da..892d924aaf 100644 --- a/config/presets/testnet.go +++ b/config/presets/testnet.go @@ -65,7 +65,6 @@ func testnet() config.Config { DatabaseConnections: 16, DatabaseSizeMeteringInterval: 10 * time.Minute, DatabasePruneInterval: 30 * time.Minute, - DatabaseSchemaIgnoreRx: "^_litestream", NetworkHRP: "stest", LayerDuration: 5 * time.Minute, diff --git a/node/node.go b/node/node.go index 2f21f6d8f4..c7a9da2564 100644 --- a/node/node.go +++ b/node/node.go @@ -1933,7 +1933,6 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { sql.WithConnections(app.Config.DatabaseConnections), sql.WithLatencyMetering(app.Config.DatabaseLatencyMetering), sql.WithVacuumState(app.Config.DatabaseVacuumState), - sql.WithIgnoreTableRx(app.Config.DatabaseSchemaIgnoreRx), sql.WithAllowSchemaDrift(app.Config.DatabaseSchemaAllowDrift), sql.WithQueryCache(app.Config.DatabaseQueryCache), sql.WithQueryCacheSizes(map[sql.QueryCacheKind]int{ diff --git a/sql/database.go b/sql/database.go index 94391eaf77..78044cd159 100644 --- a/sql/database.go +++ b/sql/database.go @@ -81,7 +81,6 @@ type conf struct { cacheSizes map[QueryCacheKind]int logger *zap.Logger schema *Schema - ignoreTableRx string allowSchemaDrift bool ignoreSchemaDrift bool } @@ -157,15 +156,7 @@ func WithDatabaseSchema(schema *Schema) Opt { } } -// WithIgnoreTableRx specifies regular expression for table names to ignore during schema -// drift detection. -func WithIgnoreTableRx(rx string) Opt { - return func(c *conf) { - c.ignoreTableRx = rx - } -} - -// WithAllowSchemaDrift prevents Open from failing upon schema drift. A warning is printed instead +// WithAllowSchemaDrift prevents Open from failing upon schema drift. A warning is printed instead. func WithAllowSchemaDrift(allow bool) Opt { return func(c *conf) { c.allowSchemaDrift = allow @@ -253,7 +244,7 @@ func Open(uri string, opts ...Opt) (*Database, error) { } if !config.ignoreSchemaDrift { - loaded, err := LoadDBSchemaScript(db, config.ignoreTableRx) + loaded, err := LoadDBSchemaScript(db) if err != nil { return nil, fmt.Errorf("error loading database schema: %w", err) } diff --git a/sql/database_test.go b/sql/database_test.go index 29ceaca47d..c91f61fbac 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -270,7 +270,7 @@ func TestSchemaDrift(t *testing.T) { require.NoError(t, db.Close()) require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") - db, err = Open("file:"+dbFile, + _, err = Open("file:"+dbFile, WithDatabaseSchema(schema), WithLogger(logger), ) @@ -290,43 +290,3 @@ func TestSchemaDrift(t *testing.T) { require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, observedLogs.All()[0].ContextMap()["diff"]) } - -func TestSchemaDrift_IgnoredTables(t *testing.T) { - observer, observedLogs := observer.New(zapcore.WarnLevel) - logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.WrapCore( - func(core zapcore.Core) zapcore.Core { - return zapcore.NewTee(core, observer) - }, - ))) - dbFile := filepath.Join(t.TempDir(), "test.sql") - schema := &Schema{ - // Not using ` here to avoid schema drift warnings due to whitespace - // TODO: ignore whitespace and comments during schema comparison - Script: "PRAGMA user_version = 0;\n" + - "CREATE TABLE testing1 (\n" + - " id varchar primary key,\n" + - " field int\n" + - ");\n", - } - db, err := Open("file:"+dbFile, - WithDatabaseSchema(schema), - WithLogger(logger), - WithIgnoreTableRx("^_litestream"), - ) - require.NoError(t, err) - - _, err = db.Exec("create table _litestream_test (id int)", nil, nil) - require.NoError(t, err) - - require.NoError(t, db.Close()) - require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") - - db, err = Open("file:"+dbFile, - WithDatabaseSchema(schema), - WithLogger(logger), - WithIgnoreTableRx("^_litestream"), - ) - require.NoError(t, err) - require.NoError(t, db.Close()) - require.Equal(t, 0, observedLogs.Len(), "expected 0 log messages") -} diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go index f00f6be18d..85bdc4d418 100644 --- a/sql/localsql/localsql_test.go +++ b/sql/localsql/localsql_test.go @@ -33,7 +33,7 @@ func TestIdempotentMigration(t *testing.T) { }) require.NoError(t, err) - sql1, err := sql.LoadDBSchemaScript(db, "") + sql1, err := sql.LoadDBSchemaScript(db) require.NoError(t, err) require.NoError(t, db.Close()) @@ -45,7 +45,7 @@ func TestIdempotentMigration(t *testing.T) { db, err = Open("file:"+file, sql.WithLogger(logger)) require.NoError(t, err) - sql2, err := sql.LoadDBSchemaScript(db, "") + sql2, err := sql.LoadDBSchemaScript(db) require.NoError(t, err) require.Equal(t, sql1, sql2) diff --git a/sql/schema.go b/sql/schema.go index c2ffabc5b1..5614185567 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -9,7 +9,6 @@ import ( "io" "os" "path/filepath" - "regexp" "strings" "github.com/google/go-cmp/cmp" @@ -22,19 +21,11 @@ const ( ) // LoadDBSchemaScript retrieves the database schema as text. -func LoadDBSchemaScript(db Executor, ignoreRx string) (string, error) { +func LoadDBSchemaScript(db Executor) (string, error) { var ( - err error - ignRx *regexp.Regexp - sb strings.Builder + err error + sb strings.Builder ) - if ignoreRx != "" { - ignRx, err = regexp.Compile(ignoreRx) - if err != nil { - return "", fmt.Errorf("error compiling table ignore regexp %q: %w", - ignoreRx, err) - } - } version, err := version(db) if err != nil { return "", err @@ -46,9 +37,7 @@ func LoadDBSchemaScript(db Executor, ignoreRx string) (string, error) { where sql is not null order by tbl_name, type desc, name`, nil, func(st *Statement) bool { - if ignRx == nil || !ignRx.MatchString(st.ColumnText(0)) { - fmt.Fprintln(&sb, st.ColumnText(1)) - } + fmt.Fprintln(&sb, st.ColumnText(1)) return true }); err != nil { return "", fmt.Errorf("error retrieving DB schema: %w", err) @@ -184,7 +173,7 @@ func (s *Schema) Migrate(logger *zap.Logger, db *Database, vacuumState int, enab return nil } -// SchemaGenOpt represents a schema generator option +// SchemaGenOpt represents a schema generator option. type SchemaGenOpt func(g *SchemaGen) func withDefaultOut(w io.Writer) SchemaGenOpt { @@ -193,14 +182,14 @@ func withDefaultOut(w io.Writer) SchemaGenOpt { } } -// SchemaGen generates database schema files +// SchemaGen generates database schema files. type SchemaGen struct { logger *zap.Logger schema *Schema defaultOut io.Writer } -// NewSchemaGen creates a new SchemaGen instance +// NewSchemaGen creates a new SchemaGen instance. func NewSchemaGen(logger *zap.Logger, schema *Schema, opts ...SchemaGenOpt) *SchemaGen { g := &SchemaGen{logger: logger, schema: schema, defaultOut: os.Stdout} for _, opt := range opts { @@ -210,7 +199,7 @@ func NewSchemaGen(logger *zap.Logger, schema *Schema, opts ...SchemaGenOpt) *Sch } // Generate generates database schema and writes it to the specified file. -// If an empty string is specified as outputFile, the +// If an empty string is specified as outputFile, os.Stdout is used for output. func (g *SchemaGen) Generate(outputFile string) error { db, err := OpenInMemory( WithLogger(g.logger), @@ -225,7 +214,7 @@ func (g *SchemaGen) Generate(outputFile string) error { g.logger.Error("error closing in-memory db: %w", zap.Error(err)) } }() - loadedScript, err := LoadDBSchemaScript(db, "") + loadedScript, err := LoadDBSchemaScript(db) if err != nil { return fmt.Errorf("error loading DB schema script: %w", err) } @@ -233,7 +222,7 @@ func (g *SchemaGen) Generate(outputFile string) error { if _, err := io.WriteString(g.defaultOut, loadedScript); err != nil { return fmt.Errorf("error writing schema file: %w", err) } - } else if err := os.WriteFile(outputFile, []byte(loadedScript), 0777); err != nil { + } else if err := os.WriteFile(outputFile, []byte(loadedScript), 0o777); err != nil { return fmt.Errorf("error writing schema file %q: %w", outputFile, err) } return nil diff --git a/sql/statesql/schema/migrations/0019_schema_cleanup.sql b/sql/statesql/schema/migrations/0019_schema_cleanup.sql new file mode 100644 index 0000000000..215bca7adc --- /dev/null +++ b/sql/statesql/schema/migrations/0019_schema_cleanup.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS _litestream_seq; +DROP TABLE IF EXISTS _litestream_lock; diff --git a/sql/statesql/schema/schema.sql b/sql/statesql/schema/schema.sql index 85e381f87b..7a6873fd7d 100755 --- a/sql/statesql/schema/schema.sql +++ b/sql/statesql/schema/schema.sql @@ -1,4 +1,4 @@ -PRAGMA user_version = 18; +PRAGMA user_version = 19; CREATE TABLE accounts ( address CHAR(24), diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go index 66a55cca8e..f3140e2858 100644 --- a/sql/statesql/statesql_test.go +++ b/sql/statesql/statesql_test.go @@ -33,7 +33,7 @@ func TestIdempotentMigration(t *testing.T) { }) require.NoError(t, err) - sql1, err := sql.LoadDBSchemaScript(db, "") + sql1, err := sql.LoadDBSchemaScript(db) require.NoError(t, err) require.NoError(t, db.Close()) @@ -45,7 +45,7 @@ func TestIdempotentMigration(t *testing.T) { db, err = Open("file:"+file, sql.WithLogger(logger)) require.NoError(t, err) - sql2, err := sql.LoadDBSchemaScript(db, "") + sql2, err := sql.LoadDBSchemaScript(db) require.NoError(t, err) require.Equal(t, sql1, sql2) From beccdddaa44ebaceeae9a3db2caef8ea4bc9c18f Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 01:07:22 +0400 Subject: [PATCH 16/62] sql: avoid cyclic dependencies in future coded migrations --- activation/activation.go | 5 +- activation/activation_test.go | 2 +- activation/certifier.go | 9 +- activation/handler_v1.go | 4 +- activation/handler_v2.go | 4 +- activation/malfeasance_test.go | 13 +- activation/nipost.go | 5 +- activation/nipost_test.go | 2 +- activation/poetdb.go | 5 +- api/grpcserver/admin_service.go | 6 +- api/grpcserver/admin_service_test.go | 4 +- api/grpcserver/debug_service.go | 6 +- api/grpcserver/transaction_service.go | 6 +- api/grpcserver/transaction_service_test.go | 2 +- api/grpcserver/v2alpha1/layer_test.go | 3 +- api/grpcserver/v2alpha1/transaction_test.go | 2 +- atxsdata/warmup.go | 3 +- blocks/certifier.go | 7 +- blocks/certifier_test.go | 2 +- blocks/generator.go | 6 +- blocks/generator_test.go | 2 +- blocks/handler.go | 6 +- blocks/utils.go | 6 +- checkpoint/recovery.go | 22 +- checkpoint/recovery_test.go | 5 +- checkpoint/runner.go | 6 +- checkpoint/runner_test.go | 3 +- cmd/merge-nodes/internal/merge_action.go | 6 +- datastore/mocks/mocks.go | 146 --- datastore/store.go | 14 +- fetch/handler_test.go | 4 +- fetch/p2p_test.go | 4 +- genvm/vm.go | 5 +- hare3/eligibility/oracle_test.go | 2 +- hare3/hare.go | 5 +- hare3/hare_test.go | 3 +- hare3/malfeasance_test.go | 5 +- malfeasance/handler.go | 2 +- malfeasance/handler_test.go | 3 +- mesh/executor_test.go | 2 +- mesh/malfeasance_test.go | 5 +- mesh/mesh.go | 13 +- mesh/mesh_test.go | 2 +- miner/active_set_generator_test.go | 4 +- node/node.go | 4 +- proposals/handler.go | 5 +- proposals/handler_test.go | 2 +- prune/prune.go | 6 +- sql/atxs/atxs_test.go | 6 +- sql/database.go | 87 +- sql/localsql/localsql.go | 17 +- sql/metrics/prometheus.go | 5 +- sql/mocks/mocks.go | 1205 +++++++++++++++++++ sql/schema.go | 8 +- sql/statesql/statesql.go | 17 +- sql/transactions/iterator_test.go | 6 +- sql/transactions/transactions.go | 26 +- sql/transactions/transactions_test.go | 22 +- syncer/atxsync/atxsync.go | 6 +- syncer/atxsync/syncer.go | 7 +- syncer/atxsync/syncer_test.go | 5 +- syncer/find_fork_test.go | 5 +- syncer/malsync/syncer.go | 7 +- syncer/malsync/syncer_test.go | 5 +- tortoise/model/core.go | 2 +- tortoise/recover_test.go | 20 +- tortoise/sim/utils.go | 3 +- tortoise/tortoise_test.go | 2 +- tortoise/tracer_test.go | 2 +- txs/cache.go | 35 +- txs/cache_test.go | 10 +- txs/conservative_state.go | 6 +- txs/conservative_state_test.go | 3 +- 73 files changed, 1499 insertions(+), 406 deletions(-) diff --git a/activation/activation.go b/activation/activation.go index 3dd4604497..0feea56c24 100644 --- a/activation/activation.go +++ b/activation/activation.go @@ -27,7 +27,6 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -76,7 +75,7 @@ type Builder struct { conf Config db sql.Executor atxsdata *atxsdata.Data - localDB *localsql.Database + localDB sql.LocalDatabase publisher pubsub.Publisher nipostBuilder nipostBuilder validator nipostValidator @@ -164,7 +163,7 @@ func NewBuilder( conf Config, db sql.Executor, atxsdata *atxsdata.Data, - localDB *localsql.Database, + localDB sql.LocalDatabase, publisher pubsub.Publisher, nipostBuilder nipostBuilder, layerClock layerClock, diff --git a/activation/activation_test.go b/activation/activation_test.go index f43bdaa6eb..391383f442 100644 --- a/activation/activation_test.go +++ b/activation/activation_test.go @@ -56,7 +56,7 @@ func TestMain(m *testing.M) { type testAtxBuilder struct { *Builder db sql.Executor - localDb *localsql.Database + localDb sql.LocalDatabase goldenATXID types.ATXID observedLogs *observer.ObservedLogs diff --git a/activation/certifier.go b/activation/certifier.go index 4ac3c4aced..3487f16b85 100644 --- a/activation/certifier.go +++ b/activation/certifier.go @@ -21,7 +21,6 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" - "github.com/spacemeshos/go-spacemesh/sql/localsql" certifierdb "github.com/spacemeshos/go-spacemesh/sql/localsql/certifier" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -80,14 +79,14 @@ type CertifyResponse struct { type Certifier struct { logger *zap.Logger - db *localsql.Database + db sql.LocalDatabase client certifierClient certifications singleflight.Group } func NewCertifier( - db *localsql.Database, + db sql.LocalDatabase, logger *zap.Logger, client certifierClient, ) *Certifier { @@ -147,7 +146,7 @@ type CertifierClient struct { client *retryablehttp.Client logger *zap.Logger db sql.Executor - localDb *localsql.Database + localDb sql.LocalDatabase } type certifierClientOpts func(*CertifierClient) @@ -162,7 +161,7 @@ func WithCertifierClientConfig(cfg CertifierClientConfig) certifierClientOpts { func NewCertifierClient( db sql.Executor, - localDb *localsql.Database, + localDb sql.LocalDatabase, logger *zap.Logger, opts ...certifierClientOpts, ) *CertifierClient { diff --git a/activation/handler_v1.go b/activation/handler_v1.go index cefff79e1b..3d8c7ea419 100644 --- a/activation/handler_v1.go +++ b/activation/handler_v1.go @@ -555,7 +555,7 @@ func (h *HandlerV1) checkWrongPrevAtx( func (h *HandlerV1) checkMalicious( ctx context.Context, - tx *sql.Tx, + tx sql.Transaction, watx *wire.ActivationTxV1, ) (*mwire.MalfeasanceProof, error) { malicious, err := identities.IsMalicious(tx, watx.SmesherID) @@ -579,7 +579,7 @@ func (h *HandlerV1) storeAtx( watx *wire.ActivationTxV1, ) (*mwire.MalfeasanceProof, error) { var proof *mwire.MalfeasanceProof - if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { + if err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { var err error proof, err = h.checkMalicious(ctx, tx, watx) if err != nil { diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 5eab7a1606..3391790b08 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -582,7 +582,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( func (h *HandlerV2) checkMalicious( ctx context.Context, - tx *sql.Tx, + tx sql.Transaction, watx *wire.ActivationTxV2, ) (bool, *mwire.MalfeasanceProof, error) { malicious, err := identities.IsMalicious(tx, watx.SmesherID) @@ -614,7 +614,7 @@ func (h *HandlerV2) storeAtx( malicious bool proof *mwire.MalfeasanceProof ) - if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { + if err := h.cdb.WithTx(ctx, func(tx sql.Transaction) error { var err error malicious, proof, err = h.checkMalicious(ctx, tx, watx) if err != nil { diff --git a/activation/malfeasance_test.go b/activation/malfeasance_test.go index bc41510598..0f74a2cc82 100644 --- a/activation/malfeasance_test.go +++ b/activation/malfeasance_test.go @@ -21,6 +21,7 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func createIdentity(tb testing.TB, db sql.Executor, sig *signing.EdSigner) { @@ -40,11 +41,11 @@ type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { @@ -237,12 +238,12 @@ type testInvalidPostIndexHandler struct { *InvalidPostIndexHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase mockPostVerifier *MockPostVerifier } func newTestInvalidPostIndexHandler(tb testing.TB) *testInvalidPostIndexHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { @@ -428,11 +429,11 @@ type testInvalidPrevATXHandler struct { *InvalidPrevATXHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestInvalidPrevATXHandler(tb testing.TB) *testInvalidPrevATXHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/activation/nipost.go b/activation/nipost.go index e8ba0e714d..c55ac566e7 100644 --- a/activation/nipost.go +++ b/activation/nipost.go @@ -20,7 +20,6 @@ import ( "github.com/spacemeshos/go-spacemesh/metrics/public" "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" ) @@ -44,7 +43,7 @@ const ( // NIPostBuilder holds the required state and dependencies to create Non-Interactive Proofs of Space-Time (NIPost). type NIPostBuilder struct { - localDB *localsql.Database + localDB sql.LocalDatabase poetProvers map[string]PoetClient postService postService @@ -73,7 +72,7 @@ func NipostbuilderWithPostStates(ps PostStates) NIPostBuilderOption { // NewNIPostBuilder returns a NIPostBuilder. func NewNIPostBuilder( - db *localsql.Database, + db sql.LocalDatabase, postService postService, lg *zap.Logger, poetCfg PoetConfig, diff --git a/activation/nipost_test.go b/activation/nipost_test.go index 6856045c81..a4d2d56512 100644 --- a/activation/nipost_test.go +++ b/activation/nipost_test.go @@ -52,7 +52,7 @@ type testNIPostBuilder struct { observedLogs *observer.ObservedLogs eventSub <-chan events.UserEvent - mDb *localsql.Database + mDb sql.LocalDatabase mLogger *zap.Logger mPoetDb *MockpoetDbAPI mClock *MocklayerClock diff --git a/activation/poetdb.go b/activation/poetdb.go index bc5671854d..6ce0f691b5 100644 --- a/activation/poetdb.go +++ b/activation/poetdb.go @@ -17,19 +17,18 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/poets" - "github.com/spacemeshos/go-spacemesh/sql/statesql" ) var ErrObjectExists = sql.ErrObjectExists // PoetDb is a database for PoET proofs. type PoetDb struct { - sqlDB *statesql.Database + sqlDB sql.StateDatabase logger *zap.Logger } // NewPoetDb returns a new PoET handler. -func NewPoetDb(db *statesql.Database, log *zap.Logger) *PoetDb { +func NewPoetDb(db sql.StateDatabase, log *zap.Logger) *PoetDb { return &PoetDb{sqlDB: db, logger: log} } diff --git a/api/grpcserver/admin_service.go b/api/grpcserver/admin_service.go index 010dbf8487..afe1e8b1e2 100644 --- a/api/grpcserver/admin_service.go +++ b/api/grpcserver/admin_service.go @@ -22,7 +22,7 @@ import ( "github.com/spacemeshos/go-spacemesh/checkpoint" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" - "github.com/spacemeshos/go-spacemesh/sql/statesql" + "github.com/spacemeshos/go-spacemesh/sql" ) const ( @@ -32,14 +32,14 @@ const ( // AdminService exposes endpoints for node administration. type AdminService struct { - db *statesql.Database + db sql.StateDatabase dataDir string recover func() p peers } // NewAdminService creates a new admin grpc service. -func NewAdminService(db *statesql.Database, dataDir string, p peers) *AdminService { +func NewAdminService(db sql.StateDatabase, dataDir string, p peers) *AdminService { return &AdminService{ db: db, dataDir: dataDir, diff --git a/api/grpcserver/admin_service_test.go b/api/grpcserver/admin_service_test.go index f82c7f7f29..159b5edd78 100644 --- a/api/grpcserver/admin_service_test.go +++ b/api/grpcserver/admin_service_test.go @@ -20,7 +20,7 @@ import ( const snapshot uint32 = 15 -func newAtx(tb testing.TB, db *statesql.Database) { +func newAtx(tb testing.TB, db sql.StateDatabase) { atx := &types.ActivationTx{ PublishEpoch: types.EpochID(2), Sequence: 0, @@ -37,7 +37,7 @@ func newAtx(tb testing.TB, db *statesql.Database) { require.NoError(tb, atxs.Add(db, atx)) } -func createMesh(tb testing.TB, db *statesql.Database) { +func createMesh(tb testing.TB, db sql.StateDatabase) { for range 10 { newAtx(tb, db) } diff --git a/api/grpcserver/debug_service.go b/api/grpcserver/debug_service.go index c9b2637631..43861c561a 100644 --- a/api/grpcserver/debug_service.go +++ b/api/grpcserver/debug_service.go @@ -18,13 +18,13 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" - "github.com/spacemeshos/go-spacemesh/sql/statesql" ) // DebugService exposes global state data, output from the STF. type DebugService struct { - db *statesql.Database + db sql.StateDatabase conState conservativeState netInfo networkInfo oracle oracle @@ -46,7 +46,7 @@ func (d DebugService) String() string { } // NewDebugService creates a new grpc service using config data. -func NewDebugService(db *statesql.Database, conState conservativeState, host networkInfo, oracle oracle, +func NewDebugService(db sql.StateDatabase, conState conservativeState, host networkInfo, oracle oracle, loggers map[string]*zap.AtomicLevel, ) *DebugService { return &DebugService{ diff --git a/api/grpcserver/transaction_service.go b/api/grpcserver/transaction_service.go index f02b2b7983..80f5155c9a 100644 --- a/api/grpcserver/transaction_service.go +++ b/api/grpcserver/transaction_service.go @@ -22,13 +22,13 @@ import ( "github.com/spacemeshos/go-spacemesh/events" "github.com/spacemeshos/go-spacemesh/genvm/core" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" - "github.com/spacemeshos/go-spacemesh/sql/statesql" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) // TransactionService exposes transaction data, and a submit tx endpoint. type TransactionService struct { - db *statesql.Database + db sql.StateDatabase publisher pubsub.Publisher // P2P Swarm mesh meshAPI // Mesh conState conservativeState @@ -52,7 +52,7 @@ func (s TransactionService) String() string { // NewTransactionService creates a new grpc service using config data. func NewTransactionService( - db *statesql.Database, + db sql.StateDatabase, publisher pubsub.Publisher, msh meshAPI, conState conservativeState, diff --git a/api/grpcserver/transaction_service_test.go b/api/grpcserver/transaction_service_test.go index a2274a5bb1..9e97c0d538 100644 --- a/api/grpcserver/transaction_service_test.go +++ b/api/grpcserver/transaction_service_test.go @@ -37,7 +37,7 @@ func TestTransactionService_StreamResults(t *testing.T) { gen := fixture.NewTransactionResultGenerator(). WithAddresses(2) txs := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(ctx, func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(ctx, func(dtx sql.Transaction) error { for i := range txs { tx := gen.Next() diff --git a/api/grpcserver/v2alpha1/layer_test.go b/api/grpcserver/v2alpha1/layer_test.go index 9ae7f277af..2c21967c0d 100644 --- a/api/grpcserver/v2alpha1/layer_test.go +++ b/api/grpcserver/v2alpha1/layer_test.go @@ -16,6 +16,7 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/statesql" @@ -225,7 +226,7 @@ func layerGenWithBlock(withBlock bool) layerGenOpt { } } -func generateLayer(db *statesql.Database, id types.LayerID, opts ...layerGenOpt) (*layers.Layer, error) { +func generateLayer(db sql.StateDatabase, id types.LayerID, opts ...layerGenOpt) (*layers.Layer, error) { g := &layerGenOpts{} for _, opt := range opts { opt(g) diff --git a/api/grpcserver/v2alpha1/transaction_test.go b/api/grpcserver/v2alpha1/transaction_test.go index 2ab968e7ef..07deb08c34 100644 --- a/api/grpcserver/v2alpha1/transaction_test.go +++ b/api/grpcserver/v2alpha1/transaction_test.go @@ -37,7 +37,7 @@ func TestTransactionService_List(t *testing.T) { gen := fixture.NewTransactionResultGenerator().WithAddresses(2) txsList := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(ctx, func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(ctx, func(dtx sql.Transaction) error { for i := range txsList { tx := gen.Next() diff --git a/atxsdata/warmup.go b/atxsdata/warmup.go index a069a2717a..557618dcef 100644 --- a/atxsdata/warmup.go +++ b/atxsdata/warmup.go @@ -8,10 +8,9 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/layers" - "github.com/spacemeshos/go-spacemesh/sql/statesql" ) -func Warm(db *statesql.Database, keep types.EpochID) (*Data, error) { +func Warm(db sql.StateDatabase, keep types.EpochID) (*Data, error) { cache := New() tx, err := db.Tx(context.Background()) if err != nil { diff --git a/blocks/certifier.go b/blocks/certifier.go index db7b4627f6..b3461a98fc 100644 --- a/blocks/certifier.go +++ b/blocks/certifier.go @@ -20,7 +20,6 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/certificates" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -82,7 +81,7 @@ type Certifier struct { stop func() stopped atomic.Bool - db *statesql.Database + db sql.StateDatabase oracle eligibility.Rolacle signers map[types.NodeID]*signing.EdSigner edVerifier *signing.EdVerifier @@ -100,7 +99,7 @@ type Certifier struct { // NewCertifier creates new block certifier. func NewCertifier( - db *statesql.Database, + db sql.StateDatabase, o eligibility.Rolacle, v *signing.EdVerifier, @@ -568,7 +567,7 @@ func (c *Certifier) save( if len(valid)+len(invalid) == 0 { return certificates.Add(c.db, lid, cert) } - return c.db.WithTx(ctx, func(dbtx *sql.Tx) error { + return c.db.WithTx(ctx, func(dbtx sql.Transaction) error { if err := certificates.Add(dbtx, lid, cert); err != nil { return err } diff --git a/blocks/certifier_test.go b/blocks/certifier_test.go index a20365bbf0..d4a0c7faef 100644 --- a/blocks/certifier_test.go +++ b/blocks/certifier_test.go @@ -28,7 +28,7 @@ const defaultCnt = uint16(2) type testCertifier struct { *Certifier - db *statesql.Database + db sql.StateDatabase mOracle *eligibility.MockRolacle mPub *pubsubmock.MockPublisher mClk *mocks.MocklayerClock diff --git a/blocks/generator.go b/blocks/generator.go index 6651923e75..6c63db247f 100644 --- a/blocks/generator.go +++ b/blocks/generator.go @@ -17,8 +17,8 @@ import ( "github.com/spacemeshos/go-spacemesh/hare3/eligibility" "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/proposals/store" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -30,7 +30,7 @@ type Generator struct { eg errgroup.Group stop func() - db *statesql.Database + db sql.StateDatabase atxs *atxsdata.Data proposals *store.Store msh meshProvider @@ -84,7 +84,7 @@ func WithHareOutputChan(ch <-chan hare3.ConsensusOutput) GeneratorOpt { // NewGenerator creates new block generator. func NewGenerator( - db *statesql.Database, + db sql.StateDatabase, atxs *atxsdata.Data, proposals *store.Store, exec executor, diff --git a/blocks/generator_test.go b/blocks/generator_test.go index 0f967c08f8..1375dcf269 100644 --- a/blocks/generator_test.go +++ b/blocks/generator_test.go @@ -266,7 +266,7 @@ func Test_StopBeforeStart(t *testing.T) { func genData( t *testing.T, - db *statesql.Database, + db sql.StateDatabase, data *atxsdata.Data, store *store.Store, lid types.LayerID, diff --git a/blocks/handler.go b/blocks/handler.go index 2e749d389b..4ff6f9b30e 100644 --- a/blocks/handler.go +++ b/blocks/handler.go @@ -12,8 +12,8 @@ import ( "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/blocks" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -28,7 +28,7 @@ type Handler struct { logger *zap.Logger fetcher system.Fetcher - db *statesql.Database + db sql.StateDatabase tortoise tortoiseProvider mesh meshProvider } @@ -46,7 +46,7 @@ func WithLogger(logger *zap.Logger) Opt { // NewHandler creates new Handler. func NewHandler( f system.Fetcher, - db *statesql.Database, + db sql.StateDatabase, tortoise tortoiseProvider, m meshProvider, opts ...Opt, diff --git a/blocks/utils.go b/blocks/utils.go index 88de439ab3..4dfb25d1a1 100644 --- a/blocks/utils.go +++ b/blocks/utils.go @@ -18,9 +18,9 @@ import ( "github.com/spacemeshos/go-spacemesh/atxsdata" "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/layers" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/txs" ) @@ -50,7 +50,7 @@ type proposalMetadata struct { func getProposalMetadata( ctx context.Context, logger *zap.Logger, - db *statesql.Database, + db sql.StateDatabase, atxs *atxsdata.Data, cfg Config, lid types.LayerID, @@ -232,7 +232,7 @@ func toUint64Slice(b []byte) []uint64 { func rewardInfoAndHeight( cfg Config, - db *statesql.Database, + db sql.StateDatabase, atxs *atxsdata.Data, props []*types.Proposal, ) (uint64, []types.AnyReward, error) { diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 4a5d5923b7..36a050e39f 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -125,7 +125,7 @@ func Recover( } defer localDB.Close() logger.With().Info("clearing atx and malfeasance sync metadata from local database") - if err := localDB.WithTx(ctx, func(tx *sql.Tx) error { + if err := localDB.WithTx(ctx, func(tx sql.Transaction) error { if err := atxsync.Clear(tx); err != nil { return err } @@ -149,8 +149,8 @@ func Recover( func RecoverWithDb( ctx context.Context, logger log.Log, - db *statesql.Database, - localDB *localsql.Database, + db sql.StateDatabase, + localDB sql.LocalDatabase, fs afero.Fs, cfg *RecoverConfig, ) (*PreservedData, error) { @@ -181,8 +181,8 @@ type recoveryData struct { func recoverFromLocalFile( ctx context.Context, logger log.Log, - db *statesql.Database, - localDB *localsql.Database, + db sql.StateDatabase, + localDB sql.LocalDatabase, fs afero.Fs, cfg *RecoverConfig, file string, @@ -268,7 +268,7 @@ func recoverFromLocalFile( log.Int("num accounts", len(data.accounts)), log.Int("num atxs", len(data.atxs)), ) - if err = newDB.WithTx(ctx, func(tx *sql.Tx) error { + if err = newDB.WithTx(ctx, func(tx sql.Transaction) error { for _, acct := range data.accounts { if err = accounts.Update(tx, acct); err != nil { return fmt.Errorf("restore account snapshot: %w", err) @@ -367,8 +367,8 @@ func checkpointData(fs afero.Fs, file string, newGenesis types.LayerID) (*recove func collectOwnAtxDeps( logger log.Log, - db *statesql.Database, - localDB *localsql.Database, + db sql.StateDatabase, + localDB sql.LocalDatabase, nodeID types.NodeID, goldenATX types.ATXID, data *recoveryData, @@ -435,7 +435,7 @@ func collectOwnAtxDeps( } func collectDeps( - db *statesql.Database, + db sql.StateDatabase, ref types.ATXID, all map[types.ATXID]struct{}, ) (map[types.ATXID]*AtxDep, map[types.PoetProofRef]*types.PoetProofMessage, error) { @@ -451,7 +451,7 @@ func collectDeps( } func collect( - db *statesql.Database, + db sql.StateDatabase, ref types.ATXID, all map[types.ATXID]struct{}, deps map[types.ATXID]*AtxDep, @@ -506,7 +506,7 @@ func collect( } func poetProofs( - db *statesql.Database, + db sql.StateDatabase, atxIds map[types.ATXID]*AtxDep, ) (map[types.PoetProofRef]*types.PoetProofMessage, error) { proofs := make(map[types.PoetProofRef]*types.PoetProofMessage, len(atxIds)) diff --git a/checkpoint/recovery_test.go b/checkpoint/recovery_test.go index 02cf6d2624..435c30f48f 100644 --- a/checkpoint/recovery_test.go +++ b/checkpoint/recovery_test.go @@ -29,6 +29,7 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql" @@ -76,7 +77,7 @@ func accountEqual(tb testing.TB, cacct types.AccountSnapshot, acct *types.Accoun } } -func verifyDbContent(tb testing.TB, db *statesql.Database) { +func verifyDbContent(tb testing.TB, db sql.StateDatabase) { var expected types.Checkpoint require.NoError(tb, json.Unmarshal([]byte(checkpointData), &expected)) expAtx := map[types.ATXID]types.AtxSnapshot{} @@ -227,7 +228,7 @@ func TestRecover_SameRecoveryInfo(t *testing.T) { func validateAndPreserveData( tb testing.TB, - db *statesql.Database, + db sql.StateDatabase, deps []*checkpoint.AtxDep, ) { lg := zaptest.NewLogger(tb) diff --git a/checkpoint/runner.go b/checkpoint/runner.go index f40cfe391a..9741992775 100644 --- a/checkpoint/runner.go +++ b/checkpoint/runner.go @@ -10,10 +10,10 @@ import ( "github.com/spf13/afero" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" - "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -28,7 +28,7 @@ const ( func checkpointDB( ctx context.Context, - db *statesql.Database, + db sql.StateDatabase, snapshot types.LayerID, numAtxs int, ) (*types.Checkpoint, error) { @@ -115,7 +115,7 @@ func checkpointDB( func Generate( ctx context.Context, fs afero.Fs, - db *statesql.Database, + db sql.StateDatabase, dataDir string, snapshot types.LayerID, numAtxs int, diff --git a/checkpoint/runner_test.go b/checkpoint/runner_test.go index 7661f79f4b..bf18c1e583 100644 --- a/checkpoint/runner_test.go +++ b/checkpoint/runner_test.go @@ -18,6 +18,7 @@ import ( "github.com/spacemeshos/go-spacemesh/checkpoint" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/identities" @@ -233,7 +234,7 @@ func asAtxSnapshot(v *types.ActivationTx, cmt *types.ATXID) types.AtxSnapshot { func createMesh( t *testing.T, - db *statesql.Database, + db sql.StateDatabase, miners map[types.NodeID][]*types.ActivationTx, accts []*types.Account, ) { diff --git a/cmd/merge-nodes/internal/merge_action.go b/cmd/merge-nodes/internal/merge_action.go index 241c5bb0a9..1ebc90b00e 100644 --- a/cmd/merge-nodes/internal/merge_action.go +++ b/cmd/merge-nodes/internal/merge_action.go @@ -26,7 +26,7 @@ const ( func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { // Open the target database - var dstDB *localsql.Database + var dstDB sql.LocalDatabase var err error dstDB, err = openDB(dbLog, to) switch { @@ -150,7 +150,7 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { } dbLog.Info("merging databases", zap.String("from", from), zap.String("to", to)) - err = dstDB.WithTx(ctx, func(tx *sql.Tx) error { + err = dstDB.WithTx(ctx, func(tx sql.Transaction) error { enc := func(stmt *sql.Statement) { stmt.BindText(1, filepath.Join(from, localDbFile)) } @@ -183,7 +183,7 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { return nil } -func openDB(dbLog *zap.Logger, path string) (*localsql.Database, error) { +func openDB(dbLog *zap.Logger, path string) (sql.LocalDatabase, error) { dbPath := filepath.Join(path, localDbFile) if _, err := os.Stat(dbPath); err != nil { return nil, fmt.Errorf("stat source database %s: %w", dbPath, err) diff --git a/datastore/mocks/mocks.go b/datastore/mocks/mocks.go index bbb150d389..fc3ab85e8e 100644 --- a/datastore/mocks/mocks.go +++ b/datastore/mocks/mocks.go @@ -8,149 +8,3 @@ // Package mocks is a generated GoMock package. package mocks - -import ( - context "context" - reflect "reflect" - - sql "github.com/spacemeshos/go-spacemesh/sql" - gomock "go.uber.org/mock/gomock" -) - -// MockExecutor is a mock of Executor interface. -type MockExecutor struct { - ctrl *gomock.Controller - recorder *MockExecutorMockRecorder -} - -// MockExecutorMockRecorder is the mock recorder for MockExecutor. -type MockExecutorMockRecorder struct { - mock *MockExecutor -} - -// NewMockExecutor creates a new mock instance. -func NewMockExecutor(ctrl *gomock.Controller) *MockExecutor { - mock := &MockExecutor{ctrl: ctrl} - mock.recorder = &MockExecutorMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockExecutor) EXPECT() *MockExecutorMockRecorder { - return m.recorder -} - -// Exec mocks base method. -func (m *MockExecutor) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec. -func (mr *MockExecutorMockRecorder) Exec(arg0, arg1, arg2 any) *MockExecutorExecCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockExecutor)(nil).Exec), arg0, arg1, arg2) - return &MockExecutorExecCall{Call: call} -} - -// MockExecutorExecCall wrap *gomock.Call -type MockExecutorExecCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockExecutorExecCall) Return(arg0 int, arg1 error) *MockExecutorExecCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockExecutorExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockExecutorExecCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockExecutorExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockExecutorExecCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// QueryCache mocks base method. -func (m *MockExecutor) QueryCache() sql.QueryCache { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryCache") - ret0, _ := ret[0].(sql.QueryCache) - return ret0 -} - -// QueryCache indicates an expected call of QueryCache. -func (mr *MockExecutorMockRecorder) QueryCache() *MockExecutorQueryCacheCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCache", reflect.TypeOf((*MockExecutor)(nil).QueryCache)) - return &MockExecutorQueryCacheCall{Call: call} -} - -// MockExecutorQueryCacheCall wrap *gomock.Call -type MockExecutorQueryCacheCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockExecutorQueryCacheCall) Return(arg0 sql.QueryCache) *MockExecutorQueryCacheCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockExecutorQueryCacheCall) Do(f func() sql.QueryCache) *MockExecutorQueryCacheCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockExecutorQueryCacheCall) DoAndReturn(f func() sql.QueryCache) *MockExecutorQueryCacheCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// WithTx mocks base method. -func (m *MockExecutor) WithTx(arg0 context.Context, arg1 func(*sql.Tx) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTx", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// WithTx indicates an expected call of WithTx. -func (mr *MockExecutorMockRecorder) WithTx(arg0, arg1 any) *MockExecutorWithTxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTx", reflect.TypeOf((*MockExecutor)(nil).WithTx), arg0, arg1) - return &MockExecutorWithTxCall{Call: call} -} - -// MockExecutorWithTxCall wrap *gomock.Call -type MockExecutorWithTxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockExecutorWithTxCall) Return(arg0 error) *MockExecutorWithTxCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockExecutorWithTxCall) Do(f func(context.Context, func(*sql.Tx) error) error) *MockExecutorWithTxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockExecutorWithTxCall) DoAndReturn(f func(context.Context, func(*sql.Tx) error) error) *MockExecutorWithTxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} diff --git a/datastore/store.go b/datastore/store.go index 1d5180bfc5..8b1d37a584 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -32,15 +32,9 @@ type VrfNonceKey struct { //go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./store.go -type Executor interface { - sql.Executor - WithTx(context.Context, func(*sql.Tx) error) error - QueryCache() sql.QueryCache -} - // CachedDB is simply a database injected with cache. type CachedDB struct { - Executor + sql.Database sql.QueryCache logger *zap.Logger @@ -91,7 +85,7 @@ func WithConsensusCache(c *atxsdata.Data) Opt { } // NewCachedDB create an instance of a CachedDB. -func NewCachedDB(db Executor, lg *zap.Logger, opts ...Opt) *CachedDB { +func NewCachedDB(db sql.StateDatabase, lg *zap.Logger, opts ...Opt) *CachedDB { o := cacheOpts{cfg: DefaultConfig()} for _, opt := range opts { opt(&o) @@ -114,7 +108,7 @@ func NewCachedDB(db Executor, lg *zap.Logger, opts ...Opt) *CachedDB { } return &CachedDB{ - Executor: db, + Database: db, QueryCache: db.QueryCache(), logger: lg, atxsdata: o.atxsdata, @@ -169,7 +163,7 @@ func (db *CachedDB) GetMalfeasanceProof(id types.NodeID) (*wire.MalfeasanceProof return proof, nil } - proof, err := identities.GetMalfeasanceProof(db.Executor, id) + proof, err := identities.GetMalfeasanceProof(db.Database, id) if err != nil && err != sql.ErrNotFound { return nil, err } diff --git a/fetch/handler_test.go b/fetch/handler_test.go index 1fcd79e174..7cfe001d2a 100644 --- a/fetch/handler_test.go +++ b/fetch/handler_test.go @@ -29,7 +29,7 @@ import ( type testHandler struct { *handler - db *statesql.Database + db sql.StateDatabase cdb *datastore.CachedDB } @@ -360,7 +360,7 @@ func testHandleEpochInfoReqWithQueryCache( expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) } - qc := th.cdb.Executor.(interface{ QueryCount() int }) + qc := th.cdb.Database.(interface{ QueryCount() int }) require.Equal(t, 20, qc.QueryCount()) epochBytes, err := codec.Encode(epoch) require.NoError(t, err) diff --git a/fetch/p2p_test.go b/fetch/p2p_test.go index 9ab1a16d3a..700a632de4 100644 --- a/fetch/p2p_test.go +++ b/fetch/p2p_test.go @@ -38,13 +38,13 @@ type blobKey struct { type testP2PFetch struct { t *testing.T - clientDB *statesql.Database + clientDB sql.StateDatabase // client proposals clientPDB *store.Store clientCDB *datastore.CachedDB clientFetch *Fetch serverID peer.ID - serverDB *statesql.Database + serverDB sql.StateDatabase // server proposals serverPDB *store.Store serverCDB *datastore.CachedDB diff --git a/genvm/vm.go b/genvm/vm.go index 24a156f7d7..538dd54e45 100644 --- a/genvm/vm.go +++ b/genvm/vm.go @@ -23,7 +23,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/accounts" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/rewards" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/system" ) @@ -59,7 +58,7 @@ func WithConfig(cfg Config) Opt { } // New returns VM instance. -func New(db *statesql.Database, opts ...Opt) *VM { +func New(db sql.StateDatabase, opts ...Opt) *VM { vm := &VM{ logger: log.NewNop(), db: db, @@ -79,7 +78,7 @@ func New(db *statesql.Database, opts ...Opt) *VM { // VM handles modifications to the account state. type VM struct { logger log.Log - db *statesql.Database + db sql.StateDatabase cfg Config registry *registry.Registry } diff --git a/hare3/eligibility/oracle_test.go b/hare3/eligibility/oracle_test.go index 39b2ba4cf5..103d647be7 100644 --- a/hare3/eligibility/oracle_test.go +++ b/hare3/eligibility/oracle_test.go @@ -47,7 +47,7 @@ func TestMain(m *testing.M) { type testOracle struct { *Oracle tb testing.TB - db *statesql.Database + db sql.StateDatabase atxsdata *atxsdata.Data mBeacon *mocks.MockBeaconGetter mVerifier *MockvrfVerifier diff --git a/hare3/hare.go b/hare3/hare.go index c5f2bb3780..0f48d49255 100644 --- a/hare3/hare.go +++ b/hare3/hare.go @@ -28,7 +28,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/beacons" "github.com/spacemeshos/go-spacemesh/sql/identities" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -147,7 +146,7 @@ type nodeclock interface { func New( nodeclock nodeclock, pubsub pubsub.PublishSubsciber, - db *statesql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, proposals *store.Store, verifier *signing.EdVerifier, @@ -209,7 +208,7 @@ type Hare struct { // dependencies nodeclock nodeclock pubsub pubsub.PublishSubsciber - db *statesql.Database + db sql.StateDatabase atxsdata *atxsdata.Data proposals *store.Store verifier *signing.EdVerifier diff --git a/hare3/hare_test.go b/hare3/hare_test.go index 8221f0d584..2a75a1482b 100644 --- a/hare3/hare_test.go +++ b/hare3/hare_test.go @@ -24,6 +24,7 @@ import ( pmocks "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/proposals/store" "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/ballots" "github.com/spacemeshos/go-spacemesh/sql/beacons" @@ -114,7 +115,7 @@ type node struct { vrfsigner *signing.VRFSigner atx *types.ActivationTx oracle *eligibility.Oracle - db *statesql.Database + db sql.StateDatabase atxsdata *atxsdata.Data proposals *store.Store diff --git a/hare3/malfeasance_test.go b/hare3/malfeasance_test.go index e3d1547b7f..73d286ff3b 100644 --- a/hare3/malfeasance_test.go +++ b/hare3/malfeasance_test.go @@ -16,17 +16,18 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/malfeasance/handler.go b/malfeasance/handler.go index ff6590ecbb..15e4897253 100644 --- a/malfeasance/handler.go +++ b/malfeasance/handler.go @@ -167,7 +167,7 @@ func (h *Handler) validateAndSave(ctx context.Context, p *wire.MalfeasanceGossip h.countInvalidProof(&p.MalfeasanceProof) return types.EmptyNodeID, err } - if err := h.cdb.WithTx(ctx, func(dbtx *sql.Tx) error { + if err := h.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { malicious, err := identities.IsMalicious(dbtx, nodeID) if err != nil { return fmt.Errorf("check known malicious: %w", err) diff --git a/malfeasance/handler_test.go b/malfeasance/handler_test.go index 4aa2399cd1..e902712687 100644 --- a/malfeasance/handler_test.go +++ b/malfeasance/handler_test.go @@ -19,6 +19,7 @@ import ( "github.com/spacemeshos/go-spacemesh/datastore" "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) @@ -27,7 +28,7 @@ type testMalfeasanceHandler struct { *Handler observedLogs *observer.ObservedLogs - db *statesql.Database + db sql.StateDatabase mockTrt *Mocktortoise } diff --git a/mesh/executor_test.go b/mesh/executor_test.go index 790066d2b6..037891045a 100644 --- a/mesh/executor_test.go +++ b/mesh/executor_test.go @@ -34,7 +34,7 @@ func TestMain(m *testing.M) { type testExecutor struct { tb testing.TB exec *mesh.Executor - db *statesql.Database + db sql.StateDatabase atxsdata *atxsdata.Data mcs *mocks.MockconservativeState mvm *mocks.MockvmState diff --git a/mesh/malfeasance_test.go b/mesh/malfeasance_test.go index 1d673da5e1..535310abda 100644 --- a/mesh/malfeasance_test.go +++ b/mesh/malfeasance_test.go @@ -16,17 +16,18 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) type testMalfeasanceHandler struct { *MalfeasanceHandler observedLogs *observer.ObservedLogs - db *sql.Database + db sql.StateDatabase } func newTestMalfeasanceHandler(tb testing.TB) *testMalfeasanceHandler { - db := sql.InMemory() + db := statesql.InMemory() observer, observedLogs := observer.New(zapcore.WarnLevel) logger := zaptest.NewLogger(tb, zaptest.WrapOptions(zap.WrapCore( func(core zapcore.Core) zapcore.Core { diff --git a/mesh/mesh.go b/mesh/mesh.go index aeb460673b..94b8b2dc8c 100644 --- a/mesh/mesh.go +++ b/mesh/mesh.go @@ -29,14 +29,13 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/rewards" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) // Mesh is the logic layer above our mesh.DB database. type Mesh struct { logger log.Log - cdb *statesql.Database + cdb sql.StateDatabase atxsdata *atxsdata.Data clock layerClock @@ -59,7 +58,7 @@ type Mesh struct { // NewMesh creates a new instant of a mesh. func NewMesh( - db *statesql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, c layerClock, trtl system.Tortoise, @@ -93,7 +92,7 @@ func NewMesh( } genesis := types.GetEffectiveGenesis() - if err = db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + if err = db.WithTx(context.Background(), func(dbtx sql.Transaction) error { if err = layers.SetProcessed(dbtx, genesis); err != nil { return fmt.Errorf("mesh init: %w", err) } @@ -371,7 +370,7 @@ func (msh *Mesh) applyResults(ctx context.Context, results []result.Layer) error return fmt.Errorf("execute block %v/%v: %w", layer.Layer, target, err) } } - if err := msh.cdb.WithTx(ctx, func(dbtx *sql.Tx) error { + if err := msh.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { if err := layers.SetApplied(dbtx, layer.Layer, target); err != nil { return fmt.Errorf("set applied for %v/%v: %w", layer.Layer, target, err) } @@ -421,7 +420,7 @@ func (msh *Mesh) saveHareOutput(ctx context.Context, lid types.LayerID, bid type certs []certificates.CertValidity err error ) - if err = msh.cdb.WithTx(ctx, func(tx *sql.Tx) error { + if err = msh.cdb.WithTx(ctx, func(tx sql.Transaction) error { // check if a certificate has been generated or sync'ed. // - node generated the certificate when it collected enough certify messages // - hare outputs are processed in layer order. i.e. when hare fails for a previous layer N, @@ -543,7 +542,7 @@ func (msh *Mesh) AddBallot( var proof *wire.MalfeasanceProof // ballots.LayerBallotByNodeID and ballots.Add should be atomic // otherwise concurrent ballots.Add from the same smesher may not be noticed - if err := msh.cdb.WithTx(ctx, func(dbtx *sql.Tx) error { + if err := msh.cdb.WithTx(ctx, func(dbtx sql.Transaction) error { if !malicious { prev, err := ballots.LayerBallotByNodeID(dbtx, ballot.Layer, ballot.SmesherID) if err != nil && !errors.Is(err, sql.ErrNotFound) { diff --git a/mesh/mesh_test.go b/mesh/mesh_test.go index ec2aa9a6dc..795564a4aa 100644 --- a/mesh/mesh_test.go +++ b/mesh/mesh_test.go @@ -37,7 +37,7 @@ const ( type testMesh struct { *Mesh - db *statesql.Database + db sql.StateDatabase // it is used in malfeasence.Validate, which is called in the tests cdb *datastore.CachedDB atxsdata *atxsdata.Data diff --git a/miner/active_set_generator_test.go b/miner/active_set_generator_test.go index 6b4319a60e..8ae8021602 100644 --- a/miner/active_set_generator_test.go +++ b/miner/active_set_generator_test.go @@ -98,8 +98,8 @@ type testerActiveSetGenerator struct { tb testing.TB gen *activeSetGenerator - db *statesql.Database - localdb *localsql.Database + db sql.StateDatabase + localdb sql.LocalDatabase atxsdata *atxsdata.Data ctrl *gomock.Controller clock *mocks.MocklayerClock diff --git a/node/node.go b/node/node.go index c7a9da2564..fed10b3398 100644 --- a/node/node.go +++ b/node/node.go @@ -385,10 +385,10 @@ type App struct { fileLock *flock.Flock signers []*signing.EdSigner Config *config.Config - db *statesql.Database + db sql.StateDatabase cachedDB *datastore.CachedDB dbMetrics *dbmetrics.DBMetricsCollector - localDB *localsql.Database + localDB sql.LocalDatabase grpcPublicServer *grpcserver.Server grpcPrivateServer *grpcserver.Server grpcPostServer *grpcserver.Server diff --git a/proposals/handler.go b/proposals/handler.go index 5344465b35..adcb14f27b 100644 --- a/proposals/handler.go +++ b/proposals/handler.go @@ -26,7 +26,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/ballots" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" "github.com/spacemeshos/go-spacemesh/tortoise" ) @@ -50,7 +49,7 @@ type Handler struct { logger log.Log cfg Config - db *statesql.Database + db sql.StateDatabase atxsdata *atxsdata.Data activeSets *lru.Cache[types.Hash32, uint64] edVerifier *signing.EdVerifier @@ -109,7 +108,7 @@ func WithConfig(cfg Config) Opt { // NewHandler creates new Handler. func NewHandler( - db *statesql.Database, + db sql.StateDatabase, atxsdata *atxsdata.Data, proposals proposalsConsumer, edVerifier *signing.EdVerifier, diff --git a/proposals/handler_test.go b/proposals/handler_test.go index 0023ae113d..24c15e02d9 100644 --- a/proposals/handler_test.go +++ b/proposals/handler_test.go @@ -237,7 +237,7 @@ func createProposal(t *testing.T, opts ...any) *types.Proposal { return p } -func createAtx(t *testing.T, db *statesql.Database, epoch types.EpochID, atxID types.ATXID, nodeID types.NodeID) { +func createAtx(t *testing.T, db sql.StateDatabase, epoch types.EpochID, atxID types.ATXID, nodeID types.NodeID) { atx := &types.ActivationTx{ PublishEpoch: epoch, NumUnits: 1, diff --git a/prune/prune.go b/prune/prune.go index c5b4551ea9..2fe9ef20c4 100644 --- a/prune/prune.go +++ b/prune/prune.go @@ -7,9 +7,9 @@ import ( "go.uber.org/zap" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/activesets" "github.com/spacemeshos/go-spacemesh/sql/certificates" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -22,7 +22,7 @@ func WithLogger(logger *zap.Logger) Opt { } } -func New(db *statesql.Database, safeDist uint32, activesetEpoch types.EpochID, opts ...Opt) *Pruner { +func New(db sql.StateDatabase, safeDist uint32, activesetEpoch types.EpochID, opts ...Opt) *Pruner { p := &Pruner{ logger: zap.NewNop(), db: db, @@ -37,7 +37,7 @@ func New(db *statesql.Database, safeDist uint32, activesetEpoch types.EpochID, o type Pruner struct { logger *zap.Logger - db *statesql.Database + db sql.StateDatabase safeDist uint32 activesetEpoch types.EpochID } diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 1fa6c3b563..0d26be8b97 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -443,7 +443,7 @@ func TestGetIDsByEpochCached(t *testing.T) { require.Equal(t, 11, db.QueryCount()) } - require.NoError(t, db.WithTx(context.Background(), func(tx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { atxs.Add(tx, atx5) return nil })) @@ -455,7 +455,7 @@ func TestGetIDsByEpochCached(t *testing.T) { require.ElementsMatch(t, []types.ATXID{atx4.ID(), atx5.ID()}, ids3) require.Equal(t, 13, db.QueryCount()) // not incremented after Add - require.Error(t, db.WithTx(context.Background(), func(tx *sql.Tx) error { + require.Error(t, db.WithTx(context.Background(), func(tx sql.Transaction) error { atxs.Add(tx, atx6) return errors.New("fail") // rollback })) @@ -848,7 +848,7 @@ type header struct { filteredOut bool } -func createAtx(tb testing.TB, db *statesql.Database, hdr header) (types.ATXID, *signing.EdSigner) { +func createAtx(tb testing.TB, db sql.StateDatabase, hdr header) (types.ATXID, *signing.EdSigner) { sig, err := signing.NewEdSigner() require.NoError(tb, err) diff --git a/sql/database.go b/sql/database.go index 78044cd159..9cd979f606 100644 --- a/sql/database.go +++ b/sql/database.go @@ -179,14 +179,14 @@ func withForceFresh() Opt { type Opt func(c *conf) // OpenInMemory creates an in-memory database. -func OpenInMemory(opts ...Opt) (*Database, error) { +func OpenInMemory(opts ...Opt) (*sqliteDatabase, error) { opts = append(opts, WithConnections(1), withForceFresh()) return Open("file::memory:?mode=memory", opts...) } // InMemory creates an in-memory database for testing and panics if // there's an error. -func InMemory(opts ...Opt) *Database { +func InMemory(opts ...Opt) *sqliteDatabase { db, err := OpenInMemory(opts...) if err != nil { panic(err) @@ -199,7 +199,7 @@ func InMemory(opts ...Opt) *Database { // Database is opened in WAL mode and pragma synchronous=normal. // https://sqlite.org/wal.html // https://www.sqlite.org/pragma.html#pragma_synchronous -func Open(uri string, opts ...Opt) (*Database, error) { +func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { config := defaultConf() for _, opt := range opts { opt(config) @@ -224,7 +224,7 @@ func Open(uri string, opts ...Opt) (*Database, error) { return nil, fmt.Errorf("create db %s: %w", uri, err) } } - db := &Database{pool: pool} + db := &sqliteDatabase{pool: pool} if config.enableLatency { db.latency = newQueryLatency() } @@ -276,13 +276,31 @@ func Version(uri string) (int, error) { if err != nil { return 0, fmt.Errorf("open db %s: %w", uri, err) } - db := &Database{pool: pool} + db := &sqliteDatabase{pool: pool} defer db.Close() return version(db) } -// Database is an instance of sqlite database. -type Database struct { +// Database represents a database. +type Database interface { + Executor + Close() error + QueryCount() int + QueryCache() QueryCache + Tx(ctx context.Context) (Transaction, error) + WithTx(ctx context.Context, exec func(Transaction) error) error + TxImmediate(ctx context.Context) (Transaction, error) + WithTxImmediate(ctx context.Context, exec func(Transaction) error) error +} + +// Transaction represents a transaction. +type Transaction interface { + Executor + Commit() error + Release() error +} + +type sqliteDatabase struct { *queryCache pool *sqlitex.Pool @@ -293,7 +311,9 @@ type Database struct { queryCount atomic.Int64 } -func (db *Database) getConn(ctx context.Context) *sqlite.Conn { +var _ Database = &sqliteDatabase{} + +func (db *sqliteDatabase) getConn(ctx context.Context) *sqlite.Conn { start := time.Now() conn := db.pool.Get(ctx) if conn != nil { @@ -302,19 +322,19 @@ func (db *Database) getConn(ctx context.Context) *sqlite.Conn { return conn } -func (db *Database) getTx(ctx context.Context, initstmt string) (*Tx, error) { +func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx, error) { conn := db.getConn(ctx) if conn == nil { return nil, ErrNoConnection } - tx := &Tx{queryCache: db.queryCache, db: db, conn: conn} + tx := &sqliteTx{queryCache: db.queryCache, db: db, conn: conn} if err := tx.begin(initstmt); err != nil { return nil, err } return tx, nil } -func (db *Database) withTx(ctx context.Context, initstmt string, exec func(*Tx) error) error { +func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func(Transaction) error) error { tx, err := db.getTx(ctx, initstmt) if err != nil { return err @@ -334,13 +354,13 @@ func (db *Database) withTx(ctx context.Context, initstmt string, exec func(*Tx) // after one of the write statements. // // https://www.sqlite.org/lang_transaction.html -func (db *Database) Tx(ctx context.Context) (*Tx, error) { +func (db *sqliteDatabase) Tx(ctx context.Context) (Transaction, error) { return db.getTx(ctx, beginDefault) } // WithTx will pass initialized deferred transaction to exec callback. // Will commit only if error is nil. -func (db *Database) WithTx(ctx context.Context, exec func(*Tx) error) error { +func (db *sqliteDatabase) WithTx(ctx context.Context, exec func(Transaction) error) error { return db.withTx(ctx, beginImmediate, exec) } @@ -349,13 +369,16 @@ func (db *Database) WithTx(ctx context.Context, exec func(*Tx) error) error { // IMMEDIATE cause the database connection to start a new write immediately, without waiting // for a write statement. The BEGIN IMMEDIATE might fail with SQLITE_BUSY if another write // transaction is already active on another database connection. -func (db *Database) TxImmediate(ctx context.Context) (*Tx, error) { +func (db *sqliteDatabase) TxImmediate(ctx context.Context) (Transaction, error) { return db.getTx(ctx, beginImmediate) } // WithTxImmediate will pass initialized immediate transaction to exec callback. // Will commit only if error is nil. -func (db *Database) WithTxImmediate(ctx context.Context, exec func(*Tx) error) error { +func (db *sqliteDatabase) WithTxImmediate( + ctx context.Context, + exec func(Transaction) error, +) error { return db.withTx(ctx, beginImmediate, exec) } @@ -367,7 +390,7 @@ func (db *Database) WithTxImmediate(ctx context.Context, exec func(*Tx) error) e // // Note that Exec will block until database is closed or statement has finished. // If application needs to control statement execution lifetime use one of the transaction. -func (db *Database) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { +func (db *sqliteDatabase) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { db.queryCount.Add(1) conn := db.getConn(context.Background()) if conn == nil { @@ -384,7 +407,7 @@ func (db *Database) Exec(query string, encoder Encoder, decoder Decoder) (int, e } // Close closes all pooled connections. -func (db *Database) Close() error { +func (db *sqliteDatabase) Close() error { db.closeMux.Lock() defer db.closeMux.Unlock() if db.closed { @@ -399,12 +422,12 @@ func (db *Database) Close() error { // QueryCount returns the number of queries executed, including failed // queries, but not counting transaction start / commit / rollback. -func (db *Database) QueryCount() int { +func (db *sqliteDatabase) QueryCount() int { return int(db.queryCount.Load()) } // Return database's QueryCache. -func (db *Database) QueryCache() QueryCache { +func (db *sqliteDatabase) QueryCache() QueryCache { return db.queryCache } @@ -445,16 +468,16 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in } } -// Tx is wrapper for database transaction. -type Tx struct { +// sqliteTx is wrapper for database transaction. +type sqliteTx struct { *queryCache - db *Database + db *sqliteDatabase conn *sqlite.Conn committed bool err error } -func (tx *Tx) begin(initstmt string) error { +func (tx *sqliteTx) begin(initstmt string) error { stmt := tx.conn.Prep(initstmt) _, err := stmt.Step() if err != nil { @@ -464,7 +487,7 @@ func (tx *Tx) begin(initstmt string) error { } // Commit transaction. -func (tx *Tx) Commit() error { +func (tx *sqliteTx) Commit() error { stmt := tx.conn.Prep("COMMIT;") _, tx.err = stmt.Step() if tx.err != nil { @@ -475,7 +498,7 @@ func (tx *Tx) Commit() error { } // Release transaction. Every transaction that was created must be released. -func (tx *Tx) Release() error { +func (tx *sqliteTx) Release() error { defer tx.db.pool.Put(tx.conn) if tx.committed { return nil @@ -486,7 +509,7 @@ func (tx *Tx) Release() error { } // Exec query. -func (tx *Tx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { +func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { tx.db.queryCount.Add(1) if tx.db.latency != nil { start := time.Now() @@ -581,3 +604,15 @@ func LoadBlob(db Executor, cmd string, id []byte, blob *Blob) error { func IsNull(stmt *Statement, col int) bool { return stmt.ColumnType(col) == sqlite.SQLITE_NULL } + +// StateDatabase is a Database used for Spacemesh state. +type StateDatabase interface { + Database + IsStateDatabase() bool +} + +// LocalDatabase is a Database used for local node data. +type LocalDatabase interface { + Database + IsLocalDatabase() bool +} diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go index 5084a996bf..17cbc70765 100644 --- a/sql/localsql/localsql.go +++ b/sql/localsql/localsql.go @@ -14,11 +14,14 @@ var schemaScript string //go:embed schema/migrations/*.sql var migrations embed.FS -// Database represents a local database. -type Database struct { - *sql.Database +type database struct { + sql.Database } +var _ sql.LocalDatabase = &database{} + +func (d *database) IsLocalDatabase() bool { return true } + // Schema returns the schema for the local database. func Schema() (*sql.Schema, error) { sqlMigrations, err := sql.LoadSQLMigrations(migrations) @@ -31,7 +34,7 @@ func Schema() (*sql.Schema, error) { } // Open opens a local database. -func Open(uri string, opts ...sql.Opt) (*Database, error) { +func Open(uri string, opts ...sql.Opt) (*database, error) { schema, err := Schema() if err != nil { return nil, err @@ -45,11 +48,11 @@ func Open(uri string, opts ...sql.Opt) (*Database, error) { if err != nil { return nil, err } - return &Database{Database: db}, nil + return &database{Database: db}, nil } // Open opens an in-memory local database. -func InMemory(opts ...sql.Opt) *Database { +func InMemory(opts ...sql.Opt) *database { schema, err := Schema() if err != nil { panic(err) @@ -60,5 +63,5 @@ func InMemory(opts ...sql.Opt) *Database { } opts = append(defaultOpts, opts...) db := sql.InMemory(opts...) - return &Database{Database: db} + return &database{Database: db} } diff --git a/sql/metrics/prometheus.go b/sql/metrics/prometheus.go index acab0b4532..dcaa3206ff 100644 --- a/sql/metrics/prometheus.go +++ b/sql/metrics/prometheus.go @@ -11,7 +11,6 @@ import ( "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/metrics" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/statesql" ) const ( @@ -23,7 +22,7 @@ const ( type DBMetricsCollector struct { logger log.Logger checkInterval time.Duration - db *statesql.Database + db sql.StateDatabase tablesList map[string]struct{} eg errgroup.Group cancel context.CancelFunc @@ -36,7 +35,7 @@ type DBMetricsCollector struct { // NewDBMetricsCollector creates new DBMetricsCollector. func NewDBMetricsCollector( ctx context.Context, - db *statesql.Database, + db sql.StateDatabase, logger log.Logger, checkInterval time.Duration, ) *DBMetricsCollector { diff --git a/sql/mocks/mocks.go b/sql/mocks/mocks.go index c5133b3064..bb3021faa9 100644 --- a/sql/mocks/mocks.go +++ b/sql/mocks/mocks.go @@ -10,6 +10,7 @@ package mocks import ( + context "context" reflect "reflect" sql "github.com/spacemeshos/go-spacemesh/sql" @@ -77,3 +78,1207 @@ func (c *MockExecutorExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decod c.Call = c.Call.DoAndReturn(f) return c } + +// MockDatabase is a mock of Database interface. +type MockDatabase struct { + ctrl *gomock.Controller + recorder *MockDatabaseMockRecorder +} + +// MockDatabaseMockRecorder is the mock recorder for MockDatabase. +type MockDatabaseMockRecorder struct { + mock *MockDatabase +} + +// NewMockDatabase creates a new mock instance. +func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase { + mock := &MockDatabase{ctrl: ctrl} + mock.recorder = &MockDatabaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockDatabase) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockDatabaseMockRecorder) Close() *MockDatabaseCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDatabase)(nil).Close)) + return &MockDatabaseCloseCall{Call: call} +} + +// MockDatabaseCloseCall wrap *gomock.Call +type MockDatabaseCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseCloseCall) Return(arg0 error) *MockDatabaseCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseCloseCall) Do(f func() error) *MockDatabaseCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseCloseCall) DoAndReturn(f func() error) *MockDatabaseCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockDatabase) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockDatabaseMockRecorder) Exec(arg0, arg1, arg2 any) *MockDatabaseExecCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockDatabase)(nil).Exec), arg0, arg1, arg2) + return &MockDatabaseExecCall{Call: call} +} + +// MockDatabaseExecCall wrap *gomock.Call +type MockDatabaseExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseExecCall) Return(arg0 int, arg1 error) *MockDatabaseExecCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockDatabaseExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockDatabaseExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryCache mocks base method. +func (m *MockDatabase) QueryCache() sql.QueryCache { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryCache") + ret0, _ := ret[0].(sql.QueryCache) + return ret0 +} + +// QueryCache indicates an expected call of QueryCache. +func (mr *MockDatabaseMockRecorder) QueryCache() *MockDatabaseQueryCacheCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCache", reflect.TypeOf((*MockDatabase)(nil).QueryCache)) + return &MockDatabaseQueryCacheCall{Call: call} +} + +// MockDatabaseQueryCacheCall wrap *gomock.Call +type MockDatabaseQueryCacheCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseQueryCacheCall) Return(arg0 sql.QueryCache) *MockDatabaseQueryCacheCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseQueryCacheCall) Do(f func() sql.QueryCache) *MockDatabaseQueryCacheCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseQueryCacheCall) DoAndReturn(f func() sql.QueryCache) *MockDatabaseQueryCacheCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryCount mocks base method. +func (m *MockDatabase) QueryCount() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryCount") + ret0, _ := ret[0].(int) + return ret0 +} + +// QueryCount indicates an expected call of QueryCount. +func (mr *MockDatabaseMockRecorder) QueryCount() *MockDatabaseQueryCountCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCount", reflect.TypeOf((*MockDatabase)(nil).QueryCount)) + return &MockDatabaseQueryCountCall{Call: call} +} + +// MockDatabaseQueryCountCall wrap *gomock.Call +type MockDatabaseQueryCountCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseQueryCountCall) Return(arg0 int) *MockDatabaseQueryCountCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseQueryCountCall) Do(f func() int) *MockDatabaseQueryCountCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseQueryCountCall) DoAndReturn(f func() int) *MockDatabaseQueryCountCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Tx mocks base method. +func (m *MockDatabase) Tx(ctx context.Context) (sql.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Tx", ctx) + ret0, _ := ret[0].(sql.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Tx indicates an expected call of Tx. +func (mr *MockDatabaseMockRecorder) Tx(ctx any) *MockDatabaseTxCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tx", reflect.TypeOf((*MockDatabase)(nil).Tx), ctx) + return &MockDatabaseTxCall{Call: call} +} + +// MockDatabaseTxCall wrap *gomock.Call +type MockDatabaseTxCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseTxCall) Return(arg0 sql.Transaction, arg1 error) *MockDatabaseTxCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseTxCall) Do(f func(context.Context) (sql.Transaction, error)) *MockDatabaseTxCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseTxCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockDatabaseTxCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// TxImmediate mocks base method. +func (m *MockDatabase) TxImmediate(ctx context.Context) (sql.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TxImmediate", ctx) + ret0, _ := ret[0].(sql.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TxImmediate indicates an expected call of TxImmediate. +func (mr *MockDatabaseMockRecorder) TxImmediate(ctx any) *MockDatabaseTxImmediateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxImmediate", reflect.TypeOf((*MockDatabase)(nil).TxImmediate), ctx) + return &MockDatabaseTxImmediateCall{Call: call} +} + +// MockDatabaseTxImmediateCall wrap *gomock.Call +type MockDatabaseTxImmediateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseTxImmediateCall) Return(arg0 sql.Transaction, arg1 error) *MockDatabaseTxImmediateCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseTxImmediateCall) Do(f func(context.Context) (sql.Transaction, error)) *MockDatabaseTxImmediateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockDatabaseTxImmediateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// WithTx mocks base method. +func (m *MockDatabase) WithTx(ctx context.Context, exec func(sql.Transaction) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithTx", ctx, exec) + ret0, _ := ret[0].(error) + return ret0 +} + +// WithTx indicates an expected call of WithTx. +func (mr *MockDatabaseMockRecorder) WithTx(ctx, exec any) *MockDatabaseWithTxCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTx", reflect.TypeOf((*MockDatabase)(nil).WithTx), ctx, exec) + return &MockDatabaseWithTxCall{Call: call} +} + +// MockDatabaseWithTxCall wrap *gomock.Call +type MockDatabaseWithTxCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseWithTxCall) Return(arg0 error) *MockDatabaseWithTxCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseWithTxCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockDatabaseWithTxCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseWithTxCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockDatabaseWithTxCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// WithTxImmediate mocks base method. +func (m *MockDatabase) WithTxImmediate(ctx context.Context, exec func(sql.Transaction) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithTxImmediate", ctx, exec) + ret0, _ := ret[0].(error) + return ret0 +} + +// WithTxImmediate indicates an expected call of WithTxImmediate. +func (mr *MockDatabaseMockRecorder) WithTxImmediate(ctx, exec any) *MockDatabaseWithTxImmediateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTxImmediate", reflect.TypeOf((*MockDatabase)(nil).WithTxImmediate), ctx, exec) + return &MockDatabaseWithTxImmediateCall{Call: call} +} + +// MockDatabaseWithTxImmediateCall wrap *gomock.Call +type MockDatabaseWithTxImmediateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseWithTxImmediateCall) Return(arg0 error) *MockDatabaseWithTxImmediateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseWithTxImmediateCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockDatabaseWithTxImmediateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseWithTxImmediateCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockDatabaseWithTxImmediateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockTransaction is a mock of Transaction interface. +type MockTransaction struct { + ctrl *gomock.Controller + recorder *MockTransactionMockRecorder +} + +// MockTransactionMockRecorder is the mock recorder for MockTransaction. +type MockTransactionMockRecorder struct { + mock *MockTransaction +} + +// NewMockTransaction creates a new mock instance. +func NewMockTransaction(ctrl *gomock.Controller) *MockTransaction { + mock := &MockTransaction{ctrl: ctrl} + mock.recorder = &MockTransactionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder { + return m.recorder +} + +// Commit mocks base method. +func (m *MockTransaction) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTransactionMockRecorder) Commit() *MockTransactionCommitCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit)) + return &MockTransactionCommitCall{Call: call} +} + +// MockTransactionCommitCall wrap *gomock.Call +type MockTransactionCommitCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionCommitCall) Return(arg0 error) *MockTransactionCommitCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionCommitCall) Do(f func() error) *MockTransactionCommitCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionCommitCall) DoAndReturn(f func() error) *MockTransactionCommitCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockTransaction) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTransactionMockRecorder) Exec(arg0, arg1, arg2 any) *MockTransactionExecCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), arg0, arg1, arg2) + return &MockTransactionExecCall{Call: call} +} + +// MockTransactionExecCall wrap *gomock.Call +type MockTransactionExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionExecCall) Return(arg0 int, arg1 error) *MockTransactionExecCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockTransactionExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockTransactionExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Release mocks base method. +func (m *MockTransaction) Release() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Release") + ret0, _ := ret[0].(error) + return ret0 +} + +// Release indicates an expected call of Release. +func (mr *MockTransactionMockRecorder) Release() *MockTransactionReleaseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockTransaction)(nil).Release)) + return &MockTransactionReleaseCall{Call: call} +} + +// MockTransactionReleaseCall wrap *gomock.Call +type MockTransactionReleaseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionReleaseCall) Return(arg0 error) *MockTransactionReleaseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionReleaseCall) Do(f func() error) *MockTransactionReleaseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionReleaseCall) DoAndReturn(f func() error) *MockTransactionReleaseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockStateDatabase is a mock of StateDatabase interface. +type MockStateDatabase struct { + ctrl *gomock.Controller + recorder *MockStateDatabaseMockRecorder +} + +// MockStateDatabaseMockRecorder is the mock recorder for MockStateDatabase. +type MockStateDatabaseMockRecorder struct { + mock *MockStateDatabase +} + +// NewMockStateDatabase creates a new mock instance. +func NewMockStateDatabase(ctrl *gomock.Controller) *MockStateDatabase { + mock := &MockStateDatabase{ctrl: ctrl} + mock.recorder = &MockStateDatabaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStateDatabase) EXPECT() *MockStateDatabaseMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockStateDatabase) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockStateDatabaseMockRecorder) Close() *MockStateDatabaseCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStateDatabase)(nil).Close)) + return &MockStateDatabaseCloseCall{Call: call} +} + +// MockStateDatabaseCloseCall wrap *gomock.Call +type MockStateDatabaseCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseCloseCall) Return(arg0 error) *MockStateDatabaseCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseCloseCall) Do(f func() error) *MockStateDatabaseCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseCloseCall) DoAndReturn(f func() error) *MockStateDatabaseCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockStateDatabase) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockStateDatabaseMockRecorder) Exec(arg0, arg1, arg2 any) *MockStateDatabaseExecCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStateDatabase)(nil).Exec), arg0, arg1, arg2) + return &MockStateDatabaseExecCall{Call: call} +} + +// MockStateDatabaseExecCall wrap *gomock.Call +type MockStateDatabaseExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseExecCall) Return(arg0 int, arg1 error) *MockStateDatabaseExecCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockStateDatabaseExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockStateDatabaseExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// IsStateDatabase mocks base method. +func (m *MockStateDatabase) IsStateDatabase() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsStateDatabase") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsStateDatabase indicates an expected call of IsStateDatabase. +func (mr *MockStateDatabaseMockRecorder) IsStateDatabase() *MockStateDatabaseIsStateDatabaseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsStateDatabase", reflect.TypeOf((*MockStateDatabase)(nil).IsStateDatabase)) + return &MockStateDatabaseIsStateDatabaseCall{Call: call} +} + +// MockStateDatabaseIsStateDatabaseCall wrap *gomock.Call +type MockStateDatabaseIsStateDatabaseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseIsStateDatabaseCall) Return(arg0 bool) *MockStateDatabaseIsStateDatabaseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseIsStateDatabaseCall) Do(f func() bool) *MockStateDatabaseIsStateDatabaseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseIsStateDatabaseCall) DoAndReturn(f func() bool) *MockStateDatabaseIsStateDatabaseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryCache mocks base method. +func (m *MockStateDatabase) QueryCache() sql.QueryCache { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryCache") + ret0, _ := ret[0].(sql.QueryCache) + return ret0 +} + +// QueryCache indicates an expected call of QueryCache. +func (mr *MockStateDatabaseMockRecorder) QueryCache() *MockStateDatabaseQueryCacheCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCache", reflect.TypeOf((*MockStateDatabase)(nil).QueryCache)) + return &MockStateDatabaseQueryCacheCall{Call: call} +} + +// MockStateDatabaseQueryCacheCall wrap *gomock.Call +type MockStateDatabaseQueryCacheCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseQueryCacheCall) Return(arg0 sql.QueryCache) *MockStateDatabaseQueryCacheCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseQueryCacheCall) Do(f func() sql.QueryCache) *MockStateDatabaseQueryCacheCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseQueryCacheCall) DoAndReturn(f func() sql.QueryCache) *MockStateDatabaseQueryCacheCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryCount mocks base method. +func (m *MockStateDatabase) QueryCount() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryCount") + ret0, _ := ret[0].(int) + return ret0 +} + +// QueryCount indicates an expected call of QueryCount. +func (mr *MockStateDatabaseMockRecorder) QueryCount() *MockStateDatabaseQueryCountCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCount", reflect.TypeOf((*MockStateDatabase)(nil).QueryCount)) + return &MockStateDatabaseQueryCountCall{Call: call} +} + +// MockStateDatabaseQueryCountCall wrap *gomock.Call +type MockStateDatabaseQueryCountCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseQueryCountCall) Return(arg0 int) *MockStateDatabaseQueryCountCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseQueryCountCall) Do(f func() int) *MockStateDatabaseQueryCountCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseQueryCountCall) DoAndReturn(f func() int) *MockStateDatabaseQueryCountCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Tx mocks base method. +func (m *MockStateDatabase) Tx(ctx context.Context) (sql.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Tx", ctx) + ret0, _ := ret[0].(sql.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Tx indicates an expected call of Tx. +func (mr *MockStateDatabaseMockRecorder) Tx(ctx any) *MockStateDatabaseTxCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tx", reflect.TypeOf((*MockStateDatabase)(nil).Tx), ctx) + return &MockStateDatabaseTxCall{Call: call} +} + +// MockStateDatabaseTxCall wrap *gomock.Call +type MockStateDatabaseTxCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseTxCall) Return(arg0 sql.Transaction, arg1 error) *MockStateDatabaseTxCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseTxCall) Do(f func(context.Context) (sql.Transaction, error)) *MockStateDatabaseTxCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseTxCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockStateDatabaseTxCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// TxImmediate mocks base method. +func (m *MockStateDatabase) TxImmediate(ctx context.Context) (sql.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TxImmediate", ctx) + ret0, _ := ret[0].(sql.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TxImmediate indicates an expected call of TxImmediate. +func (mr *MockStateDatabaseMockRecorder) TxImmediate(ctx any) *MockStateDatabaseTxImmediateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxImmediate", reflect.TypeOf((*MockStateDatabase)(nil).TxImmediate), ctx) + return &MockStateDatabaseTxImmediateCall{Call: call} +} + +// MockStateDatabaseTxImmediateCall wrap *gomock.Call +type MockStateDatabaseTxImmediateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseTxImmediateCall) Return(arg0 sql.Transaction, arg1 error) *MockStateDatabaseTxImmediateCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseTxImmediateCall) Do(f func(context.Context) (sql.Transaction, error)) *MockStateDatabaseTxImmediateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockStateDatabaseTxImmediateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// WithTx mocks base method. +func (m *MockStateDatabase) WithTx(ctx context.Context, exec func(sql.Transaction) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithTx", ctx, exec) + ret0, _ := ret[0].(error) + return ret0 +} + +// WithTx indicates an expected call of WithTx. +func (mr *MockStateDatabaseMockRecorder) WithTx(ctx, exec any) *MockStateDatabaseWithTxCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTx", reflect.TypeOf((*MockStateDatabase)(nil).WithTx), ctx, exec) + return &MockStateDatabaseWithTxCall{Call: call} +} + +// MockStateDatabaseWithTxCall wrap *gomock.Call +type MockStateDatabaseWithTxCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseWithTxCall) Return(arg0 error) *MockStateDatabaseWithTxCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseWithTxCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockStateDatabaseWithTxCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseWithTxCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockStateDatabaseWithTxCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// WithTxImmediate mocks base method. +func (m *MockStateDatabase) WithTxImmediate(ctx context.Context, exec func(sql.Transaction) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithTxImmediate", ctx, exec) + ret0, _ := ret[0].(error) + return ret0 +} + +// WithTxImmediate indicates an expected call of WithTxImmediate. +func (mr *MockStateDatabaseMockRecorder) WithTxImmediate(ctx, exec any) *MockStateDatabaseWithTxImmediateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTxImmediate", reflect.TypeOf((*MockStateDatabase)(nil).WithTxImmediate), ctx, exec) + return &MockStateDatabaseWithTxImmediateCall{Call: call} +} + +// MockStateDatabaseWithTxImmediateCall wrap *gomock.Call +type MockStateDatabaseWithTxImmediateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseWithTxImmediateCall) Return(arg0 error) *MockStateDatabaseWithTxImmediateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseWithTxImmediateCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockStateDatabaseWithTxImmediateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseWithTxImmediateCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockStateDatabaseWithTxImmediateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockLocalDatabase is a mock of LocalDatabase interface. +type MockLocalDatabase struct { + ctrl *gomock.Controller + recorder *MockLocalDatabaseMockRecorder +} + +// MockLocalDatabaseMockRecorder is the mock recorder for MockLocalDatabase. +type MockLocalDatabaseMockRecorder struct { + mock *MockLocalDatabase +} + +// NewMockLocalDatabase creates a new mock instance. +func NewMockLocalDatabase(ctrl *gomock.Controller) *MockLocalDatabase { + mock := &MockLocalDatabase{ctrl: ctrl} + mock.recorder = &MockLocalDatabaseMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLocalDatabase) EXPECT() *MockLocalDatabaseMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockLocalDatabase) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockLocalDatabaseMockRecorder) Close() *MockLocalDatabaseCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockLocalDatabase)(nil).Close)) + return &MockLocalDatabaseCloseCall{Call: call} +} + +// MockLocalDatabaseCloseCall wrap *gomock.Call +type MockLocalDatabaseCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseCloseCall) Return(arg0 error) *MockLocalDatabaseCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseCloseCall) Do(f func() error) *MockLocalDatabaseCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseCloseCall) DoAndReturn(f func() error) *MockLocalDatabaseCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockLocalDatabase) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockLocalDatabaseMockRecorder) Exec(arg0, arg1, arg2 any) *MockLocalDatabaseExecCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockLocalDatabase)(nil).Exec), arg0, arg1, arg2) + return &MockLocalDatabaseExecCall{Call: call} +} + +// MockLocalDatabaseExecCall wrap *gomock.Call +type MockLocalDatabaseExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseExecCall) Return(arg0 int, arg1 error) *MockLocalDatabaseExecCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockLocalDatabaseExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockLocalDatabaseExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// IsLocalDatabase mocks base method. +func (m *MockLocalDatabase) IsLocalDatabase() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsLocalDatabase") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsLocalDatabase indicates an expected call of IsLocalDatabase. +func (mr *MockLocalDatabaseMockRecorder) IsLocalDatabase() *MockLocalDatabaseIsLocalDatabaseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsLocalDatabase", reflect.TypeOf((*MockLocalDatabase)(nil).IsLocalDatabase)) + return &MockLocalDatabaseIsLocalDatabaseCall{Call: call} +} + +// MockLocalDatabaseIsLocalDatabaseCall wrap *gomock.Call +type MockLocalDatabaseIsLocalDatabaseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseIsLocalDatabaseCall) Return(arg0 bool) *MockLocalDatabaseIsLocalDatabaseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseIsLocalDatabaseCall) Do(f func() bool) *MockLocalDatabaseIsLocalDatabaseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseIsLocalDatabaseCall) DoAndReturn(f func() bool) *MockLocalDatabaseIsLocalDatabaseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryCache mocks base method. +func (m *MockLocalDatabase) QueryCache() sql.QueryCache { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryCache") + ret0, _ := ret[0].(sql.QueryCache) + return ret0 +} + +// QueryCache indicates an expected call of QueryCache. +func (mr *MockLocalDatabaseMockRecorder) QueryCache() *MockLocalDatabaseQueryCacheCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCache", reflect.TypeOf((*MockLocalDatabase)(nil).QueryCache)) + return &MockLocalDatabaseQueryCacheCall{Call: call} +} + +// MockLocalDatabaseQueryCacheCall wrap *gomock.Call +type MockLocalDatabaseQueryCacheCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseQueryCacheCall) Return(arg0 sql.QueryCache) *MockLocalDatabaseQueryCacheCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseQueryCacheCall) Do(f func() sql.QueryCache) *MockLocalDatabaseQueryCacheCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseQueryCacheCall) DoAndReturn(f func() sql.QueryCache) *MockLocalDatabaseQueryCacheCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryCount mocks base method. +func (m *MockLocalDatabase) QueryCount() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueryCount") + ret0, _ := ret[0].(int) + return ret0 +} + +// QueryCount indicates an expected call of QueryCount. +func (mr *MockLocalDatabaseMockRecorder) QueryCount() *MockLocalDatabaseQueryCountCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCount", reflect.TypeOf((*MockLocalDatabase)(nil).QueryCount)) + return &MockLocalDatabaseQueryCountCall{Call: call} +} + +// MockLocalDatabaseQueryCountCall wrap *gomock.Call +type MockLocalDatabaseQueryCountCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseQueryCountCall) Return(arg0 int) *MockLocalDatabaseQueryCountCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseQueryCountCall) Do(f func() int) *MockLocalDatabaseQueryCountCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseQueryCountCall) DoAndReturn(f func() int) *MockLocalDatabaseQueryCountCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Tx mocks base method. +func (m *MockLocalDatabase) Tx(ctx context.Context) (sql.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Tx", ctx) + ret0, _ := ret[0].(sql.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Tx indicates an expected call of Tx. +func (mr *MockLocalDatabaseMockRecorder) Tx(ctx any) *MockLocalDatabaseTxCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tx", reflect.TypeOf((*MockLocalDatabase)(nil).Tx), ctx) + return &MockLocalDatabaseTxCall{Call: call} +} + +// MockLocalDatabaseTxCall wrap *gomock.Call +type MockLocalDatabaseTxCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseTxCall) Return(arg0 sql.Transaction, arg1 error) *MockLocalDatabaseTxCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseTxCall) Do(f func(context.Context) (sql.Transaction, error)) *MockLocalDatabaseTxCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseTxCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockLocalDatabaseTxCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// TxImmediate mocks base method. +func (m *MockLocalDatabase) TxImmediate(ctx context.Context) (sql.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TxImmediate", ctx) + ret0, _ := ret[0].(sql.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TxImmediate indicates an expected call of TxImmediate. +func (mr *MockLocalDatabaseMockRecorder) TxImmediate(ctx any) *MockLocalDatabaseTxImmediateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxImmediate", reflect.TypeOf((*MockLocalDatabase)(nil).TxImmediate), ctx) + return &MockLocalDatabaseTxImmediateCall{Call: call} +} + +// MockLocalDatabaseTxImmediateCall wrap *gomock.Call +type MockLocalDatabaseTxImmediateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseTxImmediateCall) Return(arg0 sql.Transaction, arg1 error) *MockLocalDatabaseTxImmediateCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseTxImmediateCall) Do(f func(context.Context) (sql.Transaction, error)) *MockLocalDatabaseTxImmediateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockLocalDatabaseTxImmediateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// WithTx mocks base method. +func (m *MockLocalDatabase) WithTx(ctx context.Context, exec func(sql.Transaction) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithTx", ctx, exec) + ret0, _ := ret[0].(error) + return ret0 +} + +// WithTx indicates an expected call of WithTx. +func (mr *MockLocalDatabaseMockRecorder) WithTx(ctx, exec any) *MockLocalDatabaseWithTxCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTx", reflect.TypeOf((*MockLocalDatabase)(nil).WithTx), ctx, exec) + return &MockLocalDatabaseWithTxCall{Call: call} +} + +// MockLocalDatabaseWithTxCall wrap *gomock.Call +type MockLocalDatabaseWithTxCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseWithTxCall) Return(arg0 error) *MockLocalDatabaseWithTxCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseWithTxCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockLocalDatabaseWithTxCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseWithTxCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockLocalDatabaseWithTxCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// WithTxImmediate mocks base method. +func (m *MockLocalDatabase) WithTxImmediate(ctx context.Context, exec func(sql.Transaction) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithTxImmediate", ctx, exec) + ret0, _ := ret[0].(error) + return ret0 +} + +// WithTxImmediate indicates an expected call of WithTxImmediate. +func (mr *MockLocalDatabaseMockRecorder) WithTxImmediate(ctx, exec any) *MockLocalDatabaseWithTxImmediateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTxImmediate", reflect.TypeOf((*MockLocalDatabase)(nil).WithTxImmediate), ctx, exec) + return &MockLocalDatabaseWithTxImmediateCall{Call: call} +} + +// MockLocalDatabaseWithTxImmediateCall wrap *gomock.Call +type MockLocalDatabaseWithTxImmediateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseWithTxImmediateCall) Return(arg0 error) *MockLocalDatabaseWithTxImmediateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseWithTxImmediateCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockLocalDatabaseWithTxImmediateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseWithTxImmediateCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockLocalDatabaseWithTxImmediateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/sql/schema.go b/sql/schema.go index 5614185567..6feec5d031 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -78,8 +78,8 @@ func (s *Schema) SkipMigrations(i ...int) { } // Apply applies the schema to the database. -func (s *Schema) Apply(db *Database) error { - return db.WithTx(context.Background(), func(tx *Tx) error { +func (s *Schema) Apply(db Database) error { + return db.WithTx(context.Background(), func(tx Transaction) error { scanner := bufio.NewScanner(strings.NewReader(s.Script)) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if i := bytes.Index(data, []byte(";")); i >= 0 { @@ -99,7 +99,7 @@ func (s *Schema) Apply(db *Database) error { // Migrate performs database migration. In case if migrations are disabled, the database // version is checked but no migrations are run, and if the database is too old and // migrations are disabled, an error is returned. -func (s *Schema) Migrate(logger *zap.Logger, db *Database, vacuumState int, enable bool) error { +func (s *Schema) Migrate(logger *zap.Logger, db Database, vacuumState int, enable bool) error { if len(s.Migrations) == 0 { return nil } @@ -139,7 +139,7 @@ func (s *Schema) Migrate(logger *zap.Logger, db *Database, vacuumState int, enab if m.Order() <= before { continue } - if err := db.WithTx(context.Background(), func(tx *Tx) error { + if err := db.WithTx(context.Background(), func(tx Transaction) error { if _, ok := s.skipMigration[m.Order()]; !ok { if err := m.Apply(tx); err != nil { for j := i; j >= 0 && s.Migrations[j].Order() > before; j-- { diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 76139fac5f..85de4f20bc 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -14,11 +14,14 @@ var schemaScript string //go:embed schema/migrations/*.sql var migrations embed.FS -// Database represents a state database. -type Database struct { - *sql.Database +type database struct { + sql.Database } +var _ sql.StateDatabase = &database{} + +func (db *database) IsStateDatabase() bool { return true } + // Schema returns the schema for the state database. func Schema() (*sql.Schema, error) { sqlMigrations, err := sql.LoadSQLMigrations(migrations) @@ -31,7 +34,7 @@ func Schema() (*sql.Schema, error) { } // Open opens a state database. -func Open(uri string, opts ...sql.Opt) (*Database, error) { +func Open(uri string, opts ...sql.Opt) (sql.StateDatabase, error) { schema, err := Schema() if err != nil { return nil, err @@ -41,11 +44,11 @@ func Open(uri string, opts ...sql.Opt) (*Database, error) { if err != nil { return nil, err } - return &Database{Database: db}, nil + return &database{Database: db}, nil } // Open opens an in-memory state database. -func InMemory(opts ...sql.Opt) *Database { +func InMemory(opts ...sql.Opt) sql.StateDatabase { schema, err := Schema() if err != nil { panic(err) @@ -55,5 +58,5 @@ func InMemory(opts ...sql.Opt) *Database { } opts = append(defaultOpts, opts...) db := sql.InMemory(opts...) - return &Database{Database: db} + return &database{Database: db} } diff --git a/sql/transactions/iterator_test.go b/sql/transactions/iterator_test.go index 4b9bdb0dcb..334cef2a97 100644 --- a/sql/transactions/iterator_test.go +++ b/sql/transactions/iterator_test.go @@ -64,7 +64,7 @@ func TestIterateResults(t *testing.T) { gen := fixture.NewTransactionResultGenerator() txs := make([]types.TransactionWithResult, 100) - require.NoError(t, db.WithTx(context.TODO(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.TODO(), func(dtx sql.Transaction) error { for i := range txs { tx := gen.Next() @@ -148,7 +148,7 @@ func TestIterateSnapshot(t *testing.T) { require.NoError(t, err) gen := fixture.NewTransactionResultGenerator() expect := 10 - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { for i := 0; i < expect; i++ { tx := gen.Next() @@ -176,7 +176,7 @@ func TestIterateSnapshot(t *testing.T) { }() <-initialized - require.NoError(t, db.WithTx(context.TODO(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.TODO(), func(dtx sql.Transaction) error { for i := 0; i < 10; i++ { tx := gen.Next() diff --git a/sql/transactions/transactions.go b/sql/transactions/transactions.go index f3757a8891..f2ab21847a 100644 --- a/sql/transactions/transactions.go +++ b/sql/transactions/transactions.go @@ -28,8 +28,8 @@ func Add(db sql.Executor, tx *types.Transaction, received time.Time) error { if _, err = db.Exec(` insert into transactions (id, tx, header, principal, nonce, timestamp) values (?1, ?2, ?3, ?4, ?5, ?6) - on conflict(id) do update set - header=?3, principal=?4, nonce=?5 + on conflict(id) do update set + header=?3, principal=?4, nonce=?5 where header is null;`, func(stmt *sql.Statement) { stmt.BindBytes(1, tx.ID.Bytes()) @@ -49,7 +49,7 @@ func Add(db sql.Executor, tx *types.Transaction, received time.Time) error { // AddToProposal associates a transaction with a proposal. func AddToProposal(db sql.Executor, tid types.TransactionID, lid types.LayerID, pid types.ProposalID) error { if _, err := db.Exec(` - insert into proposal_transactions (pid, tid, layer) values (?1, ?2, ?3) + insert into proposal_transactions (pid, tid, layer) values (?1, ?2, ?3) on conflict(tid, pid) do nothing;`, func(stmt *sql.Statement) { stmt.BindBytes(1, pid.Bytes()) @@ -136,8 +136,8 @@ func GetAppliedLayer(db sql.Executor, tid types.TransactionID) (types.LayerID, e } // UndoLayers unset all transactions to `statePending` from `from` layer to the max layer with applied transactions. -func UndoLayers(db *sql.Tx, from types.LayerID) error { - _, err := db.Exec(`delete from transactions_results_addresses +func UndoLayers(tx sql.Transaction, from types.LayerID) error { + _, err := tx.Exec(`delete from transactions_results_addresses where tid in (select id from transactions where layer >= ?1);`, func(stmt *sql.Statement) { stmt.BindInt64(1, int64(from)) @@ -145,8 +145,8 @@ func UndoLayers(db *sql.Tx, from types.LayerID) error { if err != nil { return fmt.Errorf("delete addresses mapping %w", err) } - _, err = db.Exec(`update transactions - set layer = null, block = null, result = null + _, err = tx.Exec(`update transactions + set layer = null, block = null, result = null where layer >= ?1`, func(stmt *sql.Statement) { stmt.BindInt64(1, int64(from)) @@ -287,7 +287,7 @@ func AddressesWithPendingTransactions(db sql.Executor) ([]types.AddressNonce, er // GetAcctPendingFromNonce get all pending transactions with nonce after `from` for the given address. func GetAcctPendingFromNonce(db sql.Executor, address types.Address, from uint64) ([]*types.MeshTransaction, error) { return queryPending(db, `select tx, header, layer, block, timestamp, id from transactions - where principal = ?1 and nonce >= ?2 and result is null + where principal = ?1 and nonce >= ?2 and result is null order by nonce asc, timestamp asc`, func(stmt *sql.Statement) { stmt.BindBytes(1, address.Bytes()) @@ -321,14 +321,14 @@ func queryPending( } // AddResult adds result for the transaction. -func AddResult(db *sql.Tx, id types.TransactionID, rst *types.TransactionResult) error { +func AddResult(tx sql.Transaction, id types.TransactionID, rst *types.TransactionResult) error { buf, err := codec.Encode(rst) if err != nil { return fmt.Errorf("encode %w", err) } - if rows, err := db.Exec(`update transactions - set result = ?2, layer = ?3, block = ?4 + if rows, err := tx.Exec(`update transactions + set result = ?2, layer = ?3, block = ?4 where id = ?1 and result is null returning id;`, func(stmt *sql.Statement) { stmt.BindBytes(1, id[:]) @@ -345,7 +345,7 @@ func AddResult(db *sql.Tx, id types.TransactionID, rst *types.TransactionResult) return fmt.Errorf("invalid state for %s", id) } for i := range rst.Addresses { - if _, err := db.Exec(`insert into transactions_results_addresses + if _, err := tx.Exec(`insert into transactions_results_addresses (address, tid) values (?1, ?2);`, func(stmt *sql.Statement) { stmt.BindBytes(1, rst.Addresses[i][:]) @@ -418,7 +418,7 @@ func IterateTransactionsOps( fn func(tx *types.MeshTransaction, result *types.TransactionResult) bool, ) error { var derr error - _, err := db.Exec(`select distinct tx, header, layer, block, timestamp, id, result + _, err := db.Exec(`select distinct tx, header, layer, block, timestamp, id, result from transactions left join transactions_results_addresses on id=tid where result is not null`+builder.FilterFrom(operations), diff --git a/sql/transactions/transactions_test.go b/sql/transactions/transactions_test.go index dc5d4714f8..4711e9fba2 100644 --- a/sql/transactions/transactions_test.go +++ b/sql/transactions/transactions_test.go @@ -232,17 +232,17 @@ func TestApply_AlreadyApplied(t *testing.T) { require.NoError(t, transactions.Add(db, tx, time.Now())) bid := types.RandomBlockID() - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) // same block applied again - require.Error(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.Error(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) // different block applied again - require.Error(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.Error(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult( dtx, tx.ID, @@ -254,7 +254,7 @@ func TestApply_AlreadyApplied(t *testing.T) { func TestUndoLayers_Empty(t *testing.T) { db := statesql.InMemory() - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, types.LayerID(199)) })) } @@ -273,7 +273,7 @@ func TestApplyAndUndoLayers(t *testing.T) { require.NoError(t, transactions.Add(db, tx, time.Now())) bid := types.RandomBlockID() - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, tx.ID, &types.TransactionResult{Layer: lid, Block: bid}) })) applied = append(applied, tx.ID) @@ -285,7 +285,7 @@ func TestApplyAndUndoLayers(t *testing.T) { require.Equal(t, types.APPLIED, mtx.State) } // revert to firstLayer - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, firstLayer.Add(1)) })) @@ -349,7 +349,7 @@ func TestGetByAddress(t *testing.T) { createTX(t, signer1, signer2Address, 1, 191, 1), } received := time.Now() - require.NoError(t, db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, tx := range txs { require.NoError(t, transactions.Add(dbtx, tx, received)) require.NoError(t, transactions.AddResult(dbtx, tx.ID, &types.TransactionResult{Layer: lid})) @@ -418,7 +418,7 @@ func TestAppliedLayer(t *testing.T) { for _, tx := range txs { require.NoError(t, transactions.Add(db, tx, time.Now())) } - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.AddResult(dtx, txs[0].ID, &types.TransactionResult{Layer: lid, Block: types.BlockID{1, 1}}) })) @@ -429,7 +429,7 @@ func TestAppliedLayer(t *testing.T) { _, err = transactions.GetAppliedLayer(db, txs[1].ID) require.ErrorIs(t, err, sql.ErrNotFound) - require.NoError(t, db.WithTx(context.Background(), func(dtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dtx sql.Transaction) error { return transactions.UndoLayers(dtx, lid) })) _, err = transactions.GetAppliedLayer(db, txs[0].ID) @@ -466,7 +466,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { {Address: principals[0], Nonce: txs[0].Nonce}, {Address: principals[1], Nonce: txs[2].Nonce}, }, rst) - require.NoError(t, db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { return transactions.AddResult(dbtx, txs[0].ID, &types.TransactionResult{Message: "hey"}) })) rst, err = transactions.AddressesWithPendingTransactions(db) @@ -475,7 +475,7 @@ func TestAddressesWithPendingTransactions(t *testing.T) { {Address: principals[0], Nonce: txs[1].Nonce}, {Address: principals[1], Nonce: txs[2].Nonce}, }, rst) - require.NoError(t, db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + require.NoError(t, db.WithTx(context.Background(), func(dbtx sql.Transaction) error { return transactions.AddResult(dbtx, txs[2].ID, &types.TransactionResult{Message: "hey"}) })) rst, err = transactions.AddressesWithPendingTransactions(db) diff --git a/syncer/atxsync/atxsync.go b/syncer/atxsync/atxsync.go index ac5309dcb4..ab93cb62bf 100644 --- a/syncer/atxsync/atxsync.go +++ b/syncer/atxsync/atxsync.go @@ -9,12 +9,12 @@ import ( "go.uber.org/zap/zapcore" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/system" ) -func getMissing(db *statesql.Database, set []types.ATXID) ([]types.ATXID, error) { +func getMissing(db sql.StateDatabase, set []types.ATXID) ([]types.ATXID, error) { missing := []types.ATXID{} for _, atx := range set { exist, err := atxs.Has(db, atx) @@ -35,7 +35,7 @@ func Download( ctx context.Context, retryInterval time.Duration, logger *zap.Logger, - db *statesql.Database, + db sql.StateDatabase, fetcher system.AtxFetcher, set []types.ATXID, ) error { diff --git a/syncer/atxsync/syncer.go b/syncer/atxsync/syncer.go index cea9d825f7..4d14043a4e 100644 --- a/syncer/atxsync/syncer.go +++ b/syncer/atxsync/syncer.go @@ -17,7 +17,6 @@ import ( "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/atxsync" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/system" ) @@ -76,7 +75,7 @@ func WithConfig(cfg Config) Opt { } } -func New(fetcher fetcher, db sql.Executor, localdb *localsql.Database, opts ...Opt) *Syncer { +func New(fetcher fetcher, db sql.Executor, localdb sql.LocalDatabase, opts ...Opt) *Syncer { s := &Syncer{ logger: zap.NewNop(), cfg: DefaultConfig(), @@ -95,7 +94,7 @@ type Syncer struct { cfg Config fetcher fetcher db sql.Executor - localdb *localsql.Database + localdb sql.LocalDatabase } func (s *Syncer) Download(parent context.Context, publish types.EpochID, downloadUntil time.Time) error { @@ -324,7 +323,7 @@ func (s *Syncer) downloadAtxs( } } - if err := s.localdb.WithTx(context.Background(), func(tx *sql.Tx) error { + if err := s.localdb.WithTx(context.Background(), func(tx sql.Transaction) error { err := atxsync.SaveRequest(tx, publish, lastSuccess, int64(len(state)), int64(len(downloaded))) if err != nil { return fmt.Errorf("failed to save request time: %w", err) diff --git a/syncer/atxsync/syncer_test.go b/syncer/atxsync/syncer_test.go index bfa3f53410..71530d6c26 100644 --- a/syncer/atxsync/syncer_test.go +++ b/syncer/atxsync/syncer_test.go @@ -14,6 +14,7 @@ import ( "github.com/spacemeshos/go-spacemesh/fetch" "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/atxsync" "github.com/spacemeshos/go-spacemesh/sql/localsql" @@ -60,8 +61,8 @@ func newTester(tb testing.TB, cfg Config) *tester { type tester struct { tb testing.TB syncer *Syncer - localdb *localsql.Database - db *statesql.Database + localdb sql.LocalDatabase + db sql.StateDatabase cfg Config ctrl *gomock.Controller fetcher *mocks.Mockfetcher diff --git a/syncer/find_fork_test.go b/syncer/find_fork_test.go index a80c8541cc..b48214db1b 100644 --- a/syncer/find_fork_test.go +++ b/syncer/find_fork_test.go @@ -17,6 +17,7 @@ import ( "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/log/logtest" "github.com/spacemeshos/go-spacemesh/p2p" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/syncer" @@ -25,7 +26,7 @@ import ( type testForkFinder struct { *syncer.ForkFinder - db *statesql.Database + db sql.StateDatabase mFetcher *mocks.Mockfetcher } @@ -88,7 +89,7 @@ func layerHash(layer int, good bool) types.Hash32 { return h2 } -func storeNodeHashes(t *testing.T, db *statesql.Database, diverge, max int) { +func storeNodeHashes(t *testing.T, db sql.StateDatabase, diverge, max int) { for lid := 0; lid <= max; lid++ { if lid < diverge { require.NoError(t, layers.SetMeshHash(db, types.LayerID(uint32(lid)), layerHash(lid, true))) diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go index a685f82f4f..86b09e2002 100644 --- a/syncer/malsync/syncer.go +++ b/syncer/malsync/syncer.go @@ -18,7 +18,6 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" - "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/malsync" "github.com/spacemeshos/go-spacemesh/system" ) @@ -218,12 +217,12 @@ type Syncer struct { cfg Config fetcher fetcher db sql.Executor - localdb *localsql.Database + localdb sql.LocalDatabase clock clockwork.Clock peerErrMetric counter } -func New(fetcher fetcher, db sql.Executor, localdb *localsql.Database, opts ...Opt) *Syncer { +func New(fetcher fetcher, db sql.Executor, localdb sql.LocalDatabase, opts ...Opt) *Syncer { s := &Syncer{ logger: zap.NewNop(), cfg: DefaultConfig(), @@ -342,7 +341,7 @@ func (s *Syncer) downloadNodeIDs(ctx context.Context, initial bool, updates chan } func (s *Syncer) updateState(ctx context.Context) error { - if err := s.localdb.WithTx(ctx, func(tx *sql.Tx) error { + if err := s.localdb.WithTx(ctx, func(tx sql.Transaction) error { return malsync.UpdateSyncState(tx, s.clock.Now()) }); err != nil { return fmt.Errorf("error updating malsync state: %w", err) diff --git a/syncer/malsync/syncer_test.go b/syncer/malsync/syncer_test.go index 4a9ec6f300..3bca66cc17 100644 --- a/syncer/malsync/syncer_test.go +++ b/syncer/malsync/syncer_test.go @@ -20,6 +20,7 @@ import ( "github.com/spacemeshos/go-spacemesh/malfeasance/wire" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/p2p/pubsub" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" "github.com/spacemeshos/go-spacemesh/sql/statesql" @@ -137,8 +138,8 @@ func malData(ids ...string) []types.NodeID { type tester struct { tb testing.TB syncer *Syncer - localdb *localsql.Database - db *statesql.Database + localdb sql.LocalDatabase + db sql.StateDatabase cfg Config ctrl *gomock.Controller fetcher *mocks.Mockfetcher diff --git a/tortoise/model/core.go b/tortoise/model/core.go index d8c622f7df..02f8612932 100644 --- a/tortoise/model/core.go +++ b/tortoise/model/core.go @@ -135,7 +135,7 @@ func (c *core) OnMessage(m Messenger, event Message) { if ev.LayerID.After(types.GetEffectiveGenesis()) { tortoise.RecoverLayer(context.Background(), c.tortoise, - c.cdb.Executor, + c.cdb.Database, c.atxdata, ev.LayerID, c.tortoise.OnBallot, diff --git a/tortoise/recover_test.go b/tortoise/recover_test.go index ecc4c6b51b..be41f02dd6 100644 --- a/tortoise/recover_test.go +++ b/tortoise/recover_test.go @@ -58,7 +58,7 @@ func TestRecoverState(t *testing.T) { tortoise2, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, simState.Atxdata, last, WithLogger(lg), @@ -82,7 +82,7 @@ func TestRecoverEmpty(t *testing.T) { cfg.LayerSize = size tortoise, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), 100, WithLogger(zaptest.NewLogger(t)), @@ -108,18 +108,18 @@ func TestRecoverWithOpinion(t *testing.T) { var last result.Layer for _, rst := range trt.Updates() { if rst.Verified { - require.NoError(t, layers.SetMeshHash(s.GetState(0).DB.Executor, rst.Layer, rst.Opinion)) + require.NoError(t, layers.SetMeshHash(s.GetState(0).DB.Database, rst.Layer, rst.Opinion)) } for _, block := range rst.Blocks { if block.Valid { - require.NoError(t, blocks.SetValid(s.GetState(0).DB.Executor, block.Header.ID)) + require.NoError(t, blocks.SetValid(s.GetState(0).DB.Database, block.Header.ID)) } } last = rst } tortoise, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), last.Layer, WithLogger(lg), @@ -156,14 +156,14 @@ func TestResetPending(t *testing.T) { require.NoError(t, layers.SetMeshHash(s.GetState(0).DB, rst.Layer, rst.Opinion)) for _, block := range rst.Blocks { if block.Valid { - require.NoError(t, blocks.SetValid(s.GetState(0).DB.Executor, block.Header.ID)) + require.NoError(t, blocks.SetValid(s.GetState(0).DB.Database, block.Header.ID)) } } } recovered, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), last, WithLogger(lg), @@ -203,14 +203,14 @@ func TestWindowRecovery(t *testing.T) { require.NoError(t, layers.SetMeshHash(s.GetState(0).DB, rst.Layer, rst.Opinion)) for _, block := range rst.Blocks { if block.Valid { - require.NoError(t, blocks.SetValid(s.GetState(0).DB.Executor, block.Header.ID)) + require.NoError(t, blocks.SetValid(s.GetState(0).DB.Database, block.Header.ID)) } } } recovered, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, atxsdata.New(), last, WithLogger(lg), @@ -239,7 +239,7 @@ func TestRecoverOnlyAtxs(t *testing.T) { trt.TallyVotes(context.Background(), lid) } future := last + 1000 - recovered, err := Recover(context.Background(), s.GetState(0).DB.Executor, s.GetState(0).Atxdata, future, + recovered, err := Recover(context.Background(), s.GetState(0).DB.Database, s.GetState(0).Atxdata, future, WithLogger(zaptest.NewLogger(t)), WithConfig(cfg), ) diff --git a/tortoise/sim/utils.go b/tortoise/sim/utils.go index c2385e4310..86337c4551 100644 --- a/tortoise/sim/utils.go +++ b/tortoise/sim/utils.go @@ -7,6 +7,7 @@ import ( "go.uber.org/zap" "github.com/spacemeshos/go-spacemesh/datastore" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) @@ -16,7 +17,7 @@ const ( func newCacheDB(logger *zap.Logger, conf config) *datastore.CachedDB { var ( - db *statesql.Database + db sql.StateDatabase err error ) if len(conf.Path) == 0 { diff --git a/tortoise/tortoise_test.go b/tortoise/tortoise_test.go index b34c25a225..2674dcdcd4 100644 --- a/tortoise/tortoise_test.go +++ b/tortoise/tortoise_test.go @@ -337,7 +337,7 @@ func tortoiseFromSimState(tb testing.TB, state sim.State, opts ...Opt) *recovery return &recoveryAdapter{ TB: tb, Tortoise: trtl, - db: state.DB.Executor, + db: state.DB.Database, atxdata: state.Atxdata, } } diff --git a/tortoise/tracer_test.go b/tortoise/tracer_test.go index b40cd7b26c..21e217ee5a 100644 --- a/tortoise/tracer_test.go +++ b/tortoise/tracer_test.go @@ -46,7 +46,7 @@ func TestTracer(t *testing.T) { path := filepath.Join(t.TempDir(), "tortoise.trace") trt, err := Recover( context.Background(), - s.GetState(0).DB.Executor, + s.GetState(0).DB.Database, s.GetState(0).Atxdata, last, WithTracer(WithOutput(path)), diff --git a/txs/cache.go b/txs/cache.go index 8a0594ea4f..d3d35a605c 100644 --- a/txs/cache.go +++ b/txs/cache.go @@ -19,7 +19,6 @@ import ( "github.com/spacemeshos/go-spacemesh/log" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" ) @@ -333,7 +332,7 @@ func (ac *accountCache) add(logger *zap.Logger, tx *types.Transaction, received func (ac *accountCache) addPendingFromNonce( logger *zap.Logger, - db *statesql.Database, + db sql.StateDatabase, nonce uint64, applied types.LayerID, ) error { @@ -424,7 +423,7 @@ func (ac *accountCache) getMempool(logger *zap.Logger) []*NanoTX { // because applying a layer changes the conservative balance in the cache. func (ac *accountCache) resetAfterApply( logger *zap.Logger, - db *statesql.Database, + db sql.StateDatabase, nextNonce, newBalance uint64, applied types.LayerID, ) error { @@ -490,7 +489,7 @@ func groupTXsByPrincipal(logger *zap.Logger, mtxs []*types.MeshTransaction) map[ } // buildFromScratch builds the cache from database. -func (c *Cache) buildFromScratch(db *statesql.Database) error { +func (c *Cache) buildFromScratch(db sql.StateDatabase) error { applied, err := layers.GetLastApplied(db) if err != nil { return fmt.Errorf("cache: get pending %w", err) @@ -607,7 +606,7 @@ func acceptable(err error) bool { func (c *Cache) Add( ctx context.Context, - db *statesql.Database, + db sql.StateDatabase, tx *types.Transaction, received time.Time, mustPersist bool, @@ -654,7 +653,7 @@ func (c *Cache) has(tid types.TransactionID) bool { // LinkTXsWithProposal associates the transactions to a proposal. func (c *Cache) LinkTXsWithProposal( - db *statesql.Database, + db sql.StateDatabase, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID, @@ -671,7 +670,7 @@ func (c *Cache) LinkTXsWithProposal( // LinkTXsWithBlock associates the transactions to a block. func (c *Cache) LinkTXsWithBlock( - db *statesql.Database, + db sql.StateDatabase, lid types.LayerID, bid types.BlockID, tids []types.TransactionID, @@ -703,7 +702,7 @@ func (c *Cache) updateLayer(lid types.LayerID, bid types.BlockID, tids []types.T return nil } -func (c *Cache) applyEmptyLayer(db *statesql.Database, lid types.LayerID) error { +func (c *Cache) applyEmptyLayer(db sql.StateDatabase, lid types.LayerID) error { c.mu.Lock() defer c.mu.Unlock() @@ -722,7 +721,7 @@ func (c *Cache) applyEmptyLayer(db *statesql.Database, lid types.LayerID) error // ApplyLayer retires the applied transactions from the cache and updates the balances. func (c *Cache) ApplyLayer( ctx context.Context, - db *statesql.Database, + db sql.StateDatabase, lid types.LayerID, bid types.BlockID, results []types.TransactionWithResult, @@ -750,7 +749,7 @@ func (c *Cache) ApplyLayer( // commit results before reporting them // TODO(dshulyak) save results in vm - if err := db.WithTx(context.Background(), func(dbtx *sql.Tx) error { + if err := db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, rst := range results { err := transactions.AddResult(dbtx, rst.ID, &rst.TransactionResult) if err != nil { @@ -839,7 +838,7 @@ func (c *Cache) ApplyLayer( return nil } -func (c *Cache) RevertToLayer(db *statesql.Database, revertTo types.LayerID) error { +func (c *Cache) RevertToLayer(db sql.StateDatabase, revertTo types.LayerID) error { if err := undoLayers(db, revertTo.Add(1)); err != nil { return err } @@ -880,7 +879,7 @@ func (c *Cache) GetMempool(logger *zap.Logger) map[types.Address][]*NanoTX { } // checkApplyOrder returns an error if layers were not applied in order. -func checkApplyOrder(logger *zap.Logger, db *statesql.Database, toApply types.LayerID) error { +func checkApplyOrder(logger *zap.Logger, db sql.StateDatabase, toApply types.LayerID) error { lastApplied, err := layers.GetLastApplied(db) if err != nil { logger.Error("failed to get last applied layer", zap.Error(err)) @@ -896,8 +895,8 @@ func checkApplyOrder(logger *zap.Logger, db *statesql.Database, toApply types.La return nil } -func addToProposal(db *statesql.Database, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID) error { - return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { +func addToProposal(db sql.StateDatabase, lid types.LayerID, pid types.ProposalID, tids []types.TransactionID) error { + return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, tid := range tids { if err := transactions.AddToProposal(dbtx, tid, lid, pid); err != nil { return fmt.Errorf("add2prop %w", err) @@ -907,8 +906,8 @@ func addToProposal(db *statesql.Database, lid types.LayerID, pid types.ProposalI }) } -func addToBlock(db *statesql.Database, lid types.LayerID, bid types.BlockID, tids []types.TransactionID) error { - return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { +func addToBlock(db sql.StateDatabase, lid types.LayerID, bid types.BlockID, tids []types.TransactionID) error { + return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { for _, tid := range tids { if err := transactions.AddToBlock(dbtx, tid, lid, bid); err != nil { return fmt.Errorf("add2block %w", err) @@ -918,8 +917,8 @@ func addToBlock(db *statesql.Database, lid types.LayerID, bid types.BlockID, tid }) } -func undoLayers(db *statesql.Database, from types.LayerID) error { - return db.WithTx(context.Background(), func(dbtx *sql.Tx) error { +func undoLayers(db sql.StateDatabase, from types.LayerID) error { + return db.WithTx(context.Background(), func(dbtx sql.Transaction) error { err := transactions.UndoLayers(dbtx, from) if err != nil { return fmt.Errorf("undo %w", err) diff --git a/txs/cache_test.go b/txs/cache_test.go index fce0897e8a..d0abfa34f9 100644 --- a/txs/cache_test.go +++ b/txs/cache_test.go @@ -19,7 +19,7 @@ import ( type testCache struct { *Cache - db *statesql.Database + db sql.StateDatabase } type testAcct struct { @@ -68,7 +68,7 @@ func newMeshTX( func genAndSaveTXs( t *testing.T, - db *statesql.Database, + db sql.StateDatabase, signer *signing.EdSigner, from, to uint64, startTime time.Time, @@ -89,14 +89,14 @@ func genTXs(t *testing.T, signer *signing.EdSigner, from, to uint64, startTime t return mtxs } -func saveTXs(t *testing.T, db *statesql.Database, mtxs []*types.MeshTransaction) { +func saveTXs(t *testing.T, db sql.StateDatabase, mtxs []*types.MeshTransaction) { t.Helper() for _, mtx := range mtxs { require.NoError(t, transactions.Add(db, &mtx.Transaction, mtx.Received)) } } -func checkTXStateFromDB(t *testing.T, db *statesql.Database, txs []*types.MeshTransaction, state types.TXState) { +func checkTXStateFromDB(t *testing.T, db sql.StateDatabase, txs []*types.MeshTransaction, state types.TXState) { for _, mtx := range txs { got, err := transactions.Get(db, mtx.ID) require.NoError(t, err) @@ -104,7 +104,7 @@ func checkTXStateFromDB(t *testing.T, db *statesql.Database, txs []*types.MeshTr } } -func checkTXNotInDB(t *testing.T, db *statesql.Database, tid types.TransactionID) { +func checkTXNotInDB(t *testing.T, db sql.StateDatabase, tid types.TransactionID) { _, err := transactions.Get(db, tid) require.ErrorIs(t, err, sql.ErrNotFound) } diff --git a/txs/conservative_state.go b/txs/conservative_state.go index 15f0aa2d47..def31171bf 100644 --- a/txs/conservative_state.go +++ b/txs/conservative_state.go @@ -11,8 +11,8 @@ import ( "github.com/spacemeshos/go-spacemesh/common/types" "github.com/spacemeshos/go-spacemesh/events" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" - "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" "github.com/spacemeshos/go-spacemesh/system" ) @@ -54,12 +54,12 @@ type ConservativeState struct { logger *zap.Logger cfg CSConfig - db *statesql.Database + db sql.StateDatabase cache *Cache } // NewConservativeState returns a ConservativeState. -func NewConservativeState(state vmState, db *statesql.Database, opts ...ConservativeStateOpt) *ConservativeState { +func NewConservativeState(state vmState, db sql.StateDatabase, opts ...ConservativeStateOpt) *ConservativeState { cs := &ConservativeState{ vmState: state, cfg: defaultCSConfig(), diff --git a/txs/conservative_state_test.go b/txs/conservative_state_test.go index e3648582e6..f9dd30478b 100644 --- a/txs/conservative_state_test.go +++ b/txs/conservative_state_test.go @@ -23,6 +23,7 @@ import ( "github.com/spacemeshos/go-spacemesh/genvm/sdk/wallet" "github.com/spacemeshos/go-spacemesh/p2p" "github.com/spacemeshos/go-spacemesh/signing" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/statesql" "github.com/spacemeshos/go-spacemesh/sql/transactions" @@ -73,7 +74,7 @@ func newTxWthRecipient( type testConState struct { *ConservativeState logger *zap.Logger - db *statesql.Database + db sql.StateDatabase mvm *MockvmState id peer.ID From 2e41e1a8e32cd98dfea2f5ae650791d80c828394 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 01:21:16 +0400 Subject: [PATCH 17/62] activation: fix test --- activation/e2e/builds_atx_v2_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/activation/e2e/builds_atx_v2_test.go b/activation/e2e/builds_atx_v2_test.go index 004656d301..1a64851939 100644 --- a/activation/e2e/builds_atx_v2_test.go +++ b/activation/e2e/builds_atx_v2_test.go @@ -23,9 +23,9 @@ import ( "github.com/spacemeshos/go-spacemesh/p2p/pubsub" "github.com/spacemeshos/go-spacemesh/p2p/pubsub/mocks" "github.com/spacemeshos/go-spacemesh/signing" - "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/localsql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" smocks "github.com/spacemeshos/go-spacemesh/system/mocks" "github.com/spacemeshos/go-spacemesh/timesync" ) @@ -51,7 +51,7 @@ func TestBuilder_SwitchesToBuildV2(t *testing.T) { require.NoError(t, err) cfg := activation.DefaultPostConfig() - db := sql.InMemory() + db := statesql.InMemory() cdb := datastore.NewCachedDB(db, logger) syncer := activation.NewMocksyncer(ctrl) From dd2ed3e01811eb77524106ef7684375f80b530b2 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 02:46:44 +0400 Subject: [PATCH 18/62] sql: fix tests --- cmd/merge-nodes/internal/merge_action_test.go | 2 ++ sql/activesets/activesets_test.go | 4 ++-- sql/atxs/atxs_test.go | 2 +- sql/database.go | 6 ++++-- sql/database_test.go | 18 +++++++++--------- sql/schema.go | 2 +- sql/vacuum_test.go | 2 +- 7 files changed, 20 insertions(+), 16 deletions(-) diff --git a/cmd/merge-nodes/internal/merge_action_test.go b/cmd/merge-nodes/internal/merge_action_test.go index 3548b64c31..3c262f08ba 100644 --- a/cmd/merge-nodes/internal/merge_action_test.go +++ b/cmd/merge-nodes/internal/merge_action_test.go @@ -37,6 +37,7 @@ func Test_MergeDBs_InvalidTargetScheme(t *testing.T) { db, err := localsql.Open("file:"+filepath.Join(tmpDst, localDbFile), sql.WithDatabaseSchema(oldSchema(t)), sql.WithForceMigrations(true), + sql.WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -95,6 +96,7 @@ func Test_MergeDBs_InvalidSourceScheme(t *testing.T) { db, err = localsql.Open("file:"+filepath.Join(tmpSrc, localDbFile), sql.WithDatabaseSchema(oldSchema(t)), sql.WithForceMigrations(true), + sql.WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) diff --git a/sql/activesets/activesets_test.go b/sql/activesets/activesets_test.go index 8aa3c9193b..acb8fe53ce 100644 --- a/sql/activesets/activesets_test.go +++ b/sql/activesets/activesets_test.go @@ -79,12 +79,12 @@ func TestCachedActiveSet(t *testing.T) { for i := 0; i < 3; i++ { require.NoError(t, LoadBlob(ctx, db, ids[0].Bytes(), &b)) require.Equal(t, codec.MustEncode(set0), b.Bytes) - require.Equal(t, 3, db.QueryCount()) + require.Equal(t, 3, db.QueryCount(), "ids[0]: QueryCount at i=%d", i) } for i := 0; i < 3; i++ { require.NoError(t, LoadBlob(ctx, db, ids[1].Bytes(), &b)) require.Equal(t, codec.MustEncode(set1), b.Bytes) - require.Equal(t, 4, db.QueryCount()) + require.Equal(t, 4, db.QueryCount(), "ids[1]: QueryCount at i=%d", i) } } diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 58627db8e1..e78f7eee6c 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -617,7 +617,7 @@ func TestLoadBlob(t *testing.T) { } func TestLoadBlob_DefaultsToV1(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) diff --git a/sql/database.go b/sql/database.go index 9cd979f606..625db3aa71 100644 --- a/sql/database.go +++ b/sql/database.go @@ -156,14 +156,15 @@ func WithDatabaseSchema(schema *Schema) Opt { } } -// WithAllowSchemaDrift prevents Open from failing upon schema drift. A warning is printed instead. +// WithAllowSchemaDrift prevents Open from failing upon schema +// drift. A warning is printed instead. func WithAllowSchemaDrift(allow bool) Opt { return func(c *conf) { c.allowSchemaDrift = allow } } -func withIgnoreSchemaDrift() Opt { +func WithIgnoreSchemaDrift() Opt { return func(c *conf) { c.ignoreSchemaDrift = true } @@ -284,6 +285,7 @@ func Version(uri string) (int, error) { // Database represents a database. type Database interface { Executor + QueryCache Close() error QueryCount() int QueryCache() QueryCache diff --git a/sql/database_test.go b/sql/database_test.go index c91f61fbac..12406899a3 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -25,7 +25,7 @@ func Test_Transaction_Isolation(t *testing.T) { field int );`, }), - withIgnoreSchemaDrift(), + WithIgnoreSchemaDrift(), ) tx, err := db.Tx(context.Background()) require.NoError(t, err) @@ -90,7 +90,7 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), - withIgnoreSchemaDrift(), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -122,7 +122,7 @@ func Test_Migration_Disabled(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), - withIgnoreSchemaDrift(), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -157,7 +157,7 @@ func TestDatabaseSkipMigrations(t *testing.T) { db, err := Open("file:"+dbFile, WithDatabaseSchema(schema), WithForceMigrations(true), - withIgnoreSchemaDrift(), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -181,7 +181,7 @@ func TestDatabaseVacuumState(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), - withIgnoreSchemaDrift(), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -191,7 +191,7 @@ func TestDatabaseVacuumState(t *testing.T) { Migrations: MigrationList{migration1, migration2}, }), WithVacuumState(2), - withIgnoreSchemaDrift(), + WithIgnoreSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -202,7 +202,7 @@ func TestDatabaseVacuumState(t *testing.T) { } func TestQueryCount(t *testing.T) { - db := InMemory(withIgnoreSchemaDrift()) + db := InMemory(WithIgnoreSchemaDrift()) require.Equal(t, 0, db.QueryCount()) n, err := db.Exec("select 1", nil, nil) @@ -226,7 +226,7 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { migration2.EXPECT().Order().Return(2).AnyTimes() dbFile := filepath.Join(dir, "test.sql") - db, err := Open("file:"+dbFile, WithForceMigrations(true), withIgnoreSchemaDrift()) + db, err := Open("file:"+dbFile, WithForceMigrations(true), WithIgnoreSchemaDrift()) require.NoError(t, err) _, err = db.Exec("PRAGMA user_version = 3", nil, nil) require.NoError(t, err) @@ -236,7 +236,7 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1, migration2}, }), - withIgnoreSchemaDrift(), + WithIgnoreSchemaDrift(), ) require.ErrorIs(t, err, ErrTooNew) } diff --git a/sql/schema.go b/sql/schema.go index 6feec5d031..5130aeea36 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -205,7 +205,7 @@ func (g *SchemaGen) Generate(outputFile string) error { WithLogger(g.logger), WithDatabaseSchema(g.schema), WithForceMigrations(true), - withIgnoreSchemaDrift()) + WithIgnoreSchemaDrift()) if err != nil { return fmt.Errorf("error opening in-memory db: %w", err) } diff --git a/sql/vacuum_test.go b/sql/vacuum_test.go index 5017cad677..1a89158c64 100644 --- a/sql/vacuum_test.go +++ b/sql/vacuum_test.go @@ -7,6 +7,6 @@ import ( ) func TestVacuumDB(t *testing.T) { - db := InMemory(withIgnoreSchemaDrift()) + db := InMemory(WithIgnoreSchemaDrift()) require.NoError(t, Vacuum(db)) } From 13e8c1410cd975e1de9fa7528ebdbb8e50c8dfe7 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 03:57:43 +0400 Subject: [PATCH 19/62] sql: fix query cache handling --- datastore/store.go | 2 -- fetch/handler_test.go | 11 ++++++----- node/node_version_check_test.go | 3 ++- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/datastore/store.go b/datastore/store.go index 8b1d37a584..3042764367 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -35,7 +35,6 @@ type VrfNonceKey struct { // CachedDB is simply a database injected with cache. type CachedDB struct { sql.Database - sql.QueryCache logger *zap.Logger // cache is optional in tests. It MUST be set for the 'App' @@ -109,7 +108,6 @@ func NewCachedDB(db sql.StateDatabase, lg *zap.Logger, opts ...Opt) *CachedDB { return &CachedDB{ Database: db, - QueryCache: db.QueryCache(), logger: lg, atxsdata: o.atxsdata, atxCache: atxHdrCache, diff --git a/fetch/handler_test.go b/fetch/handler_test.go index 7cfe001d2a..fab3ebc560 100644 --- a/fetch/handler_test.go +++ b/fetch/handler_test.go @@ -350,6 +350,8 @@ func testHandleEpochInfoReqWithQueryCache( getInfo func(th *testHandler, req []byte, ed *EpochData), ) { th := createTestHandler(t, sql.WithQueryCache(true)) + require.True(t, th.cdb.Database.IsCached()) + require.True(t, sql.IsCached(th.cdb)) epoch := types.EpochID(11) var expected EpochData @@ -360,8 +362,7 @@ func testHandleEpochInfoReqWithQueryCache( expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) } - qc := th.cdb.Database.(interface{ QueryCount() int }) - require.Equal(t, 20, qc.QueryCount()) + require.Equal(t, 20, th.cdb.Database.QueryCount()) epochBytes, err := codec.Encode(epoch) require.NoError(t, err) @@ -369,7 +370,7 @@ func testHandleEpochInfoReqWithQueryCache( for i := 0; i < 3; i++ { getInfo(th, epochBytes, &got) require.ElementsMatch(t, expected.AtxIDs, got.AtxIDs) - require.Equal(t, 21, qc.QueryCount()) + require.Equal(t, 21, th.cdb.Database.QueryCount(), "query count @ i = %d", i) } // Add another ATX which should be appended to the cached slice @@ -377,14 +378,14 @@ func testHandleEpochInfoReqWithQueryCache( require.NoError(t, atxs.Add(th.cdb, vatx)) atxs.AtxAdded(th.cdb, vatx) expected.AtxIDs = append(expected.AtxIDs, vatx.ID()) - require.Equal(t, 23, qc.QueryCount()) + require.Equal(t, 23, th.cdb.Database.QueryCount()) getInfo(th, epochBytes, &got) require.ElementsMatch(t, expected.AtxIDs, got.AtxIDs) // The query count is not incremented as the slice is still // cached and the new atx is just appended to it, even though // the response is re-serialized. - require.Equal(t, 23, qc.QueryCount()) + require.Equal(t, 23, th.cdb.Database.QueryCount()) } func TestHandleEpochInfoReqWithQueryCache(t *testing.T) { diff --git a/node/node_version_check_test.go b/node/node_version_check_test.go index 787a7b84e6..affc503f86 100644 --- a/node/node_version_check_test.go +++ b/node/node_version_check_test.go @@ -45,7 +45,8 @@ func TestUpgradeToV15(t *testing.T) { db, err := statesql.Open(uri, sql.WithDatabaseSchema(schema), - sql.WithForceMigrations(true)) + sql.WithForceMigrations(true), + sql.WithIgnoreSchemaDrift()) require.NoError(t, err) require.NoError(t, db.Close()) From fbf288033e62df5f9dfc9c8351e58a4cc060b146 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 04:07:29 +0400 Subject: [PATCH 20/62] sql: update mocks --- sql/mocks/mocks.go | 447 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 447 insertions(+) diff --git a/sql/mocks/mocks.go b/sql/mocks/mocks.go index bb3021faa9..0d6890a346 100644 --- a/sql/mocks/mocks.go +++ b/sql/mocks/mocks.go @@ -102,6 +102,42 @@ func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder { return m.recorder } +// ClearCache mocks base method. +func (m *MockDatabase) ClearCache() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClearCache") +} + +// ClearCache indicates an expected call of ClearCache. +func (mr *MockDatabaseMockRecorder) ClearCache() *MockDatabaseClearCacheCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearCache", reflect.TypeOf((*MockDatabase)(nil).ClearCache)) + return &MockDatabaseClearCacheCall{Call: call} +} + +// MockDatabaseClearCacheCall wrap *gomock.Call +type MockDatabaseClearCacheCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseClearCacheCall) Return() *MockDatabaseClearCacheCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseClearCacheCall) Do(f func()) *MockDatabaseClearCacheCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseClearCacheCall) DoAndReturn(f func()) *MockDatabaseClearCacheCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Close mocks base method. func (m *MockDatabase) Close() error { m.ctrl.T.Helper() @@ -179,6 +215,83 @@ func (c *MockDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decod return c } +// GetValue mocks base method. +func (m *MockDatabase) GetValue(ctx context.Context, key queryCacheKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValue", ctx, key, subKey, retrieve) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetValue indicates an expected call of GetValue. +func (mr *MockDatabaseMockRecorder) GetValue(ctx, key, subKey, retrieve any) *MockDatabaseGetValueCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockDatabase)(nil).GetValue), ctx, key, subKey, retrieve) + return &MockDatabaseGetValueCall{Call: call} +} + +// MockDatabaseGetValueCall wrap *gomock.Call +type MockDatabaseGetValueCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseGetValueCall) Return(arg0 any, arg1 error) *MockDatabaseGetValueCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseGetValueCall) Do(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockDatabaseGetValueCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseGetValueCall) DoAndReturn(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockDatabaseGetValueCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// IsCached mocks base method. +func (m *MockDatabase) IsCached() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsCached") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsCached indicates an expected call of IsCached. +func (mr *MockDatabaseMockRecorder) IsCached() *MockDatabaseIsCachedCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockDatabase)(nil).IsCached)) + return &MockDatabaseIsCachedCall{Call: call} +} + +// MockDatabaseIsCachedCall wrap *gomock.Call +type MockDatabaseIsCachedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseIsCachedCall) Return(arg0 bool) *MockDatabaseIsCachedCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseIsCachedCall) Do(f func() bool) *MockDatabaseIsCachedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseIsCachedCall) DoAndReturn(f func() bool) *MockDatabaseIsCachedCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // QueryCache mocks base method. func (m *MockDatabase) QueryCache() sql.QueryCache { m.ctrl.T.Helper() @@ -333,6 +446,42 @@ func (c *MockDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) (sql.T return c } +// UpdateSlice mocks base method. +func (m *MockDatabase) UpdateSlice(key queryCacheKey, update sql.SliceAppender) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateSlice", key, update) +} + +// UpdateSlice indicates an expected call of UpdateSlice. +func (mr *MockDatabaseMockRecorder) UpdateSlice(key, update any) *MockDatabaseUpdateSliceCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSlice", reflect.TypeOf((*MockDatabase)(nil).UpdateSlice), key, update) + return &MockDatabaseUpdateSliceCall{Call: call} +} + +// MockDatabaseUpdateSliceCall wrap *gomock.Call +type MockDatabaseUpdateSliceCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDatabaseUpdateSliceCall) Return() *MockDatabaseUpdateSliceCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDatabaseUpdateSliceCall) Do(f func(queryCacheKey, sql.SliceAppender)) *MockDatabaseUpdateSliceCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDatabaseUpdateSliceCall) DoAndReturn(f func(queryCacheKey, sql.SliceAppender)) *MockDatabaseUpdateSliceCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // WithTx mocks base method. func (m *MockDatabase) WithTx(ctx context.Context, exec func(sql.Transaction) error) error { m.ctrl.T.Helper() @@ -570,6 +719,42 @@ func (m *MockStateDatabase) EXPECT() *MockStateDatabaseMockRecorder { return m.recorder } +// ClearCache mocks base method. +func (m *MockStateDatabase) ClearCache() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClearCache") +} + +// ClearCache indicates an expected call of ClearCache. +func (mr *MockStateDatabaseMockRecorder) ClearCache() *MockStateDatabaseClearCacheCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearCache", reflect.TypeOf((*MockStateDatabase)(nil).ClearCache)) + return &MockStateDatabaseClearCacheCall{Call: call} +} + +// MockStateDatabaseClearCacheCall wrap *gomock.Call +type MockStateDatabaseClearCacheCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseClearCacheCall) Return() *MockStateDatabaseClearCacheCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseClearCacheCall) Do(f func()) *MockStateDatabaseClearCacheCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseClearCacheCall) DoAndReturn(f func()) *MockStateDatabaseClearCacheCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Close mocks base method. func (m *MockStateDatabase) Close() error { m.ctrl.T.Helper() @@ -647,6 +832,83 @@ func (c *MockStateDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql. return c } +// GetValue mocks base method. +func (m *MockStateDatabase) GetValue(ctx context.Context, key queryCacheKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValue", ctx, key, subKey, retrieve) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetValue indicates an expected call of GetValue. +func (mr *MockStateDatabaseMockRecorder) GetValue(ctx, key, subKey, retrieve any) *MockStateDatabaseGetValueCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockStateDatabase)(nil).GetValue), ctx, key, subKey, retrieve) + return &MockStateDatabaseGetValueCall{Call: call} +} + +// MockStateDatabaseGetValueCall wrap *gomock.Call +type MockStateDatabaseGetValueCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseGetValueCall) Return(arg0 any, arg1 error) *MockStateDatabaseGetValueCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseGetValueCall) Do(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockStateDatabaseGetValueCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseGetValueCall) DoAndReturn(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockStateDatabaseGetValueCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// IsCached mocks base method. +func (m *MockStateDatabase) IsCached() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsCached") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsCached indicates an expected call of IsCached. +func (mr *MockStateDatabaseMockRecorder) IsCached() *MockStateDatabaseIsCachedCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockStateDatabase)(nil).IsCached)) + return &MockStateDatabaseIsCachedCall{Call: call} +} + +// MockStateDatabaseIsCachedCall wrap *gomock.Call +type MockStateDatabaseIsCachedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseIsCachedCall) Return(arg0 bool) *MockStateDatabaseIsCachedCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseIsCachedCall) Do(f func() bool) *MockStateDatabaseIsCachedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseIsCachedCall) DoAndReturn(f func() bool) *MockStateDatabaseIsCachedCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // IsStateDatabase mocks base method. func (m *MockStateDatabase) IsStateDatabase() bool { m.ctrl.T.Helper() @@ -839,6 +1101,42 @@ func (c *MockStateDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) ( return c } +// UpdateSlice mocks base method. +func (m *MockStateDatabase) UpdateSlice(key queryCacheKey, update sql.SliceAppender) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateSlice", key, update) +} + +// UpdateSlice indicates an expected call of UpdateSlice. +func (mr *MockStateDatabaseMockRecorder) UpdateSlice(key, update any) *MockStateDatabaseUpdateSliceCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSlice", reflect.TypeOf((*MockStateDatabase)(nil).UpdateSlice), key, update) + return &MockStateDatabaseUpdateSliceCall{Call: call} +} + +// MockStateDatabaseUpdateSliceCall wrap *gomock.Call +type MockStateDatabaseUpdateSliceCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockStateDatabaseUpdateSliceCall) Return() *MockStateDatabaseUpdateSliceCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockStateDatabaseUpdateSliceCall) Do(f func(queryCacheKey, sql.SliceAppender)) *MockStateDatabaseUpdateSliceCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockStateDatabaseUpdateSliceCall) DoAndReturn(f func(queryCacheKey, sql.SliceAppender)) *MockStateDatabaseUpdateSliceCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // WithTx mocks base method. func (m *MockStateDatabase) WithTx(ctx context.Context, exec func(sql.Transaction) error) error { m.ctrl.T.Helper() @@ -938,6 +1236,42 @@ func (m *MockLocalDatabase) EXPECT() *MockLocalDatabaseMockRecorder { return m.recorder } +// ClearCache mocks base method. +func (m *MockLocalDatabase) ClearCache() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ClearCache") +} + +// ClearCache indicates an expected call of ClearCache. +func (mr *MockLocalDatabaseMockRecorder) ClearCache() *MockLocalDatabaseClearCacheCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearCache", reflect.TypeOf((*MockLocalDatabase)(nil).ClearCache)) + return &MockLocalDatabaseClearCacheCall{Call: call} +} + +// MockLocalDatabaseClearCacheCall wrap *gomock.Call +type MockLocalDatabaseClearCacheCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseClearCacheCall) Return() *MockLocalDatabaseClearCacheCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseClearCacheCall) Do(f func()) *MockLocalDatabaseClearCacheCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseClearCacheCall) DoAndReturn(f func()) *MockLocalDatabaseClearCacheCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Close mocks base method. func (m *MockLocalDatabase) Close() error { m.ctrl.T.Helper() @@ -1015,6 +1349,83 @@ func (c *MockLocalDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql. return c } +// GetValue mocks base method. +func (m *MockLocalDatabase) GetValue(ctx context.Context, key queryCacheKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValue", ctx, key, subKey, retrieve) + ret0, _ := ret[0].(any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetValue indicates an expected call of GetValue. +func (mr *MockLocalDatabaseMockRecorder) GetValue(ctx, key, subKey, retrieve any) *MockLocalDatabaseGetValueCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockLocalDatabase)(nil).GetValue), ctx, key, subKey, retrieve) + return &MockLocalDatabaseGetValueCall{Call: call} +} + +// MockLocalDatabaseGetValueCall wrap *gomock.Call +type MockLocalDatabaseGetValueCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseGetValueCall) Return(arg0 any, arg1 error) *MockLocalDatabaseGetValueCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseGetValueCall) Do(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockLocalDatabaseGetValueCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseGetValueCall) DoAndReturn(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockLocalDatabaseGetValueCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// IsCached mocks base method. +func (m *MockLocalDatabase) IsCached() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsCached") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsCached indicates an expected call of IsCached. +func (mr *MockLocalDatabaseMockRecorder) IsCached() *MockLocalDatabaseIsCachedCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockLocalDatabase)(nil).IsCached)) + return &MockLocalDatabaseIsCachedCall{Call: call} +} + +// MockLocalDatabaseIsCachedCall wrap *gomock.Call +type MockLocalDatabaseIsCachedCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseIsCachedCall) Return(arg0 bool) *MockLocalDatabaseIsCachedCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseIsCachedCall) Do(f func() bool) *MockLocalDatabaseIsCachedCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseIsCachedCall) DoAndReturn(f func() bool) *MockLocalDatabaseIsCachedCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // IsLocalDatabase mocks base method. func (m *MockLocalDatabase) IsLocalDatabase() bool { m.ctrl.T.Helper() @@ -1207,6 +1618,42 @@ func (c *MockLocalDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) ( return c } +// UpdateSlice mocks base method. +func (m *MockLocalDatabase) UpdateSlice(key queryCacheKey, update sql.SliceAppender) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateSlice", key, update) +} + +// UpdateSlice indicates an expected call of UpdateSlice. +func (mr *MockLocalDatabaseMockRecorder) UpdateSlice(key, update any) *MockLocalDatabaseUpdateSliceCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSlice", reflect.TypeOf((*MockLocalDatabase)(nil).UpdateSlice), key, update) + return &MockLocalDatabaseUpdateSliceCall{Call: call} +} + +// MockLocalDatabaseUpdateSliceCall wrap *gomock.Call +type MockLocalDatabaseUpdateSliceCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockLocalDatabaseUpdateSliceCall) Return() *MockLocalDatabaseUpdateSliceCall { + c.Call = c.Call.Return() + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockLocalDatabaseUpdateSliceCall) Do(f func(queryCacheKey, sql.SliceAppender)) *MockLocalDatabaseUpdateSliceCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockLocalDatabaseUpdateSliceCall) DoAndReturn(f func(queryCacheKey, sql.SliceAppender)) *MockLocalDatabaseUpdateSliceCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // WithTx mocks base method. func (m *MockLocalDatabase) WithTx(ctx context.Context, exec func(sql.Transaction) error) error { m.ctrl.T.Helper() From 6dbc902b90c6d6e1316700ec13cc9ae76edeaad2 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 04:22:47 +0400 Subject: [PATCH 21/62] sql: fix QueryCache related mocks --- sql/mocks/mocks.go | 36 ++++++++++++++++++------------------ sql/querycache.go | 32 ++++++++++++++++---------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/sql/mocks/mocks.go b/sql/mocks/mocks.go index 0d6890a346..1a40d1adb6 100644 --- a/sql/mocks/mocks.go +++ b/sql/mocks/mocks.go @@ -216,7 +216,7 @@ func (c *MockDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decod } // GetValue mocks base method. -func (m *MockDatabase) GetValue(ctx context.Context, key queryCacheKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { +func (m *MockDatabase) GetValue(ctx context.Context, key sql.QueryCacheItemKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetValue", ctx, key, subKey, retrieve) ret0, _ := ret[0].(any) @@ -243,13 +243,13 @@ func (c *MockDatabaseGetValueCall) Return(arg0 any, arg1 error) *MockDatabaseGet } // Do rewrite *gomock.Call.Do -func (c *MockDatabaseGetValueCall) Do(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockDatabaseGetValueCall { +func (c *MockDatabaseGetValueCall) Do(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockDatabaseGetValueCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseGetValueCall) DoAndReturn(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockDatabaseGetValueCall { +func (c *MockDatabaseGetValueCall) DoAndReturn(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockDatabaseGetValueCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -447,7 +447,7 @@ func (c *MockDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) (sql.T } // UpdateSlice mocks base method. -func (m *MockDatabase) UpdateSlice(key queryCacheKey, update sql.SliceAppender) { +func (m *MockDatabase) UpdateSlice(key sql.QueryCacheItemKey, update sql.SliceAppender) { m.ctrl.T.Helper() m.ctrl.Call(m, "UpdateSlice", key, update) } @@ -471,13 +471,13 @@ func (c *MockDatabaseUpdateSliceCall) Return() *MockDatabaseUpdateSliceCall { } // Do rewrite *gomock.Call.Do -func (c *MockDatabaseUpdateSliceCall) Do(f func(queryCacheKey, sql.SliceAppender)) *MockDatabaseUpdateSliceCall { +func (c *MockDatabaseUpdateSliceCall) Do(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockDatabaseUpdateSliceCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseUpdateSliceCall) DoAndReturn(f func(queryCacheKey, sql.SliceAppender)) *MockDatabaseUpdateSliceCall { +func (c *MockDatabaseUpdateSliceCall) DoAndReturn(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockDatabaseUpdateSliceCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -833,7 +833,7 @@ func (c *MockStateDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql. } // GetValue mocks base method. -func (m *MockStateDatabase) GetValue(ctx context.Context, key queryCacheKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { +func (m *MockStateDatabase) GetValue(ctx context.Context, key sql.QueryCacheItemKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetValue", ctx, key, subKey, retrieve) ret0, _ := ret[0].(any) @@ -860,13 +860,13 @@ func (c *MockStateDatabaseGetValueCall) Return(arg0 any, arg1 error) *MockStateD } // Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseGetValueCall) Do(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockStateDatabaseGetValueCall { +func (c *MockStateDatabaseGetValueCall) Do(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockStateDatabaseGetValueCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseGetValueCall) DoAndReturn(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockStateDatabaseGetValueCall { +func (c *MockStateDatabaseGetValueCall) DoAndReturn(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockStateDatabaseGetValueCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -1102,7 +1102,7 @@ func (c *MockStateDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) ( } // UpdateSlice mocks base method. -func (m *MockStateDatabase) UpdateSlice(key queryCacheKey, update sql.SliceAppender) { +func (m *MockStateDatabase) UpdateSlice(key sql.QueryCacheItemKey, update sql.SliceAppender) { m.ctrl.T.Helper() m.ctrl.Call(m, "UpdateSlice", key, update) } @@ -1126,13 +1126,13 @@ func (c *MockStateDatabaseUpdateSliceCall) Return() *MockStateDatabaseUpdateSlic } // Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseUpdateSliceCall) Do(f func(queryCacheKey, sql.SliceAppender)) *MockStateDatabaseUpdateSliceCall { +func (c *MockStateDatabaseUpdateSliceCall) Do(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockStateDatabaseUpdateSliceCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseUpdateSliceCall) DoAndReturn(f func(queryCacheKey, sql.SliceAppender)) *MockStateDatabaseUpdateSliceCall { +func (c *MockStateDatabaseUpdateSliceCall) DoAndReturn(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockStateDatabaseUpdateSliceCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -1350,7 +1350,7 @@ func (c *MockLocalDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql. } // GetValue mocks base method. -func (m *MockLocalDatabase) GetValue(ctx context.Context, key queryCacheKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { +func (m *MockLocalDatabase) GetValue(ctx context.Context, key sql.QueryCacheItemKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetValue", ctx, key, subKey, retrieve) ret0, _ := ret[0].(any) @@ -1377,13 +1377,13 @@ func (c *MockLocalDatabaseGetValueCall) Return(arg0 any, arg1 error) *MockLocalD } // Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseGetValueCall) Do(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockLocalDatabaseGetValueCall { +func (c *MockLocalDatabaseGetValueCall) Do(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockLocalDatabaseGetValueCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseGetValueCall) DoAndReturn(f func(context.Context, queryCacheKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockLocalDatabaseGetValueCall { +func (c *MockLocalDatabaseGetValueCall) DoAndReturn(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockLocalDatabaseGetValueCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -1619,7 +1619,7 @@ func (c *MockLocalDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) ( } // UpdateSlice mocks base method. -func (m *MockLocalDatabase) UpdateSlice(key queryCacheKey, update sql.SliceAppender) { +func (m *MockLocalDatabase) UpdateSlice(key sql.QueryCacheItemKey, update sql.SliceAppender) { m.ctrl.T.Helper() m.ctrl.Call(m, "UpdateSlice", key, update) } @@ -1643,13 +1643,13 @@ func (c *MockLocalDatabaseUpdateSliceCall) Return() *MockLocalDatabaseUpdateSlic } // Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseUpdateSliceCall) Do(f func(queryCacheKey, sql.SliceAppender)) *MockLocalDatabaseUpdateSliceCall { +func (c *MockLocalDatabaseUpdateSliceCall) Do(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockLocalDatabaseUpdateSliceCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseUpdateSliceCall) DoAndReturn(f func(queryCacheKey, sql.SliceAppender)) *MockLocalDatabaseUpdateSliceCall { +func (c *MockLocalDatabaseUpdateSliceCall) DoAndReturn(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockLocalDatabaseUpdateSliceCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/sql/querycache.go b/sql/querycache.go index e9840a2959..38db15f880 100644 --- a/sql/querycache.go +++ b/sql/querycache.go @@ -17,14 +17,14 @@ type ( var NullQueryCache QueryCache = (*queryCache)(nil) -type queryCacheKey struct { +type QueryCacheItemKey struct { Kind QueryCacheKind Key string } // QueryCacheKey creates a key for QueryCache. -func QueryCacheKey(kind QueryCacheKind, key string) queryCacheKey { - return queryCacheKey{Kind: kind, Key: key} +func QueryCacheKey(kind QueryCacheKind, key string) QueryCacheItemKey { + return QueryCacheItemKey{Kind: kind, Key: key} } // QueryCacheSubKey denotes a cache subkey. The empty subkey refers to the main @@ -60,14 +60,14 @@ type QueryCache interface { // called for this cache. GetValue( ctx context.Context, - key queryCacheKey, + key QueryCacheItemKey, subKey QueryCacheSubKey, retrieve UntypedRetrieveFunc, ) (any, error) // UpdateSlice updates the slice stored in the cache by invoking the // specified SliceAppender. If the entry is not cached, the method does // nothing. - UpdateSlice(key queryCacheKey, update SliceAppender) + UpdateSlice(key QueryCacheItemKey, update SliceAppender) // ClearCache empties the cache. ClearCache() } @@ -87,7 +87,7 @@ func IsCached(db any) bool { func WithCachedValue[T any]( ctx context.Context, db any, - key queryCacheKey, + key QueryCacheItemKey, retrieve func(ctx context.Context) (T, error), ) (T, error) { return WithCachedSubKey(ctx, db, key, mainSubKey, retrieve) @@ -100,7 +100,7 @@ func WithCachedValue[T any]( func WithCachedSubKey[T any]( ctx context.Context, db any, - key queryCacheKey, + key QueryCacheItemKey, subKey QueryCacheSubKey, retrieve func(ctx context.Context) (T, error), ) (T, error) { @@ -124,7 +124,7 @@ func WithCachedSubKey[T any]( // AppendToCachedSlice adds a value to the slice stored in the cache by invoking // the specified SliceAppender. If the entry is not cached, the function does // nothing. -func AppendToCachedSlice[T any](db any, key queryCacheKey, v T) { +func AppendToCachedSlice[T any](db any, key QueryCacheItemKey, v T) { if cache, ok := db.(QueryCache); ok { cache.UpdateSlice(key, func(s any) any { if s == nil { @@ -145,7 +145,7 @@ type lru = simplelru.LRU[lruCacheKey, any] type queryCache struct { sync.Mutex updateMtx sync.RWMutex - subKeyMap map[queryCacheKey][]QueryCacheSubKey + subKeyMap map[QueryCacheItemKey][]QueryCacheSubKey cacheSizesByKind map[QueryCacheKind]int caches map[QueryCacheKind]*lru } @@ -162,7 +162,7 @@ func (c *queryCache) ensureLRU(kind QueryCacheKind) *lru { } lruForKind, err := simplelru.NewLRU[lruCacheKey, any](size, func(k lruCacheKey, v any) { if k.subKey == mainSubKey { - c.clearSubKeys(queryCacheKey{Kind: kind, Key: k.key}) + c.clearSubKeys(QueryCacheItemKey{Kind: kind, Key: k.key}) } }) if err != nil { @@ -175,7 +175,7 @@ func (c *queryCache) ensureLRU(kind QueryCacheKind) *lru { return lruForKind } -func (c *queryCache) clearSubKeys(key queryCacheKey) { +func (c *queryCache) clearSubKeys(key QueryCacheItemKey) { lru, found := c.caches[key.Kind] if !found { return @@ -188,7 +188,7 @@ func (c *queryCache) clearSubKeys(key queryCacheKey) { } } -func (c *queryCache) get(key queryCacheKey, subKey QueryCacheSubKey) (any, bool) { +func (c *queryCache) get(key QueryCacheItemKey, subKey QueryCacheSubKey) (any, bool) { c.Lock() defer c.Unlock() lru, found := c.caches[key.Kind] @@ -202,14 +202,14 @@ func (c *queryCache) get(key queryCacheKey, subKey QueryCacheSubKey) (any, bool) }) } -func (c *queryCache) set(key queryCacheKey, subKey QueryCacheSubKey, v any) { +func (c *queryCache) set(key QueryCacheItemKey, subKey QueryCacheSubKey, v any) { c.Lock() defer c.Unlock() if subKey != mainSubKey { sks := c.subKeyMap[key] if slices.Index(sks, subKey) < 0 { if c.subKeyMap == nil { - c.subKeyMap = make(map[queryCacheKey][]QueryCacheSubKey) + c.subKeyMap = make(map[QueryCacheItemKey][]QueryCacheSubKey) } c.subKeyMap[key] = append(sks, subKey) } @@ -224,7 +224,7 @@ func (c *queryCache) IsCached() bool { func (c *queryCache) GetValue( ctx context.Context, - key queryCacheKey, + key QueryCacheItemKey, subKey QueryCacheSubKey, retrieve UntypedRetrieveFunc, ) (any, error) { @@ -251,7 +251,7 @@ func (c *queryCache) GetValue( return v, err } -func (c *queryCache) UpdateSlice(key queryCacheKey, update SliceAppender) { +func (c *queryCache) UpdateSlice(key QueryCacheItemKey, update SliceAppender) { if c == nil { return } From 09e36bb16864f40af0357e26656545b5eb9ec347 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 05:14:56 +0400 Subject: [PATCH 22/62] node: don't print error twice upon failure --- cmd/node/main.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/node/main.go b/cmd/node/main.go index 99ef2265e4..0fb230794d 100644 --- a/cmd/node/main.go +++ b/cmd/node/main.go @@ -3,7 +3,6 @@ package main import ( - "fmt" _ "net/http/pprof" "os" @@ -24,7 +23,8 @@ func main() { // run the app cmd.Branch = branch cmd.NoMainNet = noMainNet == "true" if err := node.GetCommand().Execute(); err != nil { - fmt.Fprintln(os.Stderr, err) + // Do not print error as cmd.SilenceErrors is false + // and the error was already printed os.Exit(1) } } From 9dc3f0c50b2bf366300a526ed65cd84ca6bdba1f Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 05:23:59 +0400 Subject: [PATCH 23/62] sql: fix schema drift on Windows --- sql/schema.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/schema.go b/sql/schema.go index 5130aeea36..ab03125194 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -42,7 +42,8 @@ func LoadDBSchemaScript(db Executor) (string, error) { }); err != nil { return "", fmt.Errorf("error retrieving DB schema: %w", err) } - return sb.String(), nil + // On Windows, the result contains extra carriage returns + return strings.ReplaceAll(sb.String(), "\r", ""), nil } // Schema represents database schema. From 8bbeeb5830ad5e9ea263cfdd7ed7612d411e5698 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 05:24:10 +0400 Subject: [PATCH 24/62] sql: update docs on schema handling --- CONTRIBUTING.md | 84 ++++++++++++++++++++++++++++++++++++++++++++++--- README.md | 80 ---------------------------------------------- 2 files changed, 80 insertions(+), 84 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 13c79edce5..de27fcaa54 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,7 +2,7 @@ Thank you for considering to contribute to the go-spacemesh open source project. We welcome contributions large and small and we actively accept contributions. - go-spacemesh is part of [The Spacemesh open source project](https://spacemesh.io), and is MIT licensed free open source software. -- Please make sure to scan the [open issues](https://github.com/spacemeshos/go-spacemesh/issues). +- Please make sure to scan the [open issues](https://github.com/spacemeshos/go-spacemesh/issues). - Search the closed ones before reporting things, and help us with the open ones. - Make sure that you are able to contribute to MIT licensed free open software (no legal issues please). - Introduce yourself, ask questions about issues or talk about things on our [discord server](https://chat.spacemesh.io/). @@ -39,7 +39,7 @@ Thank you for considering to contribute to the go-spacemesh open source project. # Code Guidelines Please follow these guidelines for your PR to be reviewed and be considered for merging into the project. -1. Document all methods and functions using [go commentary](https://golang.org/doc/effective_go.html#commentary). +1. Document all methods and functions using [go commentary](https://golang.org/doc/effective_go.html#commentary). 2. Provide at least one unit test for each function and method. 3. Provide at least one integration test for each feature with a flow which involves more than one function call. Your tests should reflect the main ways that your code should be used. 4. Run `go mod tidy`, `go fmt ./...` and `make lint` to format and lint your code before submitting your PR. @@ -49,7 +49,7 @@ Please follow these guidelines for your PR to be reviewed and be considered for - Check for existing 3rd-party packages in the vendor folder `./vendor` before adding a new dependency. - Use [govendor](https://github.com/kardianos/govendor) to add a new dependency. -# Working on a funded issue +# Working on a funded issue ## Step 1 - Discover :sunrise_over_mountains: - Browse the [open funded issues](https://github.com/spacemeshos/go-spacemesh/labels/funded%20%3Amoneybag%3A) in our github repo, or on our [gitcoin.io funded issues page](https://gitcoin.co/profile/spacemeshos). @@ -68,6 +68,82 @@ Please follow these guidelines for your PR to be reviewed and be considered for ## Step 3 - Get paid :moneybag: - When ready, submit your PR for review and go through the code review process with one of our maintainers. - Expect a review process that ensures that you have followed our code guidelines at that your design and implementation are solid. You are expected to revision your code based on reviewers comments. -- You should receive your bounty as soon as your PR is approved and merged by one of our maintainers. +- You should receive your bounty as soon as your PR is approved and merged by one of our maintainers. Please review our funded issues program [legal notes](https://github.com/spacemeshos/go-spacemesh/blob/master/legal.md). + +# Handling database schema changes + +go-spacemesh currently maintains 2 SQLite databases in the data folder: `state.sql` (state database) and `local.sql` (local database). It employs schema versioning for both databases, with a possibility to upgrade older versions of each database to the current schema version by means of running a series of migrations. Also, go-spacemesh tracks any schema drift (unexpected schema changes) in the databases. + +When a database is first created, the corresponding schema file embedded in go-spacemesh executable is used to initialize it: +* `sql/statesql/schema/schema.sql` for `state.sql` +* `sql/localsql/schema/schema.sql` for `local.sql` +The schema file includes `PRAGMA user_version = ...` which sets the version of the database schema. The version of the schema is equal to the number of migrations defined for the corresponding database (`state.sql` or `local.sql`). + +For an existing database, the `PRAGMA user_version` is checked against the expected version number. If the database's schema version is too new, go-spacemesh fails right away as an older go-spacemesh version cannot be expected to work with a database from a newer version. If the database version number is older than the expected version, go-spacemesh runs the necessary migration scripts embedded in go-spacemesh executable and updates `PRAGMA user_version = ...`. The migration scripts are located in the following folders: +* `sql/statesql/schema/migrations` for `state.sql` +* `sql/localsql/schema/migrations` for `local.sql` + +Additionally, some migrations ("coded migrations") can be implemented in Go code, in which case they reside in `.go` files located in `sql/statesql` and `sql/localsql` packages, respectively. It is worth noting that old coded migrations can be removed at some point, rendering database versions that are *too* old unusable with newer go-spacemesh versions. + +After all the migrations are run, go-spacemesh compares the schema of each database to the embedded schema scripts and if they differ, fails with an error message: +``` +Error: open sqlite db schema drift detected (uri file:data/state.sql): + ( + """ + ... // 82 identical lines + PRIMARY KEY (layer, block) + ); ++ CREATE TABLE foo(id int); + CREATE TABLE identities + ( + ... // 66 identical lines + """ + ) +``` + +In this case, a table named `foo` has somehow appeared in the database, causing go-spacemesh to fail due to the schema drift. The possible reasons for schema drift can be the following: +* running an unreleased version of go-spacemesh using your data folder. The unreleased version may contain migrations that may be changed before the release happens +* manual changes in the database +* external SQLite tooling used on the database that adds some tables, indices etc. + +In case if you want to run go-spacemesh with schema drift anyway, you can set `main.db-allow-schema-drift` to true. In this case, a warning with schema diff will be logged instead of failing. + +The schema changes in go-spacemesh code should be always done by means of adding migrations. Let's for example create a new migration (use zero-padded N+1 instead of 0010 with N being the number of the last migration for the local db): + +```console +$ echo 'CREATE TABLE foo(id int);' >sql/localsql/schema/migrations/0010_foo.sql +``` + +After that, we update the schema files +```console +$ make generate +$ # alternative: cd sql/localsql && go generate +$ git diff sql/localsql/schema/schema.sql +diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql +index 02c44d3cc..ebcdf4278 100755 +--- a/sql/localsql/schema/schema.sql ++++ b/sql/localsql/schema/schema.sql +@@ -1,4 +1,4 @@ +-PRAGMA user_version = 9; ++PRAGMA user_version = 10; + CREATE TABLE atx_sync_requests + ( + epoch INT NOT NULL, +@@ -24,6 +24,7 @@ CREATE TABLE "challenge" + post_indices VARCHAR, + post_pow UNSIGNED LONG INT + , poet_proof_ref CHAR(32), poet_proof_membership VARCHAR) WITHOUT ROWID; ++CREATE TABLE foo(id int); + CREATE TABLE malfeasance_sync_state + ( + id INT NOT NULL PRIMARY KEY, +``` + +Note that the changes include both the new table and an updated `PRAGMA user_version` line. +The changes in the schema file must be committed along with the migration we added. +```console +$ git add sql/localsql/schema/migrations/0010_foo.sql sql/localsql/schema.sql +$ git commit -m "sql: add a test migration" +``` diff --git a/README.md b/README.md index 3c637dbebe..82ea628f19 100644 --- a/README.md +++ b/README.md @@ -514,86 +514,6 @@ $ grpcurl -plaintext 127.0.0.1:9093 spacemesh.v1.DebugService.NetworkInfo } ``` -#### Handling database schema changes - -go-spacemesh currently maintains 2 SQLite databases in the data folder: `state.sql` (state database) and `local.sql` (local database). It employs schema versioning for both databases, with a possibility to upgrade older versions of each database to the current schema version by means of running a series of migrations. Also, go-spacemesh tracks any schema drift (unexpected schema changes) in the databases. - -When a database is first created, the corresponding schema file embedded in go-spacemesh executable is used to initialize it: -* `sql/statesql/schema/schema.sql` for `state.sql` -* `sql/localsql/schema/schema.sql` for `local.sql` -The schema file includes `PRAGMA user_version = ...` which sets the version of the database schema. The version of the schema is equal to the number of migrations defined for the corresponding database (`state.sql` or `local.sql`). - -For an existing database, the `PRAGMA user_version` is checked against the expected version number. If the database's schema version is too new, go-spacemesh fails right away as an older go-spacemesh version cannot be expected to work with a database from a newer version. If the database version number is older than the expected version, go-spacemesh runs the necessary migration scripts embedded in go-spacemesh executable and updates `PRAGMA user_version = ...`. The migration scripts are located in the following folders: -* `sql/statesql/schema/migrations` for `state.sql` -* `sql/localsql/schema/migrations` for `local.sql` - -Additionally, some migrations ("coded migrations") can be implemented in Go code, in which case they reside in `.go` files located in `sql/statesql` and `sql/localsql` packages, respectively. It is worth noting that old coded migrations can be removed at some point, rendering database versions that are *too* old unusable with newer go-spacemesh versions. - -After all the migrations are run, go-spacemesh compares the schema of each database to the embedded schema scripts and if they differ, warns the user about any differences: -``` -logger.go:146: 2024-06-05T05:39:32.247+0400 WARN database schema drift detected {"uri": "file:/var/folders/r0/4mks2v4n5ysbntnf3xq6h_q80000gn/T/TestSchemaidempotent_migration3425594786/001/test.db", "diff": " (\n \t\"\"\"\n \t... // 81 identical lines\n \t PRIMARY KEY (kind, epoch)\n \t) WITHOUT ROWID;\n- \t\n- \t-- some change\n \t\"\"\"\n )\n"} -``` - -In this case, an empty line and `-- some change` was added to `schema.sql` by hand. The pretty-printed diff looks like this: -``` - ( - """ - ... // 81 identical lines - PRIMARY KEY (kind, epoch) - ) WITHOUT ROWID; -- -- -- some change - """ - ) -``` - -The possible reasons for schema drift can be the following: -* running an unreleased version of go-spacemesh using your data folder. The unreleased version may contain migrations that may be changed before the release happens -* manual changes in the database -* external SQLite tooling used on the database that adds some tables, indices etc. - -In the latter case, it is possible to make go-spacemesh ignore certain objects (tables and indices) when checking for schema drift. For this, you can use `main.db-schema-ignore-rx` setting to set a regular expression that is used to ignore tables and indices in the database during schema drift checks. The setting defaults to `_litestream` to help with certain tooling. - -The schema changes in go-spacemesh code should be always done by means of adding migrations. After that, the schema tests in `sql/localsql` and `sql/statesql` will start failing. When the tests fail, they display the difference between the schema stored in `schema.sql` and the schema that is loaded from the database after running all the migrations. -If the schema changes shown in the diff are expected, the schema file needs to be updated. - -```console -$ # run the tests -$ eval $(make print-test-env) go test ./sql/localsql ./sql/statesql -... -=== RUN TestSchema/schema/force_migrations - test.go:106: updated schema written to schema/schema.sql.updated - test.go:108: - Error Trace: /Users/user/spacemesh/go-spacemesh/sql/test/test.go:108 - Error: Should be empty, but was ( - """ - ... // 81 identical lines - PRIMARY KEY (kind, epoch) - ) WITHOUT ROWID; - - -- some change - """ - ) - Test: TestSchema/schema/force_migrations - Messages: schema diff -FAIL -FAIL github.com/spacemeshos/go-spacemesh/sql/localsql 0.163s -ok github.com/spacemeshos/go-spacemesh/sql/statesql 0.286s -FAIL -$ git status -... -Untracked files: - (use "git add ..." to include in what will be committed) - sql/localsql/schema/schema.sql.updated - -$ # update the schema file -$ mv sql/localsql/schema/schema.sql{.updated,} - -$ # rerun the tests -$ eval $(make print-test-env) go test -count=1 ./sql/localsql ./sql/statesql -ok github.com/spacemeshos/go-spacemesh/sql/localsql 0.166s -ok github.com/spacemeshos/go-spacemesh/sql/statesql 0.293s -``` - #### Next Steps - Please visit our [wiki](https://github.com/spacemeshos/go-spacemesh/wiki) From 30dcf7188666d4a3bdf8f272c51e8217d6bcc2fd Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 06:04:26 +0400 Subject: [PATCH 25/62] sql, malsync: fix handling of context cancelation --- sql/database.go | 26 ++++++++++++++++++-------- syncer/malsync/syncer.go | 5 +++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/sql/database.go b/sql/database.go index 625db3aa71..6624f1c2a6 100644 --- a/sql/database.go +++ b/sql/database.go @@ -447,11 +447,7 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in for { row, err := stmt.Step() if err != nil { - code := sqlite.ErrCode(err) - if code == sqlite.SQLITE_CONSTRAINT_PRIMARYKEY || code == sqlite.SQLITE_CONSTRAINT_UNIQUE { - return 0, ErrObjectExists - } - return 0, fmt.Errorf("step %d: %w", rows, err) + return 0, fmt.Errorf("step %d: %w", rows, fixError(err)) } if !row { return rows, nil @@ -483,7 +479,7 @@ func (tx *sqliteTx) begin(initstmt string) error { stmt := tx.conn.Prep(initstmt) _, err := stmt.Step() if err != nil { - return fmt.Errorf("begin: %w", err) + return fmt.Errorf("begin: %w", fixError(err)) } return nil } @@ -493,7 +489,7 @@ func (tx *sqliteTx) Commit() error { stmt := tx.conn.Prep("COMMIT;") _, tx.err = stmt.Step() if tx.err != nil { - return tx.err + return fixError(tx.err) } tx.committed = true return nil @@ -507,7 +503,7 @@ func (tx *sqliteTx) Release() error { } stmt := tx.conn.Prep("ROLLBACK") _, tx.err = stmt.Step() - return tx.err + return fixError(tx.err) } // Exec query. @@ -522,6 +518,20 @@ func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, e return exec(tx.conn, query, encoder, decoder) } +func fixError(err error) error { + code := sqlite.ErrCode(err) + if code == sqlite.SQLITE_CONSTRAINT_PRIMARYKEY || code == sqlite.SQLITE_CONSTRAINT_UNIQUE { + return ErrObjectExists + } + if code == sqlite.SQLITE_INTERRUPT { + // TODO: we probably should check if there was indeed a context + // that was canceled. But we're likely to replace crawshaw library + // in future so this part should be rewritten anyway + return context.Canceled + } + return err +} + // Blob represents a binary blob data. It can be reused efficiently // across multiple data retrieval operations, minimizing reallocations // of the underlying byte slice. diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go index 86b09e2002..3a14388852 100644 --- a/syncer/malsync/syncer.go +++ b/syncer/malsync/syncer.go @@ -344,6 +344,11 @@ func (s *Syncer) updateState(ctx context.Context) error { if err := s.localdb.WithTx(ctx, func(tx sql.Transaction) error { return malsync.UpdateSyncState(tx, s.clock.Now()) }); err != nil { + if ctx.Err() != nil { + // FIXME: with crawshaw, canceling the context which has been used to get + // a connection from the pool may cause "database: no free connection" errors + err = ctx.Err() + } return fmt.Errorf("error updating malsync state: %w", err) } From 850611477459305d17bf4c339df8bc73fdf49b7a Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 06:26:58 +0400 Subject: [PATCH 26/62] sql: split Schema.Migrate() method --- sql/database.go | 28 ++++++++++++++++++++++------ sql/database_test.go | 16 ++++++++++++++-- sql/schema.go | 36 ++++++++++++------------------------ 3 files changed, 48 insertions(+), 32 deletions(-) diff --git a/sql/database.go b/sql/database.go index 6624f1c2a6..d1c8c1bc66 100644 --- a/sql/database.go +++ b/sql/database.go @@ -205,6 +205,7 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { for _, opt := range opts { opt(config) } + logger := config.logger.With(zap.String("uri", uri)) var flags sqlite.OpenFlags if !config.forceFresh { flags = sqlite.SQLITE_OPEN_READWRITE | @@ -236,11 +237,26 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { db.Close()) } } else { - if err := config.schema.Migrate( - config.logger.With(zap.String("uri", uri)), - db, config.vacuumState, config.enableMigrations, - ); err != nil { + before, after, err := config.schema.CheckDBVersion(logger, db) + switch { + case err != nil: return nil, errors.Join(err, db.Close()) + case before != after && config.enableMigrations: + logger.Info("running migrations", + zap.Int("current version", before), + zap.Int("target version", after), + ) + if err := config.schema.Migrate( + logger, db, before, config.vacuumState, + ); err != nil { + return nil, errors.Join(err, db.Close()) + } + case before != after: + logger.Error("database version is too old", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return nil, fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after) } } @@ -253,7 +269,7 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { switch { case diff == "": // ok case config.allowSchemaDrift: - config.logger.Warn("database schema drift detected", + logger.Warn("database schema drift detected", zap.String("uri", uri), zap.String("diff", diff), ) @@ -265,7 +281,7 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { } if config.cache { - config.logger.Debug("using query cache", zap.Any("sizes", config.cacheSizes)) + logger.Debug("using query cache", zap.Any("sizes", config.cacheSizes)) db.queryCache = &queryCache{cacheSizesByKind: config.cacheSizes} } db.queryCount.Store(0) diff --git a/sql/database_test.go b/sql/database_test.go index 12406899a3..3b899ff779 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -17,6 +17,7 @@ import ( func Test_Transaction_Isolation(t *testing.T) { db := InMemory( + WithLogger(zaptest.NewLogger(t)), WithConnections(10), WithLatencyMetering(true), WithDatabaseSchema(&Schema{ @@ -78,6 +79,7 @@ func Test_Migration_Rollback(t *testing.T) { } func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { + logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) migration1.EXPECT().Name().Return("test").AnyTimes() @@ -86,6 +88,7 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { dbFile := filepath.Join(t.TempDir(), "test.sql") db, err := Open("file:"+dbFile, + WithLogger(logger), WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1}, }), @@ -102,6 +105,7 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { migration2.EXPECT().Rollback().Return(nil) _, err = Open("file:"+dbFile, + WithLogger(logger), WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1, migration2}, }), @@ -165,6 +169,7 @@ func TestDatabaseSkipMigrations(t *testing.T) { func TestDatabaseVacuumState(t *testing.T) { dir := t.TempDir() + logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) @@ -177,6 +182,7 @@ func TestDatabaseVacuumState(t *testing.T) { dbFile := filepath.Join(dir, "test.sql") db, err := Open("file:"+dbFile, + WithLogger(logger), WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1}, }), @@ -187,6 +193,7 @@ func TestDatabaseVacuumState(t *testing.T) { require.NoError(t, db.Close()) db, err = Open("file:"+dbFile, + WithLogger(logger), WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1, migration2}, }), @@ -202,7 +209,7 @@ func TestDatabaseVacuumState(t *testing.T) { } func TestQueryCount(t *testing.T) { - db := InMemory(WithIgnoreSchemaDrift()) + db := InMemory(WithLogger(zaptest.NewLogger(t)), WithIgnoreSchemaDrift()) require.Equal(t, 0, db.QueryCount()) n, err := db.Exec("select 1", nil, nil) @@ -217,6 +224,7 @@ func TestQueryCount(t *testing.T) { func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { dir := t.TempDir() + logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) migration1 := NewMockMigration(ctrl) @@ -226,13 +234,17 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { migration2.EXPECT().Order().Return(2).AnyTimes() dbFile := filepath.Join(dir, "test.sql") - db, err := Open("file:"+dbFile, WithForceMigrations(true), WithIgnoreSchemaDrift()) + db, err := Open("file:"+dbFile, + WithLogger(logger), + WithForceMigrations(true), + WithIgnoreSchemaDrift()) require.NoError(t, err) _, err = db.Exec("PRAGMA user_version = 3", nil, nil) require.NoError(t, err) require.NoError(t, db.Close()) _, err = Open("file:"+dbFile, + WithLogger(logger), WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1, migration2}, }), diff --git a/sql/schema.go b/sql/schema.go index ab03125194..04c0d51fbd 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -97,18 +97,15 @@ func (s *Schema) Apply(db Database) error { }) } -// Migrate performs database migration. In case if migrations are disabled, the database -// version is checked but no migrations are run, and if the database is too old and -// migrations are disabled, an error is returned. -func (s *Schema) Migrate(logger *zap.Logger, db Database, vacuumState int, enable bool) error { +func (s *Schema) CheckDBVersion(logger *zap.Logger, db Database) (before, after int, err error) { if len(s.Migrations) == 0 { - return nil + return 0, 0, nil } - before, err := version(db) + before, err = version(db) if err != nil { - return err + return 0, 0, err } - after := 0 + after = 0 if len(s.Migrations) > 0 { after = s.Migrations.Version() } @@ -117,25 +114,16 @@ func (s *Schema) Migrate(logger *zap.Logger, db Database, vacuumState int, enabl zap.Int("current version", before), zap.Int("target version", after), ) - return fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) + return before, after, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) } - if before == after { - return nil - } - - if !enable { - logger.Error("database version is too old", - zap.Int("current version", before), - zap.Int("target version", after), - ) - return fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after) - } + return before, after, nil +} - logger.Info("running migrations", - zap.Int("current version", before), - zap.Int("target version", after), - ) +// Migrate performs database migration. In case if migrations are disabled, the database +// version is checked but no migrations are run, and if the database is too old and +// migrations are disabled, an error is returned. +func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState int) error { for i, m := range s.Migrations { if m.Order() <= before { continue From feb39fe1bded878228c45319fcd245c85ad5b5e3 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 07:00:18 +0400 Subject: [PATCH 27/62] sql: another fix for Windows newlines in the schema --- sql/localsql/localsql.go | 6 +++++- sql/statesql/statesql.go | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go index 17cbc70765..a257c6b19e 100644 --- a/sql/localsql/localsql.go +++ b/sql/localsql/localsql.go @@ -2,6 +2,7 @@ package localsql import ( "embed" + "strings" "github.com/spacemeshos/go-spacemesh/sql" ) @@ -30,7 +31,10 @@ func Schema() (*sql.Schema, error) { } // NOTE: coded state migrations can be added here // They can be a part of this localsql package - return &sql.Schema{Script: schemaScript, Migrations: sqlMigrations}, nil + return &sql.Schema{ + Script: strings.ReplaceAll(schemaScript, "\r", ""), + Migrations: sqlMigrations, + }, nil } // Open opens a local database. diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 85de4f20bc..7dad4a4b48 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -2,6 +2,7 @@ package statesql import ( "embed" + "strings" "github.com/spacemeshos/go-spacemesh/sql" ) @@ -30,7 +31,10 @@ func Schema() (*sql.Schema, error) { } // NOTE: coded state migrations can be added here // They can be a part of this localsql package - return &sql.Schema{Script: schemaScript, Migrations: sqlMigrations}, nil + return &sql.Schema{ + Script: strings.ReplaceAll(schemaScript, "\r", ""), + Migrations: sqlMigrations, + }, nil } // Open opens a state database. From 77b46d12b7b82b9b119f334267966a7f204ad38f Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 07:57:10 +0400 Subject: [PATCH 28/62] merge-nodes: fix test naming --- cmd/merge-nodes/internal/merge_action_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/merge-nodes/internal/merge_action_test.go b/cmd/merge-nodes/internal/merge_action_test.go index 3c262f08ba..a323a16b9e 100644 --- a/cmd/merge-nodes/internal/merge_action_test.go +++ b/cmd/merge-nodes/internal/merge_action_test.go @@ -31,7 +31,7 @@ func oldSchema(t *testing.T) *sql.Schema { return schema } -func Test_MergeDBs_InvalidTargetScheme(t *testing.T) { +func Test_MergeDBs_InvalidTargetSchema(t *testing.T) { tmpDst := t.TempDir() db, err := localsql.Open("file:"+filepath.Join(tmpDst, localDbFile), @@ -85,7 +85,7 @@ func Test_MergeDBs_InvalidSourcePath(t *testing.T) { require.ErrorIs(t, err, fs.ErrNotExist) } -func Test_MergeDBs_InvalidSourceScheme(t *testing.T) { +func Test_MergeDBs_InvalidSourceSchema(t *testing.T) { tmpDst := t.TempDir() db, err := localsql.Open("file:" + filepath.Join(tmpDst, localDbFile)) From ee36ee40c078cc36e7ec4463fc26645262fbf9d2 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 13 Jun 2024 07:57:27 +0400 Subject: [PATCH 29/62] sql: close db on schema errors --- sql/database.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/database.go b/sql/database.go index d1c8c1bc66..4c732e8b24 100644 --- a/sql/database.go +++ b/sql/database.go @@ -256,14 +256,18 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { zap.Int("current version", before), zap.Int("target version", after), ) - return nil, fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after) + return nil, errors.Join( + fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after), + db.Close()) } } if !config.ignoreSchemaDrift { loaded, err := LoadDBSchemaScript(db) if err != nil { - return nil, fmt.Errorf("error loading database schema: %w", err) + return nil, errors.Join( + fmt.Errorf("error loading database schema: %w", err), + db.Close()) } diff := config.schema.Diff(loaded) switch { From da6ab28a2bb6f4b0e688ffc35bcab248e028d23c Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 17 Jun 2024 00:17:51 +0400 Subject: [PATCH 30/62] sql, datastore: remove unneeded mocks --- datastore/mocks/mocks.go | 10 - datastore/store.go | 2 - sql/database.go | 2 +- sql/mocks/mocks.go | 1656 +------------------------------------- 4 files changed, 3 insertions(+), 1667 deletions(-) delete mode 100644 datastore/mocks/mocks.go diff --git a/datastore/mocks/mocks.go b/datastore/mocks/mocks.go deleted file mode 100644 index fc3ab85e8e..0000000000 --- a/datastore/mocks/mocks.go +++ /dev/null @@ -1,10 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ./store.go -// -// Generated by this command: -// -// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./store.go -// - -// Package mocks is a generated GoMock package. -package mocks diff --git a/datastore/store.go b/datastore/store.go index 3042764367..014f23a7e7 100644 --- a/datastore/store.go +++ b/datastore/store.go @@ -30,8 +30,6 @@ type VrfNonceKey struct { Epoch types.EpochID } -//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./store.go - // CachedDB is simply a database injected with cache. type CachedDB struct { sql.Database diff --git a/sql/database.go b/sql/database.go index 4c732e8b24..7d7fb8460c 100644 --- a/sql/database.go +++ b/sql/database.go @@ -38,7 +38,7 @@ const ( beginImmediate = "BEGIN IMMEDIATE;" ) -//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./database.go +//go:generate mockgen -typed -package=mocks -destination=./mocks/mocks.go github.com/spacemeshos/go-spacemesh/sql Executor // Executor is an interface for executing raw statement. type Executor interface { diff --git a/sql/mocks/mocks.go b/sql/mocks/mocks.go index 1a40d1adb6..8d93a4a118 100644 --- a/sql/mocks/mocks.go +++ b/sql/mocks/mocks.go @@ -1,16 +1,15 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ./database.go +// Source: github.com/spacemeshos/go-spacemesh/sql (interfaces: Executor) // // Generated by this command: // -// mockgen -typed -package=mocks -destination=./mocks/mocks.go -source=./database.go +// mockgen -typed -package=mocks -destination=./mocks/mocks.go github.com/spacemeshos/go-spacemesh/sql Executor // // Package mocks is a generated GoMock package. package mocks import ( - context "context" reflect "reflect" sql "github.com/spacemeshos/go-spacemesh/sql" @@ -78,1654 +77,3 @@ func (c *MockExecutorExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decod c.Call = c.Call.DoAndReturn(f) return c } - -// MockDatabase is a mock of Database interface. -type MockDatabase struct { - ctrl *gomock.Controller - recorder *MockDatabaseMockRecorder -} - -// MockDatabaseMockRecorder is the mock recorder for MockDatabase. -type MockDatabaseMockRecorder struct { - mock *MockDatabase -} - -// NewMockDatabase creates a new mock instance. -func NewMockDatabase(ctrl *gomock.Controller) *MockDatabase { - mock := &MockDatabase{ctrl: ctrl} - mock.recorder = &MockDatabaseMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder { - return m.recorder -} - -// ClearCache mocks base method. -func (m *MockDatabase) ClearCache() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClearCache") -} - -// ClearCache indicates an expected call of ClearCache. -func (mr *MockDatabaseMockRecorder) ClearCache() *MockDatabaseClearCacheCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearCache", reflect.TypeOf((*MockDatabase)(nil).ClearCache)) - return &MockDatabaseClearCacheCall{Call: call} -} - -// MockDatabaseClearCacheCall wrap *gomock.Call -type MockDatabaseClearCacheCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseClearCacheCall) Return() *MockDatabaseClearCacheCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseClearCacheCall) Do(f func()) *MockDatabaseClearCacheCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseClearCacheCall) DoAndReturn(f func()) *MockDatabaseClearCacheCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Close mocks base method. -func (m *MockDatabase) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockDatabaseMockRecorder) Close() *MockDatabaseCloseCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDatabase)(nil).Close)) - return &MockDatabaseCloseCall{Call: call} -} - -// MockDatabaseCloseCall wrap *gomock.Call -type MockDatabaseCloseCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseCloseCall) Return(arg0 error) *MockDatabaseCloseCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseCloseCall) Do(f func() error) *MockDatabaseCloseCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseCloseCall) DoAndReturn(f func() error) *MockDatabaseCloseCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Exec mocks base method. -func (m *MockDatabase) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec. -func (mr *MockDatabaseMockRecorder) Exec(arg0, arg1, arg2 any) *MockDatabaseExecCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockDatabase)(nil).Exec), arg0, arg1, arg2) - return &MockDatabaseExecCall{Call: call} -} - -// MockDatabaseExecCall wrap *gomock.Call -type MockDatabaseExecCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseExecCall) Return(arg0 int, arg1 error) *MockDatabaseExecCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockDatabaseExecCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockDatabaseExecCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// GetValue mocks base method. -func (m *MockDatabase) GetValue(ctx context.Context, key sql.QueryCacheItemKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetValue", ctx, key, subKey, retrieve) - ret0, _ := ret[0].(any) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetValue indicates an expected call of GetValue. -func (mr *MockDatabaseMockRecorder) GetValue(ctx, key, subKey, retrieve any) *MockDatabaseGetValueCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockDatabase)(nil).GetValue), ctx, key, subKey, retrieve) - return &MockDatabaseGetValueCall{Call: call} -} - -// MockDatabaseGetValueCall wrap *gomock.Call -type MockDatabaseGetValueCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseGetValueCall) Return(arg0 any, arg1 error) *MockDatabaseGetValueCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseGetValueCall) Do(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockDatabaseGetValueCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseGetValueCall) DoAndReturn(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockDatabaseGetValueCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// IsCached mocks base method. -func (m *MockDatabase) IsCached() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsCached") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsCached indicates an expected call of IsCached. -func (mr *MockDatabaseMockRecorder) IsCached() *MockDatabaseIsCachedCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockDatabase)(nil).IsCached)) - return &MockDatabaseIsCachedCall{Call: call} -} - -// MockDatabaseIsCachedCall wrap *gomock.Call -type MockDatabaseIsCachedCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseIsCachedCall) Return(arg0 bool) *MockDatabaseIsCachedCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseIsCachedCall) Do(f func() bool) *MockDatabaseIsCachedCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseIsCachedCall) DoAndReturn(f func() bool) *MockDatabaseIsCachedCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// QueryCache mocks base method. -func (m *MockDatabase) QueryCache() sql.QueryCache { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryCache") - ret0, _ := ret[0].(sql.QueryCache) - return ret0 -} - -// QueryCache indicates an expected call of QueryCache. -func (mr *MockDatabaseMockRecorder) QueryCache() *MockDatabaseQueryCacheCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCache", reflect.TypeOf((*MockDatabase)(nil).QueryCache)) - return &MockDatabaseQueryCacheCall{Call: call} -} - -// MockDatabaseQueryCacheCall wrap *gomock.Call -type MockDatabaseQueryCacheCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseQueryCacheCall) Return(arg0 sql.QueryCache) *MockDatabaseQueryCacheCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseQueryCacheCall) Do(f func() sql.QueryCache) *MockDatabaseQueryCacheCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseQueryCacheCall) DoAndReturn(f func() sql.QueryCache) *MockDatabaseQueryCacheCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// QueryCount mocks base method. -func (m *MockDatabase) QueryCount() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryCount") - ret0, _ := ret[0].(int) - return ret0 -} - -// QueryCount indicates an expected call of QueryCount. -func (mr *MockDatabaseMockRecorder) QueryCount() *MockDatabaseQueryCountCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCount", reflect.TypeOf((*MockDatabase)(nil).QueryCount)) - return &MockDatabaseQueryCountCall{Call: call} -} - -// MockDatabaseQueryCountCall wrap *gomock.Call -type MockDatabaseQueryCountCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseQueryCountCall) Return(arg0 int) *MockDatabaseQueryCountCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseQueryCountCall) Do(f func() int) *MockDatabaseQueryCountCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseQueryCountCall) DoAndReturn(f func() int) *MockDatabaseQueryCountCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Tx mocks base method. -func (m *MockDatabase) Tx(ctx context.Context) (sql.Transaction, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Tx", ctx) - ret0, _ := ret[0].(sql.Transaction) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Tx indicates an expected call of Tx. -func (mr *MockDatabaseMockRecorder) Tx(ctx any) *MockDatabaseTxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tx", reflect.TypeOf((*MockDatabase)(nil).Tx), ctx) - return &MockDatabaseTxCall{Call: call} -} - -// MockDatabaseTxCall wrap *gomock.Call -type MockDatabaseTxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseTxCall) Return(arg0 sql.Transaction, arg1 error) *MockDatabaseTxCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseTxCall) Do(f func(context.Context) (sql.Transaction, error)) *MockDatabaseTxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseTxCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockDatabaseTxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// TxImmediate mocks base method. -func (m *MockDatabase) TxImmediate(ctx context.Context) (sql.Transaction, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TxImmediate", ctx) - ret0, _ := ret[0].(sql.Transaction) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// TxImmediate indicates an expected call of TxImmediate. -func (mr *MockDatabaseMockRecorder) TxImmediate(ctx any) *MockDatabaseTxImmediateCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxImmediate", reflect.TypeOf((*MockDatabase)(nil).TxImmediate), ctx) - return &MockDatabaseTxImmediateCall{Call: call} -} - -// MockDatabaseTxImmediateCall wrap *gomock.Call -type MockDatabaseTxImmediateCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseTxImmediateCall) Return(arg0 sql.Transaction, arg1 error) *MockDatabaseTxImmediateCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseTxImmediateCall) Do(f func(context.Context) (sql.Transaction, error)) *MockDatabaseTxImmediateCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockDatabaseTxImmediateCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// UpdateSlice mocks base method. -func (m *MockDatabase) UpdateSlice(key sql.QueryCacheItemKey, update sql.SliceAppender) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateSlice", key, update) -} - -// UpdateSlice indicates an expected call of UpdateSlice. -func (mr *MockDatabaseMockRecorder) UpdateSlice(key, update any) *MockDatabaseUpdateSliceCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSlice", reflect.TypeOf((*MockDatabase)(nil).UpdateSlice), key, update) - return &MockDatabaseUpdateSliceCall{Call: call} -} - -// MockDatabaseUpdateSliceCall wrap *gomock.Call -type MockDatabaseUpdateSliceCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseUpdateSliceCall) Return() *MockDatabaseUpdateSliceCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseUpdateSliceCall) Do(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockDatabaseUpdateSliceCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseUpdateSliceCall) DoAndReturn(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockDatabaseUpdateSliceCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// WithTx mocks base method. -func (m *MockDatabase) WithTx(ctx context.Context, exec func(sql.Transaction) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTx", ctx, exec) - ret0, _ := ret[0].(error) - return ret0 -} - -// WithTx indicates an expected call of WithTx. -func (mr *MockDatabaseMockRecorder) WithTx(ctx, exec any) *MockDatabaseWithTxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTx", reflect.TypeOf((*MockDatabase)(nil).WithTx), ctx, exec) - return &MockDatabaseWithTxCall{Call: call} -} - -// MockDatabaseWithTxCall wrap *gomock.Call -type MockDatabaseWithTxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseWithTxCall) Return(arg0 error) *MockDatabaseWithTxCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseWithTxCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockDatabaseWithTxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseWithTxCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockDatabaseWithTxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// WithTxImmediate mocks base method. -func (m *MockDatabase) WithTxImmediate(ctx context.Context, exec func(sql.Transaction) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTxImmediate", ctx, exec) - ret0, _ := ret[0].(error) - return ret0 -} - -// WithTxImmediate indicates an expected call of WithTxImmediate. -func (mr *MockDatabaseMockRecorder) WithTxImmediate(ctx, exec any) *MockDatabaseWithTxImmediateCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTxImmediate", reflect.TypeOf((*MockDatabase)(nil).WithTxImmediate), ctx, exec) - return &MockDatabaseWithTxImmediateCall{Call: call} -} - -// MockDatabaseWithTxImmediateCall wrap *gomock.Call -type MockDatabaseWithTxImmediateCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDatabaseWithTxImmediateCall) Return(arg0 error) *MockDatabaseWithTxImmediateCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDatabaseWithTxImmediateCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockDatabaseWithTxImmediateCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDatabaseWithTxImmediateCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockDatabaseWithTxImmediateCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// MockTransaction is a mock of Transaction interface. -type MockTransaction struct { - ctrl *gomock.Controller - recorder *MockTransactionMockRecorder -} - -// MockTransactionMockRecorder is the mock recorder for MockTransaction. -type MockTransactionMockRecorder struct { - mock *MockTransaction -} - -// NewMockTransaction creates a new mock instance. -func NewMockTransaction(ctrl *gomock.Controller) *MockTransaction { - mock := &MockTransaction{ctrl: ctrl} - mock.recorder = &MockTransactionMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder { - return m.recorder -} - -// Commit mocks base method. -func (m *MockTransaction) Commit() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit") - ret0, _ := ret[0].(error) - return ret0 -} - -// Commit indicates an expected call of Commit. -func (mr *MockTransactionMockRecorder) Commit() *MockTransactionCommitCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit)) - return &MockTransactionCommitCall{Call: call} -} - -// MockTransactionCommitCall wrap *gomock.Call -type MockTransactionCommitCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockTransactionCommitCall) Return(arg0 error) *MockTransactionCommitCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockTransactionCommitCall) Do(f func() error) *MockTransactionCommitCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockTransactionCommitCall) DoAndReturn(f func() error) *MockTransactionCommitCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Exec mocks base method. -func (m *MockTransaction) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec. -func (mr *MockTransactionMockRecorder) Exec(arg0, arg1, arg2 any) *MockTransactionExecCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), arg0, arg1, arg2) - return &MockTransactionExecCall{Call: call} -} - -// MockTransactionExecCall wrap *gomock.Call -type MockTransactionExecCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockTransactionExecCall) Return(arg0 int, arg1 error) *MockTransactionExecCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockTransactionExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockTransactionExecCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockTransactionExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockTransactionExecCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Release mocks base method. -func (m *MockTransaction) Release() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Release") - ret0, _ := ret[0].(error) - return ret0 -} - -// Release indicates an expected call of Release. -func (mr *MockTransactionMockRecorder) Release() *MockTransactionReleaseCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockTransaction)(nil).Release)) - return &MockTransactionReleaseCall{Call: call} -} - -// MockTransactionReleaseCall wrap *gomock.Call -type MockTransactionReleaseCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockTransactionReleaseCall) Return(arg0 error) *MockTransactionReleaseCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockTransactionReleaseCall) Do(f func() error) *MockTransactionReleaseCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockTransactionReleaseCall) DoAndReturn(f func() error) *MockTransactionReleaseCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// MockStateDatabase is a mock of StateDatabase interface. -type MockStateDatabase struct { - ctrl *gomock.Controller - recorder *MockStateDatabaseMockRecorder -} - -// MockStateDatabaseMockRecorder is the mock recorder for MockStateDatabase. -type MockStateDatabaseMockRecorder struct { - mock *MockStateDatabase -} - -// NewMockStateDatabase creates a new mock instance. -func NewMockStateDatabase(ctrl *gomock.Controller) *MockStateDatabase { - mock := &MockStateDatabase{ctrl: ctrl} - mock.recorder = &MockStateDatabaseMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStateDatabase) EXPECT() *MockStateDatabaseMockRecorder { - return m.recorder -} - -// ClearCache mocks base method. -func (m *MockStateDatabase) ClearCache() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClearCache") -} - -// ClearCache indicates an expected call of ClearCache. -func (mr *MockStateDatabaseMockRecorder) ClearCache() *MockStateDatabaseClearCacheCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearCache", reflect.TypeOf((*MockStateDatabase)(nil).ClearCache)) - return &MockStateDatabaseClearCacheCall{Call: call} -} - -// MockStateDatabaseClearCacheCall wrap *gomock.Call -type MockStateDatabaseClearCacheCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseClearCacheCall) Return() *MockStateDatabaseClearCacheCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseClearCacheCall) Do(f func()) *MockStateDatabaseClearCacheCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseClearCacheCall) DoAndReturn(f func()) *MockStateDatabaseClearCacheCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Close mocks base method. -func (m *MockStateDatabase) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockStateDatabaseMockRecorder) Close() *MockStateDatabaseCloseCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStateDatabase)(nil).Close)) - return &MockStateDatabaseCloseCall{Call: call} -} - -// MockStateDatabaseCloseCall wrap *gomock.Call -type MockStateDatabaseCloseCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseCloseCall) Return(arg0 error) *MockStateDatabaseCloseCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseCloseCall) Do(f func() error) *MockStateDatabaseCloseCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseCloseCall) DoAndReturn(f func() error) *MockStateDatabaseCloseCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Exec mocks base method. -func (m *MockStateDatabase) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec. -func (mr *MockStateDatabaseMockRecorder) Exec(arg0, arg1, arg2 any) *MockStateDatabaseExecCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStateDatabase)(nil).Exec), arg0, arg1, arg2) - return &MockStateDatabaseExecCall{Call: call} -} - -// MockStateDatabaseExecCall wrap *gomock.Call -type MockStateDatabaseExecCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseExecCall) Return(arg0 int, arg1 error) *MockStateDatabaseExecCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockStateDatabaseExecCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockStateDatabaseExecCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// GetValue mocks base method. -func (m *MockStateDatabase) GetValue(ctx context.Context, key sql.QueryCacheItemKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetValue", ctx, key, subKey, retrieve) - ret0, _ := ret[0].(any) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetValue indicates an expected call of GetValue. -func (mr *MockStateDatabaseMockRecorder) GetValue(ctx, key, subKey, retrieve any) *MockStateDatabaseGetValueCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockStateDatabase)(nil).GetValue), ctx, key, subKey, retrieve) - return &MockStateDatabaseGetValueCall{Call: call} -} - -// MockStateDatabaseGetValueCall wrap *gomock.Call -type MockStateDatabaseGetValueCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseGetValueCall) Return(arg0 any, arg1 error) *MockStateDatabaseGetValueCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseGetValueCall) Do(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockStateDatabaseGetValueCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseGetValueCall) DoAndReturn(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockStateDatabaseGetValueCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// IsCached mocks base method. -func (m *MockStateDatabase) IsCached() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsCached") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsCached indicates an expected call of IsCached. -func (mr *MockStateDatabaseMockRecorder) IsCached() *MockStateDatabaseIsCachedCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockStateDatabase)(nil).IsCached)) - return &MockStateDatabaseIsCachedCall{Call: call} -} - -// MockStateDatabaseIsCachedCall wrap *gomock.Call -type MockStateDatabaseIsCachedCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseIsCachedCall) Return(arg0 bool) *MockStateDatabaseIsCachedCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseIsCachedCall) Do(f func() bool) *MockStateDatabaseIsCachedCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseIsCachedCall) DoAndReturn(f func() bool) *MockStateDatabaseIsCachedCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// IsStateDatabase mocks base method. -func (m *MockStateDatabase) IsStateDatabase() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsStateDatabase") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsStateDatabase indicates an expected call of IsStateDatabase. -func (mr *MockStateDatabaseMockRecorder) IsStateDatabase() *MockStateDatabaseIsStateDatabaseCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsStateDatabase", reflect.TypeOf((*MockStateDatabase)(nil).IsStateDatabase)) - return &MockStateDatabaseIsStateDatabaseCall{Call: call} -} - -// MockStateDatabaseIsStateDatabaseCall wrap *gomock.Call -type MockStateDatabaseIsStateDatabaseCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseIsStateDatabaseCall) Return(arg0 bool) *MockStateDatabaseIsStateDatabaseCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseIsStateDatabaseCall) Do(f func() bool) *MockStateDatabaseIsStateDatabaseCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseIsStateDatabaseCall) DoAndReturn(f func() bool) *MockStateDatabaseIsStateDatabaseCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// QueryCache mocks base method. -func (m *MockStateDatabase) QueryCache() sql.QueryCache { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryCache") - ret0, _ := ret[0].(sql.QueryCache) - return ret0 -} - -// QueryCache indicates an expected call of QueryCache. -func (mr *MockStateDatabaseMockRecorder) QueryCache() *MockStateDatabaseQueryCacheCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCache", reflect.TypeOf((*MockStateDatabase)(nil).QueryCache)) - return &MockStateDatabaseQueryCacheCall{Call: call} -} - -// MockStateDatabaseQueryCacheCall wrap *gomock.Call -type MockStateDatabaseQueryCacheCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseQueryCacheCall) Return(arg0 sql.QueryCache) *MockStateDatabaseQueryCacheCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseQueryCacheCall) Do(f func() sql.QueryCache) *MockStateDatabaseQueryCacheCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseQueryCacheCall) DoAndReturn(f func() sql.QueryCache) *MockStateDatabaseQueryCacheCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// QueryCount mocks base method. -func (m *MockStateDatabase) QueryCount() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryCount") - ret0, _ := ret[0].(int) - return ret0 -} - -// QueryCount indicates an expected call of QueryCount. -func (mr *MockStateDatabaseMockRecorder) QueryCount() *MockStateDatabaseQueryCountCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCount", reflect.TypeOf((*MockStateDatabase)(nil).QueryCount)) - return &MockStateDatabaseQueryCountCall{Call: call} -} - -// MockStateDatabaseQueryCountCall wrap *gomock.Call -type MockStateDatabaseQueryCountCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseQueryCountCall) Return(arg0 int) *MockStateDatabaseQueryCountCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseQueryCountCall) Do(f func() int) *MockStateDatabaseQueryCountCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseQueryCountCall) DoAndReturn(f func() int) *MockStateDatabaseQueryCountCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Tx mocks base method. -func (m *MockStateDatabase) Tx(ctx context.Context) (sql.Transaction, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Tx", ctx) - ret0, _ := ret[0].(sql.Transaction) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Tx indicates an expected call of Tx. -func (mr *MockStateDatabaseMockRecorder) Tx(ctx any) *MockStateDatabaseTxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tx", reflect.TypeOf((*MockStateDatabase)(nil).Tx), ctx) - return &MockStateDatabaseTxCall{Call: call} -} - -// MockStateDatabaseTxCall wrap *gomock.Call -type MockStateDatabaseTxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseTxCall) Return(arg0 sql.Transaction, arg1 error) *MockStateDatabaseTxCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseTxCall) Do(f func(context.Context) (sql.Transaction, error)) *MockStateDatabaseTxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseTxCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockStateDatabaseTxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// TxImmediate mocks base method. -func (m *MockStateDatabase) TxImmediate(ctx context.Context) (sql.Transaction, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TxImmediate", ctx) - ret0, _ := ret[0].(sql.Transaction) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// TxImmediate indicates an expected call of TxImmediate. -func (mr *MockStateDatabaseMockRecorder) TxImmediate(ctx any) *MockStateDatabaseTxImmediateCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxImmediate", reflect.TypeOf((*MockStateDatabase)(nil).TxImmediate), ctx) - return &MockStateDatabaseTxImmediateCall{Call: call} -} - -// MockStateDatabaseTxImmediateCall wrap *gomock.Call -type MockStateDatabaseTxImmediateCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseTxImmediateCall) Return(arg0 sql.Transaction, arg1 error) *MockStateDatabaseTxImmediateCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseTxImmediateCall) Do(f func(context.Context) (sql.Transaction, error)) *MockStateDatabaseTxImmediateCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockStateDatabaseTxImmediateCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// UpdateSlice mocks base method. -func (m *MockStateDatabase) UpdateSlice(key sql.QueryCacheItemKey, update sql.SliceAppender) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateSlice", key, update) -} - -// UpdateSlice indicates an expected call of UpdateSlice. -func (mr *MockStateDatabaseMockRecorder) UpdateSlice(key, update any) *MockStateDatabaseUpdateSliceCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSlice", reflect.TypeOf((*MockStateDatabase)(nil).UpdateSlice), key, update) - return &MockStateDatabaseUpdateSliceCall{Call: call} -} - -// MockStateDatabaseUpdateSliceCall wrap *gomock.Call -type MockStateDatabaseUpdateSliceCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseUpdateSliceCall) Return() *MockStateDatabaseUpdateSliceCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseUpdateSliceCall) Do(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockStateDatabaseUpdateSliceCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseUpdateSliceCall) DoAndReturn(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockStateDatabaseUpdateSliceCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// WithTx mocks base method. -func (m *MockStateDatabase) WithTx(ctx context.Context, exec func(sql.Transaction) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTx", ctx, exec) - ret0, _ := ret[0].(error) - return ret0 -} - -// WithTx indicates an expected call of WithTx. -func (mr *MockStateDatabaseMockRecorder) WithTx(ctx, exec any) *MockStateDatabaseWithTxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTx", reflect.TypeOf((*MockStateDatabase)(nil).WithTx), ctx, exec) - return &MockStateDatabaseWithTxCall{Call: call} -} - -// MockStateDatabaseWithTxCall wrap *gomock.Call -type MockStateDatabaseWithTxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseWithTxCall) Return(arg0 error) *MockStateDatabaseWithTxCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseWithTxCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockStateDatabaseWithTxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseWithTxCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockStateDatabaseWithTxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// WithTxImmediate mocks base method. -func (m *MockStateDatabase) WithTxImmediate(ctx context.Context, exec func(sql.Transaction) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTxImmediate", ctx, exec) - ret0, _ := ret[0].(error) - return ret0 -} - -// WithTxImmediate indicates an expected call of WithTxImmediate. -func (mr *MockStateDatabaseMockRecorder) WithTxImmediate(ctx, exec any) *MockStateDatabaseWithTxImmediateCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTxImmediate", reflect.TypeOf((*MockStateDatabase)(nil).WithTxImmediate), ctx, exec) - return &MockStateDatabaseWithTxImmediateCall{Call: call} -} - -// MockStateDatabaseWithTxImmediateCall wrap *gomock.Call -type MockStateDatabaseWithTxImmediateCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockStateDatabaseWithTxImmediateCall) Return(arg0 error) *MockStateDatabaseWithTxImmediateCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockStateDatabaseWithTxImmediateCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockStateDatabaseWithTxImmediateCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateDatabaseWithTxImmediateCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockStateDatabaseWithTxImmediateCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// MockLocalDatabase is a mock of LocalDatabase interface. -type MockLocalDatabase struct { - ctrl *gomock.Controller - recorder *MockLocalDatabaseMockRecorder -} - -// MockLocalDatabaseMockRecorder is the mock recorder for MockLocalDatabase. -type MockLocalDatabaseMockRecorder struct { - mock *MockLocalDatabase -} - -// NewMockLocalDatabase creates a new mock instance. -func NewMockLocalDatabase(ctrl *gomock.Controller) *MockLocalDatabase { - mock := &MockLocalDatabase{ctrl: ctrl} - mock.recorder = &MockLocalDatabaseMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockLocalDatabase) EXPECT() *MockLocalDatabaseMockRecorder { - return m.recorder -} - -// ClearCache mocks base method. -func (m *MockLocalDatabase) ClearCache() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "ClearCache") -} - -// ClearCache indicates an expected call of ClearCache. -func (mr *MockLocalDatabaseMockRecorder) ClearCache() *MockLocalDatabaseClearCacheCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClearCache", reflect.TypeOf((*MockLocalDatabase)(nil).ClearCache)) - return &MockLocalDatabaseClearCacheCall{Call: call} -} - -// MockLocalDatabaseClearCacheCall wrap *gomock.Call -type MockLocalDatabaseClearCacheCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseClearCacheCall) Return() *MockLocalDatabaseClearCacheCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseClearCacheCall) Do(f func()) *MockLocalDatabaseClearCacheCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseClearCacheCall) DoAndReturn(f func()) *MockLocalDatabaseClearCacheCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Close mocks base method. -func (m *MockLocalDatabase) Close() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close. -func (mr *MockLocalDatabaseMockRecorder) Close() *MockLocalDatabaseCloseCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockLocalDatabase)(nil).Close)) - return &MockLocalDatabaseCloseCall{Call: call} -} - -// MockLocalDatabaseCloseCall wrap *gomock.Call -type MockLocalDatabaseCloseCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseCloseCall) Return(arg0 error) *MockLocalDatabaseCloseCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseCloseCall) Do(f func() error) *MockLocalDatabaseCloseCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseCloseCall) DoAndReturn(f func() error) *MockLocalDatabaseCloseCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Exec mocks base method. -func (m *MockLocalDatabase) Exec(arg0 string, arg1 sql.Encoder, arg2 sql.Decoder) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exec", arg0, arg1, arg2) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec. -func (mr *MockLocalDatabaseMockRecorder) Exec(arg0, arg1, arg2 any) *MockLocalDatabaseExecCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockLocalDatabase)(nil).Exec), arg0, arg1, arg2) - return &MockLocalDatabaseExecCall{Call: call} -} - -// MockLocalDatabaseExecCall wrap *gomock.Call -type MockLocalDatabaseExecCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseExecCall) Return(arg0 int, arg1 error) *MockLocalDatabaseExecCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseExecCall) Do(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockLocalDatabaseExecCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseExecCall) DoAndReturn(f func(string, sql.Encoder, sql.Decoder) (int, error)) *MockLocalDatabaseExecCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// GetValue mocks base method. -func (m *MockLocalDatabase) GetValue(ctx context.Context, key sql.QueryCacheItemKey, subKey sql.QueryCacheSubKey, retrieve sql.UntypedRetrieveFunc) (any, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetValue", ctx, key, subKey, retrieve) - ret0, _ := ret[0].(any) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetValue indicates an expected call of GetValue. -func (mr *MockLocalDatabaseMockRecorder) GetValue(ctx, key, subKey, retrieve any) *MockLocalDatabaseGetValueCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValue", reflect.TypeOf((*MockLocalDatabase)(nil).GetValue), ctx, key, subKey, retrieve) - return &MockLocalDatabaseGetValueCall{Call: call} -} - -// MockLocalDatabaseGetValueCall wrap *gomock.Call -type MockLocalDatabaseGetValueCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseGetValueCall) Return(arg0 any, arg1 error) *MockLocalDatabaseGetValueCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseGetValueCall) Do(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockLocalDatabaseGetValueCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseGetValueCall) DoAndReturn(f func(context.Context, sql.QueryCacheItemKey, sql.QueryCacheSubKey, sql.UntypedRetrieveFunc) (any, error)) *MockLocalDatabaseGetValueCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// IsCached mocks base method. -func (m *MockLocalDatabase) IsCached() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsCached") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsCached indicates an expected call of IsCached. -func (mr *MockLocalDatabaseMockRecorder) IsCached() *MockLocalDatabaseIsCachedCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsCached", reflect.TypeOf((*MockLocalDatabase)(nil).IsCached)) - return &MockLocalDatabaseIsCachedCall{Call: call} -} - -// MockLocalDatabaseIsCachedCall wrap *gomock.Call -type MockLocalDatabaseIsCachedCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseIsCachedCall) Return(arg0 bool) *MockLocalDatabaseIsCachedCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseIsCachedCall) Do(f func() bool) *MockLocalDatabaseIsCachedCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseIsCachedCall) DoAndReturn(f func() bool) *MockLocalDatabaseIsCachedCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// IsLocalDatabase mocks base method. -func (m *MockLocalDatabase) IsLocalDatabase() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsLocalDatabase") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsLocalDatabase indicates an expected call of IsLocalDatabase. -func (mr *MockLocalDatabaseMockRecorder) IsLocalDatabase() *MockLocalDatabaseIsLocalDatabaseCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsLocalDatabase", reflect.TypeOf((*MockLocalDatabase)(nil).IsLocalDatabase)) - return &MockLocalDatabaseIsLocalDatabaseCall{Call: call} -} - -// MockLocalDatabaseIsLocalDatabaseCall wrap *gomock.Call -type MockLocalDatabaseIsLocalDatabaseCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseIsLocalDatabaseCall) Return(arg0 bool) *MockLocalDatabaseIsLocalDatabaseCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseIsLocalDatabaseCall) Do(f func() bool) *MockLocalDatabaseIsLocalDatabaseCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseIsLocalDatabaseCall) DoAndReturn(f func() bool) *MockLocalDatabaseIsLocalDatabaseCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// QueryCache mocks base method. -func (m *MockLocalDatabase) QueryCache() sql.QueryCache { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryCache") - ret0, _ := ret[0].(sql.QueryCache) - return ret0 -} - -// QueryCache indicates an expected call of QueryCache. -func (mr *MockLocalDatabaseMockRecorder) QueryCache() *MockLocalDatabaseQueryCacheCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCache", reflect.TypeOf((*MockLocalDatabase)(nil).QueryCache)) - return &MockLocalDatabaseQueryCacheCall{Call: call} -} - -// MockLocalDatabaseQueryCacheCall wrap *gomock.Call -type MockLocalDatabaseQueryCacheCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseQueryCacheCall) Return(arg0 sql.QueryCache) *MockLocalDatabaseQueryCacheCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseQueryCacheCall) Do(f func() sql.QueryCache) *MockLocalDatabaseQueryCacheCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseQueryCacheCall) DoAndReturn(f func() sql.QueryCache) *MockLocalDatabaseQueryCacheCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// QueryCount mocks base method. -func (m *MockLocalDatabase) QueryCount() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "QueryCount") - ret0, _ := ret[0].(int) - return ret0 -} - -// QueryCount indicates an expected call of QueryCount. -func (mr *MockLocalDatabaseMockRecorder) QueryCount() *MockLocalDatabaseQueryCountCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryCount", reflect.TypeOf((*MockLocalDatabase)(nil).QueryCount)) - return &MockLocalDatabaseQueryCountCall{Call: call} -} - -// MockLocalDatabaseQueryCountCall wrap *gomock.Call -type MockLocalDatabaseQueryCountCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseQueryCountCall) Return(arg0 int) *MockLocalDatabaseQueryCountCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseQueryCountCall) Do(f func() int) *MockLocalDatabaseQueryCountCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseQueryCountCall) DoAndReturn(f func() int) *MockLocalDatabaseQueryCountCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// Tx mocks base method. -func (m *MockLocalDatabase) Tx(ctx context.Context) (sql.Transaction, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Tx", ctx) - ret0, _ := ret[0].(sql.Transaction) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Tx indicates an expected call of Tx. -func (mr *MockLocalDatabaseMockRecorder) Tx(ctx any) *MockLocalDatabaseTxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Tx", reflect.TypeOf((*MockLocalDatabase)(nil).Tx), ctx) - return &MockLocalDatabaseTxCall{Call: call} -} - -// MockLocalDatabaseTxCall wrap *gomock.Call -type MockLocalDatabaseTxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseTxCall) Return(arg0 sql.Transaction, arg1 error) *MockLocalDatabaseTxCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseTxCall) Do(f func(context.Context) (sql.Transaction, error)) *MockLocalDatabaseTxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseTxCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockLocalDatabaseTxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// TxImmediate mocks base method. -func (m *MockLocalDatabase) TxImmediate(ctx context.Context) (sql.Transaction, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TxImmediate", ctx) - ret0, _ := ret[0].(sql.Transaction) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// TxImmediate indicates an expected call of TxImmediate. -func (mr *MockLocalDatabaseMockRecorder) TxImmediate(ctx any) *MockLocalDatabaseTxImmediateCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TxImmediate", reflect.TypeOf((*MockLocalDatabase)(nil).TxImmediate), ctx) - return &MockLocalDatabaseTxImmediateCall{Call: call} -} - -// MockLocalDatabaseTxImmediateCall wrap *gomock.Call -type MockLocalDatabaseTxImmediateCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseTxImmediateCall) Return(arg0 sql.Transaction, arg1 error) *MockLocalDatabaseTxImmediateCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseTxImmediateCall) Do(f func(context.Context) (sql.Transaction, error)) *MockLocalDatabaseTxImmediateCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseTxImmediateCall) DoAndReturn(f func(context.Context) (sql.Transaction, error)) *MockLocalDatabaseTxImmediateCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// UpdateSlice mocks base method. -func (m *MockLocalDatabase) UpdateSlice(key sql.QueryCacheItemKey, update sql.SliceAppender) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "UpdateSlice", key, update) -} - -// UpdateSlice indicates an expected call of UpdateSlice. -func (mr *MockLocalDatabaseMockRecorder) UpdateSlice(key, update any) *MockLocalDatabaseUpdateSliceCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSlice", reflect.TypeOf((*MockLocalDatabase)(nil).UpdateSlice), key, update) - return &MockLocalDatabaseUpdateSliceCall{Call: call} -} - -// MockLocalDatabaseUpdateSliceCall wrap *gomock.Call -type MockLocalDatabaseUpdateSliceCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseUpdateSliceCall) Return() *MockLocalDatabaseUpdateSliceCall { - c.Call = c.Call.Return() - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseUpdateSliceCall) Do(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockLocalDatabaseUpdateSliceCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseUpdateSliceCall) DoAndReturn(f func(sql.QueryCacheItemKey, sql.SliceAppender)) *MockLocalDatabaseUpdateSliceCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// WithTx mocks base method. -func (m *MockLocalDatabase) WithTx(ctx context.Context, exec func(sql.Transaction) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTx", ctx, exec) - ret0, _ := ret[0].(error) - return ret0 -} - -// WithTx indicates an expected call of WithTx. -func (mr *MockLocalDatabaseMockRecorder) WithTx(ctx, exec any) *MockLocalDatabaseWithTxCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTx", reflect.TypeOf((*MockLocalDatabase)(nil).WithTx), ctx, exec) - return &MockLocalDatabaseWithTxCall{Call: call} -} - -// MockLocalDatabaseWithTxCall wrap *gomock.Call -type MockLocalDatabaseWithTxCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseWithTxCall) Return(arg0 error) *MockLocalDatabaseWithTxCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseWithTxCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockLocalDatabaseWithTxCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseWithTxCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockLocalDatabaseWithTxCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - -// WithTxImmediate mocks base method. -func (m *MockLocalDatabase) WithTxImmediate(ctx context.Context, exec func(sql.Transaction) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "WithTxImmediate", ctx, exec) - ret0, _ := ret[0].(error) - return ret0 -} - -// WithTxImmediate indicates an expected call of WithTxImmediate. -func (mr *MockLocalDatabaseMockRecorder) WithTxImmediate(ctx, exec any) *MockLocalDatabaseWithTxImmediateCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTxImmediate", reflect.TypeOf((*MockLocalDatabase)(nil).WithTxImmediate), ctx, exec) - return &MockLocalDatabaseWithTxImmediateCall{Call: call} -} - -// MockLocalDatabaseWithTxImmediateCall wrap *gomock.Call -type MockLocalDatabaseWithTxImmediateCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockLocalDatabaseWithTxImmediateCall) Return(arg0 error) *MockLocalDatabaseWithTxImmediateCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockLocalDatabaseWithTxImmediateCall) Do(f func(context.Context, func(sql.Transaction) error) error) *MockLocalDatabaseWithTxImmediateCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockLocalDatabaseWithTxImmediateCall) DoAndReturn(f func(context.Context, func(sql.Transaction) error) error) *MockLocalDatabaseWithTxImmediateCall { - c.Call = c.Call.DoAndReturn(f) - return c -} From cbf419542f1f6ca3ba7b661cc1f01a26fbdf6766 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 17 Jun 2024 00:21:35 +0400 Subject: [PATCH 31/62] sql: fix naming --- sql/database.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/database.go b/sql/database.go index 7d7fb8460c..73d7257277 100644 --- a/sql/database.go +++ b/sql/database.go @@ -467,7 +467,7 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in for { row, err := stmt.Step() if err != nil { - return 0, fmt.Errorf("step %d: %w", rows, fixError(err)) + return 0, fmt.Errorf("step %d: %w", rows, mapSqliteError(err)) } if !row { return rows, nil @@ -499,7 +499,7 @@ func (tx *sqliteTx) begin(initstmt string) error { stmt := tx.conn.Prep(initstmt) _, err := stmt.Step() if err != nil { - return fmt.Errorf("begin: %w", fixError(err)) + return fmt.Errorf("begin: %w", mapSqliteError(err)) } return nil } @@ -509,7 +509,7 @@ func (tx *sqliteTx) Commit() error { stmt := tx.conn.Prep("COMMIT;") _, tx.err = stmt.Step() if tx.err != nil { - return fixError(tx.err) + return mapSqliteError(tx.err) } tx.committed = true return nil @@ -523,7 +523,7 @@ func (tx *sqliteTx) Release() error { } stmt := tx.conn.Prep("ROLLBACK") _, tx.err = stmt.Step() - return fixError(tx.err) + return mapSqliteError(tx.err) } // Exec query. @@ -538,7 +538,7 @@ func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, e return exec(tx.conn, query, encoder, decoder) } -func fixError(err error) error { +func mapSqliteError(err error) error { code := sqlite.ErrCode(err) if code == sqlite.SQLITE_CONSTRAINT_PRIMARYKEY || code == sqlite.SQLITE_CONSTRAINT_UNIQUE { return ErrObjectExists From b56dc59d27180e00ba3576fe484c103d780cb371 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 17 Jun 2024 00:23:13 +0400 Subject: [PATCH 32/62] sql: remove unneeded assertions from tests --- sql/localsql/localsql_test.go | 1 - sql/statesql/statesql_test.go | 1 - 2 files changed, 2 deletions(-) diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go index 85bdc4d418..320702d7eb 100644 --- a/sql/localsql/localsql_test.go +++ b/sql/localsql/localsql_test.go @@ -57,7 +57,6 @@ func TestIdempotentMigration(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, err) schema, err := Schema() require.NoError(t, err) expectedVersion := slices.MaxFunc( diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go index f3140e2858..e3814ba5e1 100644 --- a/sql/statesql/statesql_test.go +++ b/sql/statesql/statesql_test.go @@ -57,7 +57,6 @@ func TestIdempotentMigration(t *testing.T) { }) require.NoError(t, err) - require.NoError(t, err) schema, err := Schema() require.NoError(t, err) expectedVersion := slices.MaxFunc( From f1f11fe417a8c4cbd96bee1c5e685ff2bd88f11c Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 17 Jun 2024 00:23:26 +0400 Subject: [PATCH 33/62] sql: schemagen: fix help --- sql/schemagen/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/schemagen/main.go b/sql/schemagen/main.go index db064ce60b..c55d2b3ad4 100644 --- a/sql/schemagen/main.go +++ b/sql/schemagen/main.go @@ -15,7 +15,7 @@ import ( var ( level = zap.LevelFlag("level", zapcore.ErrorLevel, "set log verbosity level") dbType = flag.String("dbtype", "state", "database type (state, local, default state)") - output = flag.String("output", "", "output file (defaults to stdin)") + output = flag.String("output", "", "output file (defaults to stdout)") ) func main() { From c2c7e6f8105181f5a0f5785c7a581d16161d34ba Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 25 Jun 2024 13:44:47 +0400 Subject: [PATCH 34/62] Moved database schema handling docs to CODING.md --- CODING.md | 76 +++++++++++++++++++++++++++++++++++++++++++++++++ CONTRIBUTING.md | 76 ------------------------------------------------- 2 files changed, 76 insertions(+), 76 deletions(-) diff --git a/CODING.md b/CODING.md index fbb458d693..0e373b3746 100644 --- a/CODING.md +++ b/CODING.md @@ -105,3 +105,79 @@ Some useful logging recommendations for cleaner output: ## Commit Messages For commit messages, follow [this guideline](https://www.conventionalcommits.org/en/v1.0.0/). Use reasonable length for the subject and body, ideally no longer than 72 characters. Use the imperative mood for subject lines. + +## Handling database schema changes + +go-spacemesh currently maintains 2 SQLite databases in the data folder: `state.sql` (state database) and `local.sql` (local database). It employs schema versioning for both databases, with a possibility to upgrade older versions of each database to the current schema version by means of running a series of migrations. Also, go-spacemesh tracks any schema drift (unexpected schema changes) in the databases. + +When a database is first created, the corresponding schema file embedded in go-spacemesh executable is used to initialize it: +* `sql/statesql/schema/schema.sql` for `state.sql` +* `sql/localsql/schema/schema.sql` for `local.sql` +The schema file includes `PRAGMA user_version = ...` which sets the version of the database schema. The version of the schema is equal to the number of migrations defined for the corresponding database (`state.sql` or `local.sql`). + +For an existing database, the `PRAGMA user_version` is checked against the expected version number. If the database's schema version is too new, go-spacemesh fails right away as an older go-spacemesh version cannot be expected to work with a database from a newer version. If the database version number is older than the expected version, go-spacemesh runs the necessary migration scripts embedded in go-spacemesh executable and updates `PRAGMA user_version = ...`. The migration scripts are located in the following folders: +* `sql/statesql/schema/migrations` for `state.sql` +* `sql/localsql/schema/migrations` for `local.sql` + +Additionally, some migrations ("coded migrations") can be implemented in Go code, in which case they reside in `.go` files located in `sql/statesql` and `sql/localsql` packages, respectively. It is worth noting that old coded migrations can be removed at some point, rendering database versions that are *too* old unusable with newer go-spacemesh versions. + +After all the migrations are run, go-spacemesh compares the schema of each database to the embedded schema scripts and if they differ, fails with an error message: +``` +Error: open sqlite db schema drift detected (uri file:data/state.sql): + ( + """ + ... // 82 identical lines + PRIMARY KEY (layer, block) + ); ++ CREATE TABLE foo(id int); + CREATE TABLE identities + ( + ... // 66 identical lines + """ + ) +``` + +In this case, a table named `foo` has somehow appeared in the database, causing go-spacemesh to fail due to the schema drift. The possible reasons for schema drift can be the following: +* running an unreleased version of go-spacemesh using your data folder. The unreleased version may contain migrations that may be changed before the release happens +* manual changes in the database +* external SQLite tooling used on the database that adds some tables, indices etc. + +In case if you want to run go-spacemesh with schema drift anyway, you can set `main.db-allow-schema-drift` to true. In this case, a warning with schema diff will be logged instead of failing. + +The schema changes in go-spacemesh code should be always done by means of adding migrations. Let's for example create a new migration (use zero-padded N+1 instead of 0010 with N being the number of the last migration for the local db): + +```console +$ echo 'CREATE TABLE foo(id int);' >sql/localsql/schema/migrations/0010_foo.sql +``` + +After that, we update the schema files +```console +$ make generate +$ # alternative: cd sql/localsql && go generate +$ git diff sql/localsql/schema/schema.sql +diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql +index 02c44d3cc..ebcdf4278 100755 +--- a/sql/localsql/schema/schema.sql ++++ b/sql/localsql/schema/schema.sql +@@ -1,4 +1,4 @@ +-PRAGMA user_version = 9; ++PRAGMA user_version = 10; + CREATE TABLE atx_sync_requests + ( + epoch INT NOT NULL, +@@ -24,6 +24,7 @@ CREATE TABLE "challenge" + post_indices VARCHAR, + post_pow UNSIGNED LONG INT + , poet_proof_ref CHAR(32), poet_proof_membership VARCHAR) WITHOUT ROWID; ++CREATE TABLE foo(id int); + CREATE TABLE malfeasance_sync_state + ( + id INT NOT NULL PRIMARY KEY, +``` + +Note that the changes include both the new table and an updated `PRAGMA user_version` line. +The changes in the schema file must be committed along with the migration we added. +```console +$ git add sql/localsql/schema/migrations/0010_foo.sql sql/localsql/schema.sql +$ git commit -m "sql: add a test migration" +``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index de27fcaa54..d5e4d4f134 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -71,79 +71,3 @@ Please follow these guidelines for your PR to be reviewed and be considered for - You should receive your bounty as soon as your PR is approved and merged by one of our maintainers. Please review our funded issues program [legal notes](https://github.com/spacemeshos/go-spacemesh/blob/master/legal.md). - -# Handling database schema changes - -go-spacemesh currently maintains 2 SQLite databases in the data folder: `state.sql` (state database) and `local.sql` (local database). It employs schema versioning for both databases, with a possibility to upgrade older versions of each database to the current schema version by means of running a series of migrations. Also, go-spacemesh tracks any schema drift (unexpected schema changes) in the databases. - -When a database is first created, the corresponding schema file embedded in go-spacemesh executable is used to initialize it: -* `sql/statesql/schema/schema.sql` for `state.sql` -* `sql/localsql/schema/schema.sql` for `local.sql` -The schema file includes `PRAGMA user_version = ...` which sets the version of the database schema. The version of the schema is equal to the number of migrations defined for the corresponding database (`state.sql` or `local.sql`). - -For an existing database, the `PRAGMA user_version` is checked against the expected version number. If the database's schema version is too new, go-spacemesh fails right away as an older go-spacemesh version cannot be expected to work with a database from a newer version. If the database version number is older than the expected version, go-spacemesh runs the necessary migration scripts embedded in go-spacemesh executable and updates `PRAGMA user_version = ...`. The migration scripts are located in the following folders: -* `sql/statesql/schema/migrations` for `state.sql` -* `sql/localsql/schema/migrations` for `local.sql` - -Additionally, some migrations ("coded migrations") can be implemented in Go code, in which case they reside in `.go` files located in `sql/statesql` and `sql/localsql` packages, respectively. It is worth noting that old coded migrations can be removed at some point, rendering database versions that are *too* old unusable with newer go-spacemesh versions. - -After all the migrations are run, go-spacemesh compares the schema of each database to the embedded schema scripts and if they differ, fails with an error message: -``` -Error: open sqlite db schema drift detected (uri file:data/state.sql): - ( - """ - ... // 82 identical lines - PRIMARY KEY (layer, block) - ); -+ CREATE TABLE foo(id int); - CREATE TABLE identities - ( - ... // 66 identical lines - """ - ) -``` - -In this case, a table named `foo` has somehow appeared in the database, causing go-spacemesh to fail due to the schema drift. The possible reasons for schema drift can be the following: -* running an unreleased version of go-spacemesh using your data folder. The unreleased version may contain migrations that may be changed before the release happens -* manual changes in the database -* external SQLite tooling used on the database that adds some tables, indices etc. - -In case if you want to run go-spacemesh with schema drift anyway, you can set `main.db-allow-schema-drift` to true. In this case, a warning with schema diff will be logged instead of failing. - -The schema changes in go-spacemesh code should be always done by means of adding migrations. Let's for example create a new migration (use zero-padded N+1 instead of 0010 with N being the number of the last migration for the local db): - -```console -$ echo 'CREATE TABLE foo(id int);' >sql/localsql/schema/migrations/0010_foo.sql -``` - -After that, we update the schema files -```console -$ make generate -$ # alternative: cd sql/localsql && go generate -$ git diff sql/localsql/schema/schema.sql -diff --git a/sql/localsql/schema/schema.sql b/sql/localsql/schema/schema.sql -index 02c44d3cc..ebcdf4278 100755 ---- a/sql/localsql/schema/schema.sql -+++ b/sql/localsql/schema/schema.sql -@@ -1,4 +1,4 @@ --PRAGMA user_version = 9; -+PRAGMA user_version = 10; - CREATE TABLE atx_sync_requests - ( - epoch INT NOT NULL, -@@ -24,6 +24,7 @@ CREATE TABLE "challenge" - post_indices VARCHAR, - post_pow UNSIGNED LONG INT - , poet_proof_ref CHAR(32), poet_proof_membership VARCHAR) WITHOUT ROWID; -+CREATE TABLE foo(id int); - CREATE TABLE malfeasance_sync_state - ( - id INT NOT NULL PRIMARY KEY, -``` - -Note that the changes include both the new table and an updated `PRAGMA user_version` line. -The changes in the schema file must be committed along with the migration we added. -```console -$ git add sql/localsql/schema/migrations/0010_foo.sql sql/localsql/schema.sql -$ git commit -m "sql: add a test migration" -``` From 0400d33e84e680a2129b0dd5863fb0f91535d106 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 25 Jun 2024 15:42:46 +0400 Subject: [PATCH 35/62] api: fix database handling in the test --- api/grpcserver/admin_service_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/grpcserver/admin_service_test.go b/api/grpcserver/admin_service_test.go index 4ec02f5479..5c9e7d3d53 100644 --- a/api/grpcserver/admin_service_test.go +++ b/api/grpcserver/admin_service_test.go @@ -133,7 +133,7 @@ func TestAdminService_PeerInfo(t *testing.T) { ctrl := gomock.NewController(t) p := NewMockpeers(ctrl) - db := sql.InMemory() + db := statesql.InMemory() svc := NewAdminService(db, t.TempDir(), p) cfg, cleanup := launchServer(t, svc) From a2679a5a47751766925de48c363fbe321cae431a Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 25 Jun 2024 16:08:35 +0400 Subject: [PATCH 36/62] sql: fix identities test --- sql/identities/identities_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index 9f146e087b..4f5af8c04e 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -124,7 +124,7 @@ func TestMarried(t *testing.T) { t.Parallel() t.Run("identity not in DB", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() married, err := Married(db, id) @@ -140,7 +140,7 @@ func TestMarried(t *testing.T) { }) t.Run("identity in DB", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() // add ID in the DB @@ -162,7 +162,7 @@ func TestEquivocationSet(t *testing.T) { t.Parallel() t.Run("equivocation set of married IDs", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() ids := []types.NodeID{ @@ -185,7 +185,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("equivocation set for unmarried ID contains itself only", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() set, err := EquivocationSet(db, id) require.NoError(t, err) @@ -193,7 +193,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("can't escape the marriage", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() ids := []types.NodeID{ types.RandomNodeID(), @@ -219,7 +219,7 @@ func TestEquivocationSet(t *testing.T) { } }) t.Run("married doesn't become malicious immediately", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() id := types.RandomNodeID() require.NoError(t, SetMarriage(db, id, atx)) @@ -238,7 +238,7 @@ func TestEquivocationSet(t *testing.T) { }) t.Run("all IDs in equivocation set are malicious if one is", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atx := types.RandomATXID() ids := []types.NodeID{ types.RandomNodeID(), From 33b9501fb3bab79625065242be0d11efa32689d5 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 26 Jun 2024 18:37:50 +0400 Subject: [PATCH 37/62] sql: simplify StateDatabase and LocalDatabase interfaces --- sql/database.go | 4 ++-- sql/localsql/localsql.go | 2 +- sql/statesql/statesql.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/database.go b/sql/database.go index 73d7257277..1a75ab86b8 100644 --- a/sql/database.go +++ b/sql/database.go @@ -640,11 +640,11 @@ func IsNull(stmt *Statement, col int) bool { // StateDatabase is a Database used for Spacemesh state. type StateDatabase interface { Database - IsStateDatabase() bool + IsStateDatabase() } // LocalDatabase is a Database used for local node data. type LocalDatabase interface { Database - IsLocalDatabase() bool + IsLocalDatabase() } diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go index a257c6b19e..0b79bc546f 100644 --- a/sql/localsql/localsql.go +++ b/sql/localsql/localsql.go @@ -21,7 +21,7 @@ type database struct { var _ sql.LocalDatabase = &database{} -func (d *database) IsLocalDatabase() bool { return true } +func (d *database) IsLocalDatabase() {} // Schema returns the schema for the local database. func Schema() (*sql.Schema, error) { diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 7dad4a4b48..9e6d997eed 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -21,7 +21,7 @@ type database struct { var _ sql.StateDatabase = &database{} -func (db *database) IsStateDatabase() bool { return true } +func (db *database) IsStateDatabase() {} // Schema returns the schema for the state database. func Schema() (*sql.Schema, error) { From ff35334d05bd338a126236a4381bc72a511641c7 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 26 Jun 2024 18:38:13 +0400 Subject: [PATCH 38/62] node: make it possibe to allow localsql schema drift --- node/node.go | 1 + 1 file changed, 1 insertion(+) diff --git a/node/node.go b/node/node.go index 978e10f861..7e7fd6b730 100644 --- a/node/node.go +++ b/node/node.go @@ -1987,6 +1987,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { localDB, err := localsql.Open("file:"+filepath.Join(dbPath, localDbFile), sql.WithLogger(dbLog.Zap()), sql.WithConnections(app.Config.DatabaseConnections), + sql.WithAllowSchemaDrift(app.Config.DatabaseSchemaAllowDrift), ) if err != nil { return fmt.Errorf("open sqlite db %w", err) From 3d38b8fd884ee4817ef7c56bde1a388bb3e114be Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 26 Jun 2024 18:34:33 +0400 Subject: [PATCH 39/62] sql: add copy-based migrations and VACUUM INTO When vacuuming is required, copy the source database using VACUUM INTO to a temporary database, perform migrations on the temporary database with journal_mode=OFF and synchronous=OFF, and then copy back the migrated temporary database over the original one using VACUUM INTO. Fixes #6069 --- checkpoint/recovery.go | 2 +- node/node.go | 4 +- sql/database.go | 513 ++++++++++++++++++++++++++++++---- sql/database_test.go | 289 ++++++++++++++++++- sql/localsql/localsql_test.go | 2 +- sql/migrations.go | 4 - sql/schema.go | 88 +++++- sql/statesql/statesql_test.go | 2 +- 8 files changed, 823 insertions(+), 81 deletions(-) diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 3c5d3f6d88..e10ec32cb3 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -247,7 +247,7 @@ func recoverFromLocalFile( newDB, err := statesql.Open("file:" + filepath.Join(cfg.DataDir, cfg.DbFile)) if err != nil { - return nil, fmt.Errorf("open sqlite db %w", err) + return nil, fmt.Errorf("open sqlite db: %w", err) } defer newDB.Close() logger.Info("populating new database", diff --git a/node/node.go b/node/node.go index 978e10f861..3e4b8971b9 100644 --- a/node/node.go +++ b/node/node.go @@ -1944,7 +1944,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { } sqlDB, err := statesql.Open("file:"+filepath.Join(dbPath, dbFile), dbopts...) if err != nil { - return fmt.Errorf("open sqlite db %w", err) + return fmt.Errorf("open sqlite db: %w", err) } app.db = sqlDB if app.Config.CollectMetrics && app.Config.DatabaseSizeMeteringInterval != 0 { @@ -1989,7 +1989,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { sql.WithConnections(app.Config.DatabaseConnections), ) if err != nil { - return fmt.Errorf("open sqlite db %w", err) + return fmt.Errorf("open sqlite db: %w", err) } app.localDB = localDB return nil diff --git a/sql/database.go b/sql/database.go index 73d7257277..964d07176c 100644 --- a/sql/database.go +++ b/sql/database.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "maps" + "net/url" + "os" "strings" "sync" "sync/atomic" @@ -20,6 +22,8 @@ import ( ) var ( + // ErrClosed is returned if database is closed. + ErrClosed = errors.New("database closed") // ErrNoConnection is returned if pooled connection is not available. ErrNoConnection = errors.New("database: no free connection") // ErrNotFound is returned if requested record is not found. @@ -63,26 +67,30 @@ type Decoder func(*Statement) bool func defaultConf() *conf { return &conf{ - enableMigrations: true, - connections: 16, - logger: zap.NewNop(), - schema: &Schema{}, + enableMigrations: true, + connections: 16, + logger: zap.NewNop(), + schema: &Schema{}, + handleIncompleteMigrations: true, } } type conf struct { - enableMigrations bool - forceFresh bool - forceMigrations bool - connections int - vacuumState int - enableLatency bool - cache bool - cacheSizes map[QueryCacheKind]int - logger *zap.Logger - schema *Schema - allowSchemaDrift bool - ignoreSchemaDrift bool + uri string + enableMigrations bool + forceFresh bool + forceMigrations bool + connections int + vacuumState int + enableLatency bool + cache bool + cacheSizes map[QueryCacheKind]int + logger *zap.Logger + schema *Schema + allowSchemaDrift bool + ignoreSchemaDrift bool + temp bool + handleIncompleteMigrations bool } // WithConnections overwrites number of pooled connections. @@ -176,6 +184,21 @@ func withForceFresh() Opt { } } +// WithTemp specifies temporary database mode. +// For the temporary database, the migrations are always run in place, and vacuuming is +// nover done. PRAGMA journal_mode=OFF and PRAGMA synchronous=OFF are used. +func WithTemp() Opt { + return func(c *conf) { + c.temp = true + } +} + +func withDisableIncompleteMigrationHandling() Opt { + return func(c *conf) { + c.handleIncompleteMigrations = false + } +} + // Opt for configuring database. type Opt func(c *conf) @@ -202,64 +225,79 @@ func InMemory(opts ...Opt) *sqliteDatabase { // https://www.sqlite.org/pragma.html#pragma_synchronous func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { config := defaultConf() + config.uri = uri for _, opt := range opts { opt(config) } - logger := config.logger.With(zap.String("uri", uri)) + if !config.temp && config.handleIncompleteMigrations { + if err := handleIncompleteCopyMigration(config); err != nil { + return nil, err + } + } + return openDB(config) +} + +func openDB(config *conf) (*sqliteDatabase, error) { + logger := config.logger.With(zap.String("uri", config.uri)) var flags sqlite.OpenFlags if !config.forceFresh { flags = sqlite.SQLITE_OPEN_READWRITE | - sqlite.SQLITE_OPEN_WAL | sqlite.SQLITE_OPEN_URI | sqlite.SQLITE_OPEN_NOMUTEX + if !config.temp { + // Note that SQLITE_OPEN_WAL is not handled by SQLITE api itself, + // but rather by the crawshaw library which executes + // PRAGMA journal_mode=WAL in this case. + // We don't want it for temporary databases as they're not + // using any journal + flags |= sqlite.SQLITE_OPEN_WAL + } } freshDB := config.forceFresh - pool, err := sqlitex.Open(uri, flags, config.connections) + pool, err := sqlitex.Open(config.uri, flags, config.connections) if err != nil { if config.forceFresh || sqlite.ErrCode(err) != sqlite.SQLITE_CANTOPEN { - return nil, fmt.Errorf("open db %s: %w", uri, err) + return nil, fmt.Errorf("open db %s: %w", config.uri, err) } flags |= sqlite.SQLITE_OPEN_CREATE freshDB = true - pool, err = sqlitex.Open(uri, flags, config.connections) + pool, err = sqlitex.Open(config.uri, flags, config.connections) if err != nil { - return nil, fmt.Errorf("create db %s: %w", uri, err) + return nil, fmt.Errorf("create db %s: %w", config.uri, err) } } db := &sqliteDatabase{pool: pool} if config.enableLatency { db.latency = newQueryLatency() } + + if config.temp { + // Temporary database is used for migration and is deleted if migrations + // fail, so we make it faster by disabling journaling and synchronous + // writes. + if _, err := db.Exec("PRAGMA journal_mode=OFF", nil, nil); err != nil { + return nil, errors.Join( + fmt.Errorf("PRAGMA journal_mode=OFF: %w", err), + db.Close()) + } + if _, err := db.Exec("PRAGMA synchronous=OFF", nil, nil); err != nil { + return nil, errors.Join( + fmt.Errorf("PRAGMA journal_mode=OFF: %w", err), + db.Close()) + } + } + if freshDB && !config.forceMigrations { if err := config.schema.Apply(db); err != nil { return nil, errors.Join( fmt.Errorf("error running schema script: %w", err), db.Close()) } - } else { - before, after, err := config.schema.CheckDBVersion(logger, db) - switch { - case err != nil: - return nil, errors.Join(err, db.Close()) - case before != after && config.enableMigrations: - logger.Info("running migrations", - zap.Int("current version", before), - zap.Int("target version", after), - ) - if err := config.schema.Migrate( - logger, db, before, config.vacuumState, - ); err != nil { - return nil, errors.Join(err, db.Close()) - } - case before != after: - logger.Error("database version is too old", - zap.Int("current version", before), - zap.Int("target version", after), - ) - return nil, errors.Join( - fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after), - db.Close()) + } else if db, err = ensureDBSchemaUpToDate(logger, db, config); err != nil { + if db != nil { + err = errors.Join(err, db.Close()) } + return nil, err } if !config.ignoreSchemaDrift { @@ -274,12 +312,12 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { case diff == "": // ok case config.allowSchemaDrift: logger.Warn("database schema drift detected", - zap.String("uri", uri), + zap.String("uri", config.uri), zap.String("diff", diff), ) default: return nil, errors.Join( - fmt.Errorf("schema drift detected (uri %s):\n%s", uri, diff), + fmt.Errorf("schema drift detected (uri %s):\n%s", config.uri, diff), db.Close()) } } @@ -292,16 +330,194 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { return db, nil } +func ensureDBSchemaUpToDate(logger *zap.Logger, db *sqliteDatabase, config *conf) (*sqliteDatabase, error) { + before, after, err := config.schema.CheckDBVersion(logger, db) + switch { + case err != nil: + return db, err + case before == after: + return db, nil + case before > after: + // TODO: this should be logged by the caller + logger.Error("database version is newer than expected - downgrade is not supported", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return db, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) + case !config.enableMigrations: + // TODO: this should be logged by the caller + logger.Error("database version is too old", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return db, fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after) + case config.temp: + // Temporary database, do migrations without transactions + // and sync afterwards + return db, config.schema.MigrateTempDB(logger, db, before) + case config.vacuumState != 0 && + before <= config.vacuumState && + strings.HasPrefix(config.uri, "file:"): + logger.Info("running migrations", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return db.copyMigrateDB(config) + default: + // Do not produce extra "running migrations" log message for the + // temporary DB, as it was already logged + if !config.temp { + logger.Info("running migrations in-place", + zap.Int("current version", before), + zap.Int("target version", after), + ) + } else { + logger.Info("applying migrations to temporary DB", + zap.Int("current version", before), + zap.Int("target version", after), + ) + } + return db, config.schema.Migrate(logger, db, before, config.vacuumState) + } +} + func Version(uri string) (int, error) { pool, err := sqlitex.Open(uri, sqlite.SQLITE_OPEN_READONLY, 1) if err != nil { return 0, fmt.Errorf("open db %s: %w", uri, err) } db := &sqliteDatabase{pool: pool} - defer db.Close() - return version(db) + v, err := version(db) + return v, errors.Join(err, db.Close()) +} + +// deleteDB deletes the database at the specified path by removing /path/to/DB* files. +// If the database doesn't exist, no error is returned. +// In addition to what DROP DATABASE does, this also removes the migration marker file. +func deleteDB(path string) error { + // https://www.sqlite.org/tempfiles.html plus marker *_done + for _, suffix := range []string{"", "-journal", "-wal", "-shm", "_done"} { + file := path + suffix + if err := os.Remove(file); err != nil { + if errors.Is(err, os.ErrNotExist) { + continue + } + return fmt.Errorf("remove %s: %w", file, err) + } + } + return nil +} + +// moveMigratedDB runs "VACUUM INTO" on the database at fromPath and +// replaces the database at toPath with the vacuumed one. The database +// at fromPath is deleted after the operation. +func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { + config.logger.Warn("finalizing migration by moving the migrated DB to the original path", + zap.String("fromPath", fromPath), + zap.String("toPath", toPath)) + // Try to open the migrated DB before deleting the original one. + // If the migrated DB is being copied to the original path by another + // process, this will fail and the original database will not be deleted. + // We don't use the proper database schema here because the migrated DB + // may have been created with a different set of migrations. + db, err := Open("file:"+fromPath, + WithLogger(config.logger), + WithConnections(1), + WithTemp(), + WithIgnoreSchemaDrift()) + if err != nil { + return fmt.Errorf("open migrated DB %s: %w", fromPath, err) + } + if err := deleteDB(toPath); err != nil { + return err + } + if err := db.vacuumInto(toPath); err != nil { + return errors.Join(err, db.Close()) + } + // Open the vacuumed DB to avoid race condition when another process + // also tries to vacuum the migrated DB into the original path after + // we close the migrated DB. + origDB, err := Open("file:"+toPath, + WithLogger(config.logger), + WithConnections(1), + WithMigrationsDisabled(), + WithIgnoreSchemaDrift(), + withDisableIncompleteMigrationHandling()) + if err != nil { + return fmt.Errorf("open vacuumed DB %s: %w", toPath, err) + } + defer func() { + err = errors.Join(err, origDB.Close()) + }() + if err := db.Close(); err != nil { + return fmt.Errorf("close migrated DB %s: %w", fromPath, err) + } + if err := deleteDB(fromPath); err != nil { + return err + } + if err := origDB.Close(); err != nil { + return fmt.Errorf("close vacuumed DB %s: %w", toPath, err) + } + return nil } +func dbMigrationPaths(uri string) (dbPath, migratedPath string, err error) { + url, err := url.Parse(uri) + if err != nil { + return "", "", fmt.Errorf("parse uri: %w", err) + } + if url.Scheme != "file" { + return "", "", nil + } + path := url.Opaque + if path == "" { + path = url.Path + } + return path, path + "_migrate", nil +} + +// handleIncompleteCopyMigration handles incomplete copy-based migrations. +// It only works for 'file:' URIs, doing nothing for other URIs. +// It first checks if there's a copy of the database with "_migrate" suffix. +// If it's there, it checks if the migration is complete by checking if +// DBNAME_migrate_done file exists. It it doesn't, the migration is considered +// incomplete and the migrated database is removed. If DBNAME_migrate_done +// file exists, the migration is finalized by running "VACUUM INTO" on the +// migrated database and replacing the original, after which the migrated +// database is deleted. +func handleIncompleteCopyMigration(config *conf) error { + dbPath, migratedPath, err := dbMigrationPaths(config.uri) + if err != nil { + return err + } + if migratedPath == "" { + return nil + } + if _, err := os.Stat(migratedPath); err != nil { + if errors.Is(err, os.ErrNotExist) { + // no migration in progress + return nil + } + return fmt.Errorf("stat %s: %w", migratedPath, err) + } + if _, err := os.Stat(migratedPath + "_done"); err != nil { + if errors.Is(err, os.ErrNotExist) { + // incomplete migration, delete the migrated DB to start over + // after that + config.logger.Warn("incomplete migration detected, deleting temporary migrated DB", + zap.String("path", migratedPath)) + return deleteDB(migratedPath) + } + } + + // the migration is complete except for the last step + return moveMigratedDB(config, migratedPath, dbPath) +} + +// Interceptor is invoked on every query after it's added to a database using +// PushIntercept. The query will fail if Interceptor returns an error. +type Interceptor func(query string) error + // Database represents a database. type Database interface { Executor @@ -313,6 +529,8 @@ type Database interface { WithTx(ctx context.Context, exec func(Transaction) error) error TxImmediate(ctx context.Context) (Transaction, error) WithTxImmediate(ctx context.Context, exec func(Transaction) error) error + Intercept(key string, fn Interceptor) + RemoveInterceptor(key string) } // Transaction represents a transaction. @@ -331,6 +549,9 @@ type sqliteDatabase struct { latency *prometheus.HistogramVec queryCount atomic.Int64 + + interceptMtx sync.Mutex + interceptors map[string]Interceptor } var _ Database = &sqliteDatabase{} @@ -345,6 +566,9 @@ func (db *sqliteDatabase) getConn(ctx context.Context) *sqlite.Conn { } func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx, error) { + if db.closed { + return nil, ErrClosed + } conn := db.getConn(ctx) if conn == nil { return nil, ErrNoConnection @@ -356,12 +580,14 @@ func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx return tx, nil } -func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func(Transaction) error) error { +func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func(Transaction) error) (err error) { tx, err := db.getTx(ctx, initstmt) if err != nil { return err } - defer tx.Release() + defer func() { + err = errors.Join(err, tx.Release()) + }() if err := exec(tx); err != nil { tx.queryCache.ClearCache() return err @@ -404,6 +630,17 @@ func (db *sqliteDatabase) WithTxImmediate( return db.withTx(ctx, beginImmediate, exec) } +func (db *sqliteDatabase) runInterceptors(query string) error { + db.interceptMtx.Lock() + defer db.interceptMtx.Unlock() + for _, interceptFn := range db.interceptors { + if err := interceptFn(query); err != nil { + return err + } + } + return nil +} + // Exec statement using one of the connection from the pool. // // If you care about atomicity of the operation (for example writing rewards to multiple accounts) @@ -413,6 +650,13 @@ func (db *sqliteDatabase) WithTxImmediate( // Note that Exec will block until database is closed or statement has finished. // If application needs to control statement execution lifetime use one of the transaction. func (db *sqliteDatabase) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { + if err := db.runInterceptors(query); err != nil { + return 0, err + } + + if db.closed { + return 0, ErrClosed + } db.queryCount.Add(1) conn := db.getConn(context.Background()) if conn == nil { @@ -442,6 +686,163 @@ func (db *sqliteDatabase) Close() error { return nil } +// Intercept adds an interceptor function to the database. The interceptor functions +// are invoked upon each query. The query will fail if the interceptor returns an error. +// The interceptor can later be removed using RemoveInterceptor with the same key. +func (db *sqliteDatabase) Intercept(key string, fn Interceptor) { + db.interceptMtx.Lock() + defer db.interceptMtx.Unlock() + if db.interceptors == nil { + db.interceptors = make(map[string]Interceptor) + } + db.interceptors[key] = fn +} + +// PopIntercept removes the interceptor function with specified key from the database. +// If there's no such interceptor, the function does nothing. +func (db *sqliteDatabase) RemoveInterceptor(key string) { + db.interceptMtx.Lock() + defer db.interceptMtx.Unlock() + delete(db.interceptors, key) +} + +// vacuumInto runs VACUUM INTO on the database and saves the vacuumed +// database at toPath +func (db *sqliteDatabase) vacuumInto(toPath string) error { + if _, err := db.Exec("VACUUM INTO ?1", func(stmt *Statement) { + stmt.BindText(1, toPath) + }, nil); err != nil { + return fmt.Errorf("vacuum into %s: %w", toPath, err) + } + return nil +} + +// copyMigrateDB performs a copy-based migration of the database. +// The source database is always closed by this function. +// Upon success, the migrated database is opened. +func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, err error) { + defer func() { + err = errors.Join(err, db.Close()) + }() + + dbPath, migratedPath, err := dbMigrationPaths(config.uri) + if err != nil { + return nil, err + } + if migratedPath == "" { + return nil, fmt.Errorf("cannot migrate database, only file DBs are supported: %s", config.uri) + } + + // Instead of just copying the source database to the temporary migration DB, use VACUUM INTO. + // This is somewhat slower but achieves two goals: + // 1. The lock is held on the source database while it's being copied + // 2. If the source database has a lot of free pages for whatever reason, those + // are not copied, saving disk space + config.logger.Info("making a temporary copy of the database", + zap.String("path", dbPath), + zap.String("target", migratedPath)) + if err := db.vacuumInto(migratedPath); err != nil { + return nil, errors.Join(err, deleteDB(migratedPath)) + } + + // Opening the migrated DB runs the actual migrations on it. + // We disable vacuuming here because we're going to vacuum the migrated DB + // into the original one. + opts := []Opt{ + WithLogger(config.logger), + WithConnections(1), + WithTemp(), + WithDatabaseSchema(config.schema), + } + if config.ignoreSchemaDrift { + opts = append(opts, WithIgnoreSchemaDrift()) + } + migratedDB, err := Open("file:"+migratedPath, opts...) + if err != nil { + return nil, errors.Join( + fmt.Errorf("process migrated DB %s: %w", migratedPath, err), + deleteDB(migratedPath)) + } + tempDBReady := false + defer func() { + err = errors.Join(err, migratedDB.Close()) + if !tempDBReady { + err = errors.Join(err, deleteDB(migratedPath)) + } + }() + + // Make sure the migrated DB is fully synced to the disk before creating the marker file. + // We don't need wal_checkpoint(TRUNCATE) here as we're going to delete the migrated DB. + if _, err := migratedDB.Exec("PRAGMA wal_checkpoint(FULL)", nil, nil); err != nil { + return nil, fmt.Errorf("checkpoint migrated DB %s: %w", migratedPath, err) + } + + // Create the marker file to indicate that the migration is complete. + // Make sure the file is written to the disk before closing the database. + // We could create a table in the temporary database instead of the marker file, + // but as the temporary database is opened without PRAGMA journal_mode=OFF + // and PRAGMA synchronous=OFF, it may become corrupt in case of a crash or power + // outage, so we avoid trying to open it. + markerPath := migratedPath + "_done" + if f, err := os.Create(markerPath); err != nil { + return nil, fmt.Errorf("create marker file %s_done: %w", migratedPath, err) + } else { + if err := f.Sync(); err != nil { + return nil, fmt.Errorf("sync/close marker file %s_done: %w", migratedPath, + errors.Join(err, f.Close())) + } + if err := f.Close(); err != nil { + return nil, fmt.Errorf("close marker file %s: %w", markerPath, err) + } + // The temporary database is complete and should not be deleted + // until we copy it to the original database location. + tempDBReady = true + } + + // We only close the source database at the end of the migration process + // so that the lock is held. There's a possibility that right after we + // close the source database, another process will see the migrated database + // and the marker file and will try to open the migrated database. If the + if err := db.Close(); err != nil { + return nil, fmt.Errorf("close db: %w", err) + } + + // Delete the original database. VACUUM INTO will fail if the destination + // database exists. + if err := deleteDB(dbPath); err != nil { + return nil, fmt.Errorf("delete original DB %s: %w", dbPath, err) + } + + // Overwrite the original database with the migrated one. + // The lock is held on the migrated DB during this, preventing concurrent + // go-spacemesh instances to attempt the same operation. + config.logger.Info("moving migrated DB to original location", zap.String("path", dbPath)) + if err := migratedDB.vacuumInto(dbPath); err != nil { + return nil, err + } + + // Open the final DB before deleting the source DB, so one of the locks + // is always held. The migrations are already run, so we're disabling them. + config.enableMigrations = false + finalDB, err = openDB(config) + if err != nil { + return nil, fmt.Errorf("open final DB %s: %w", config.uri, err) + } + + if err := migratedDB.Close(); err != nil { + finalDB.Close() + return nil, fmt.Errorf("close migrated DB %s: %w", migratedPath, err) + } + + // Now we can delete the migrated DB and the marker file. + if err := deleteDB(migratedPath); err != nil { + finalDB.Close() + return nil, err + } + + return finalDB, err +} + // QueryCount returns the number of queries executed, including failed // queries, but not counting transaction start / commit / rollback. func (db *sqliteDatabase) QueryCount() int { @@ -453,7 +854,7 @@ func (db *sqliteDatabase) QueryCache() QueryCache { return db.queryCache } -func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (int, error) { +func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (nRows int, err error) { stmt, err := conn.Prepare(query) if err != nil { return 0, fmt.Errorf("prepare %s: %w", query, err) @@ -461,7 +862,9 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in if encoder != nil { encoder(stmt) } - defer stmt.ClearBindings() + defer func() { + err = errors.Join(err, stmt.ClearBindings()) + }() rows := 0 for { @@ -528,6 +931,10 @@ func (tx *sqliteTx) Release() error { // Exec query. func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { + if err := tx.db.runInterceptors(query); err != nil { + return 0, err + } + tx.db.queryCount.Add(1) if tx.db.latency != nil { start := time.Now() diff --git a/sql/database_test.go b/sql/database_test.go index 3b899ff779..27aeb390ec 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -5,6 +5,7 @@ import ( "errors" "os" "path/filepath" + "strings" "testing" "github.com/stretchr/testify/require" @@ -109,6 +110,7 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1, migration2}, }), + WithIgnoreSchemaDrift(), ) require.ErrorContains(t, err, "migration 2 failed") } @@ -167,47 +169,321 @@ func TestDatabaseSkipMigrations(t *testing.T) { require.NoError(t, db.Close()) } +func execSQL(t *testing.T, db Executor, sql string, col int) (result string) { + _, err := db.Exec(sql, nil, func(stmt *Statement) bool { + if col >= 0 { + result = stmt.ColumnText(col) + } + return true + }) + require.NoError(t, err) + return result +} + func TestDatabaseVacuumState(t *testing.T) { dir := t.TempDir() logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) + + // The first migration is done without vacuuming and thus it is performed + // in-place. migration1 := NewMockMigration(ctrl) migration1.EXPECT().Order().Return(1).AnyTimes() - migration1.EXPECT().Apply(gomock.Any()).Return(nil).Times(1) + migration1.EXPECT().Apply(gomock.Any()).DoAndReturn(func(db Executor) error { + require.NotContains(t, execSQL(t, db, "PRAGMA database_list", 2), "_migrate") + require.Equal(t, execSQL(t, db, "PRAGMA journal_mode", 0), "wal") + require.Equal(t, execSQL(t, db, "PRAGMA synchronous", 0), "1") // NORMAL + execSQL(t, db, "create table foo(x int)", -1) + return nil + }).Times(1) migration2 := NewMockMigration(ctrl) migration2.EXPECT().Order().Return(2).AnyTimes() - migration2.EXPECT().Apply(gomock.Any()).Return(nil).Times(1) + migration2.EXPECT().Apply(gomock.Any()).DoAndReturn(func(db Executor) error { + // We must be operating on a temp database. + require.Contains(t, execSQL(t, db, "PRAGMA database_list", 2), "_migrate") + // Journaling is off for the temp database as it is deleted in case + // of migration failure. + require.Equal(t, execSQL(t, db, "PRAGMA journal_mode", 0), "off") + // Synchronous is off for the temp database as it is deleted in case + // of migration failure. + require.Equal(t, execSQL(t, db, "PRAGMA synchronous", 0), "0") // OFF + execSQL(t, db, "create table bar(y int)", -1) + return nil + }).Times(1) dbFile := filepath.Join(dir, "test.sql") db, err := Open("file:"+dbFile, WithLogger(logger), WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 1;\n" + + "CREATE TABLE foo(x int);\n", Migrations: MigrationList{migration1}, }), WithForceMigrations(true), - WithIgnoreSchemaDrift(), ) require.NoError(t, err) + execSQL(t, db, "select * from foo", -1) // ensure table exists require.NoError(t, db.Close()) db, err = Open("file:"+dbFile, WithLogger(logger), WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 2;\n" + + "CREATE TABLE bar(y int);\n" + + "CREATE TABLE foo(x int);\n", Migrations: MigrationList{migration1, migration2}, }), WithVacuumState(2), - WithIgnoreSchemaDrift(), ) require.NoError(t, err) + execSQL(t, db, "select * from foo, bar", -1) require.NoError(t, db.Close()) - // we run pragma wal_checkpoint(TRUNCATE) after vacuum, which drops the wal file + // The wal file should be absent after the database is re-created + // with VACUUM INTO _, err = os.Open(dbFile + "-wal") require.ErrorIs(t, err, os.ErrNotExist) } +func TestDatabaseVacuumStateError(t *testing.T) { + dir := t.TempDir() + logger := zaptest.NewLogger(t) + + ctrl := gomock.NewController(t) + + migration1 := &sqlMigration{ + order: 1, + name: "0001_initial.sql", + content: "create table foo(x int)", + } + + fail := true + migration2 := NewMockMigration(ctrl) + migration2.EXPECT().Name().Return("0002_test.sql").AnyTimes() + migration2.EXPECT().Order().Return(2).AnyTimes() + migration2.EXPECT().Apply(gomock.Any()).DoAndReturn(func(db Executor) error { + if fail { + return errors.New("migration failed") + } + execSQL(t, db, "create table bar(y int)", -1) + return nil + }).Times(2) + + dbFile := filepath.Join(dir, "test.sql") + db, err := Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 1;\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1}, + }), + ) + require.NoError(t, err) + execSQL(t, db, "select * from foo", -1) // ensure table exists + require.NoError(t, db.Close()) + + schema := &Schema{ + Script: "PRAGMA user_version = 2;\n" + + "CREATE TABLE bar(y int);\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1, migration2}, + } + db, err = Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + require.Error(t, err) + + // All temporary files need to be deleted upon migration failure. + tmpDBFiles, err := filepath.Glob(filepath.Join(dir, "*_migrate*")) + require.NoError(t, err) + require.Empty(t, tmpDBFiles) + + // Make sure the initial DB is intact after failed migration, + // and the 2nd migration is applied on the second attempt. + fail = false + db, err = Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + require.NoError(t, err) + execSQL(t, db, "select * from foo, bar", -1) + require.NoError(t, db.Close()) +} + +// faultyMigration is a migration that can be configured to panic during Apply. +// We don't use mock for this as it's not entirely clear what happens if a mocked method +// panics. +type faultyMigration struct { + panic, interceptVacuumInto bool + *sqlMigration +} + +func (m *faultyMigration) Apply(db Executor) error { + if m.interceptVacuumInto { + db.(Database).Intercept("crashOnVacuum", func(query string) error { + if strings.Contains(strings.ToLower(query), "vacuum into") { + panic("simulated crash") + } + return nil + }) + } + if m.panic { + panic("simulated crash") + } + return m.sqlMigration.Apply(db) +} + +func TestDropIncompleteMigration(t *testing.T) { + dir := t.TempDir() + logger := zaptest.NewLogger(t) + migration1 := &sqlMigration{ + order: 1, + name: "0001_initial.sql", + content: "create table foo(x int)", + } + migration2 := &faultyMigration{ + panic: true, + sqlMigration: &sqlMigration{ + order: 2, + name: "0002_test.sql", + content: "create table bar(y int)", + }, + } + + dbFile := filepath.Join(dir, "test.sql") + db, err := Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 1;\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + ) + require.NoError(t, err) + require.NoError(t, db.Close()) + + schema := &Schema{ + Script: "PRAGMA user_version = 2;\n" + + "CREATE TABLE bar(y int);\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1, migration2}, + } + + func() { + defer func() { + require.NotNil(t, recover()) + }() + Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + }() + + // Check that temporary database exists after the simulated crash. + // Note that we're checking "*_migrate" not "*_migrate*" to avoid matching + // any erroneously created successful migration markers. + tmpDBFiles, err := filepath.Glob(filepath.Join(dir, "*_migrate")) + require.NoError(t, err) + require.NotEmpty(t, tmpDBFiles) + + // Retry migration. The incompletely migrated temporary database should be dropped. + migration2.panic = false + db, err = Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + require.NoError(t, err) + execSQL(t, db, "select * from foo, bar", -1) + require.NoError(t, db.Close()) +} + +func TestResumeCopyMigration(t *testing.T) { + dir := t.TempDir() + logger := zaptest.NewLogger(t) + migration1 := &sqlMigration{ + order: 1, + name: "0001_initial.sql", + content: "create table foo(x int)", + } + // This migration will panic when VACUUM INTO is attempted to copy + // the migrated database to the source database location. + migration2 := &faultyMigration{ + interceptVacuumInto: true, + sqlMigration: &sqlMigration{ + order: 2, + name: "0002_test.sql", + content: "create table bar(y int)", + }, + } + + dbFile := filepath.Join(dir, "test.sql") + db, err := Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(&Schema{ + Script: "PRAGMA user_version = 1;\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1}, + }), + WithForceMigrations(true), + ) + require.NoError(t, err) + require.NoError(t, db.Close()) + + schema := &Schema{ + Script: "PRAGMA user_version = 2;\n" + + "CREATE TABLE bar(y int);\n" + + "CREATE TABLE foo(x int);\n", + Migrations: MigrationList{migration1, migration2}, + } + + func() { + defer func() { + require.NotNil(t, recover()) + }() + Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + }() + + // Check that temporary database exists after the simulated crash. + tmpDBFiles, err := filepath.Glob(filepath.Join(dir, "*")) + t.Logf("tmpDBFiles: %v", tmpDBFiles) + require.NoError(t, err) + require.NotEmpty(t, tmpDBFiles) + + // Retry migration. The migrated database should be copied + // to the source database location without invoking any further + // migrations. As the migration with fault injection is not called, + // the final VACUUM INTO must succeed. + db, err = Open("file:"+dbFile, + WithLogger(logger), + WithDatabaseSchema(schema), + WithVacuumState(2), + ) + require.NoError(t, err) + execSQL(t, db, "select * from foo, bar", -1) + require.NoError(t, db.Close()) +} + +func TestDBClosed(t *testing.T) { + db := InMemory(WithLogger(zaptest.NewLogger(t)), WithIgnoreSchemaDrift()) + require.NoError(t, db.Close()) + _, err := db.Exec("select 1", nil, nil) + require.ErrorIs(t, err, ErrClosed) + err = db.WithTx(context.Background(), func(tx Transaction) error { return nil }) + require.ErrorIs(t, err, ErrClosed) +} + func TestQueryCount(t *testing.T) { db := InMemory(WithLogger(zaptest.NewLogger(t)), WithIgnoreSchemaDrift()) require.Equal(t, 0, db.QueryCount()) @@ -302,3 +578,6 @@ func TestSchemaDrift(t *testing.T) { require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, observedLogs.All()[0].ContextMap()["diff"]) } + +// TBD: test WAL modes for temp DB +// TBD: remove SQLITE_OPEN_WAL from open flags and check journal mode diff --git a/sql/localsql/localsql_test.go b/sql/localsql/localsql_test.go index 320702d7eb..08af28c66d 100644 --- a/sql/localsql/localsql_test.go +++ b/sql/localsql/localsql_test.go @@ -39,7 +39,7 @@ func TestIdempotentMigration(t *testing.T) { require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") l := observedLogs.All()[0] - require.Equal(t, "running migrations", l.Message) + require.Equal(t, "running migrations in-place", l.Message) require.Equal(t, int64(0), l.ContextMap()["current version"]) require.Equal(t, int64(versionA), l.ContextMap()["target version"]) diff --git a/sql/migrations.go b/sql/migrations.go index 9152fcf189..c75acf21df 100644 --- a/sql/migrations.go +++ b/sql/migrations.go @@ -64,10 +64,6 @@ func (m *sqlMigration) Apply(db Executor) error { } } } - // binding values in pragma statement is not allowed - if _, err := db.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.order), nil, nil); err != nil { - return fmt.Errorf("update user_version to %d: %w", m.order, err) - } return nil } diff --git a/sql/schema.go b/sql/schema.go index 04c0d51fbd..b520a1c8c5 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -109,21 +109,26 @@ func (s *Schema) CheckDBVersion(logger *zap.Logger, db Database) (before, after if len(s.Migrations) > 0 { after = s.Migrations.Version() } - if before > after { - logger.Error("database version is newer than expected - downgrade is not supported", - zap.Int("current version", before), - zap.Int("target version", after), - ) - return before, after, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) - } return before, after, nil } +func (s *Schema) setVersion(db Executor, version int) error { + // binding values in pragma statement is not allowed + if _, err := db.Exec(fmt.Sprintf("PRAGMA user_version = %d;", version), nil, nil); err != nil { + return fmt.Errorf("update user_version to %d: %w", version, err) + } + return nil +} + // Migrate performs database migration. In case if migrations are disabled, the database // version is checked but no migrations are run, and if the database is too old and // migrations are disabled, an error is returned. func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState int) error { + if logger.Core().Enabled(zap.DebugLevel) { + db.Intercept("logQueries", logQueryInterceptor(logger)) + defer db.RemoveInterceptor("logQueries") + } for i, m := range s.Migrations { if m.Order() <= before { continue @@ -141,20 +146,17 @@ func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState in return fmt.Errorf("apply %s: %w", m.Name(), err) } } - // version is set intentionally even if actual migration was skipped - if _, err := tx.Exec(fmt.Sprintf("PRAGMA user_version = %d;", m.Order()), nil, nil); err != nil { - return fmt.Errorf("update user_version to %d: %w", m.Order(), err) + if err := s.setVersion(tx, m.Order()); err != nil { + return err } return nil }); err != nil { - err = errors.Join(err, db.Close()) - return err + return errors.Join(err, db.Close()) } if vacuumState != 0 && before <= vacuumState { if err := Vacuum(db); err != nil { - err = errors.Join(err, db.Close()) - return err + return errors.Join(err, db.Close()) } } before = m.Order() @@ -162,6 +164,53 @@ func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState in return nil } +// MigrateTempDB performs database migration on the temporary database. +// It doesn't use transactions and the temporary database should be considered +// invalid and discarded if it fails. +// The database is switched into synchronous mode with WAL journal enabled and +// synced after the migrations are completed before setting the database version, +// which triggers file sync. +func (s *Schema) MigrateTempDB(logger *zap.Logger, db Database, before int) error { + if logger.Core().Enabled(zap.DebugLevel) { + db.Intercept("logQueries", logQueryInterceptor(logger)) + defer db.RemoveInterceptor("logQueries") + } + v := before + for _, m := range s.Migrations { + if m.Order() <= v { + continue + } + + if _, ok := s.skipMigration[m.Order()]; !ok { + if err := m.Apply(db); err != nil { + return errors.Join(fmt.Errorf("apply %s: %w", m.Name(), err), db.Close()) + } + } + + // We don't set the version here as if any migration fails, + // the temporary database is considered invalid and should be discarded. + v = m.Order() + } + + logger.Info("syncing temporary database") + + // Enable synchronous mode and WAL journal to ensure the database is synced + if _, err := db.Exec("PRAGMA journal_mode=WAL", nil, nil); err != nil { + return fmt.Errorf("PRAGMA journal_mode=WAL: %w", err) + } + + if _, err := db.Exec("PRAGMA synchronous=FULL", nil, nil); err != nil { + return fmt.Errorf("PRAGMA journal_mode=OFF: %w", err) + } + + // This should trigger file sync + if err := s.setVersion(db, v); err != nil { + return err + } + + return nil +} + // SchemaGenOpt represents a schema generator option. type SchemaGenOpt func(g *SchemaGen) @@ -216,3 +265,14 @@ func (g *SchemaGen) Generate(outputFile string) error { } return nil } + +func logQueryInterceptor(logger *zap.Logger) Interceptor { + return func(query string) error { + query = strings.TrimSpace(query) + if p := strings.Index(query, "\n"); p >= 0 { + query = query[:p] + } + logger.Debug("executing query", zap.String("query", query)) + return nil + } +} diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go index e3814ba5e1..c7a2c3b6d0 100644 --- a/sql/statesql/statesql_test.go +++ b/sql/statesql/statesql_test.go @@ -39,7 +39,7 @@ func TestIdempotentMigration(t *testing.T) { require.Equal(t, 1, observedLogs.Len(), "expected 1 log messages") l := observedLogs.All()[0] - require.Equal(t, "running migrations", l.Message) + require.Equal(t, "running migrations in-place", l.Message) require.Equal(t, int64(0), l.ContextMap()["current version"]) require.Equal(t, int64(versionA), l.ContextMap()["target version"]) From 1df9c30ee550ab7bf58134a9f3b5209e9302b09e Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 28 Jun 2024 09:27:29 +0400 Subject: [PATCH 40/62] sql: fix issues found by linter --- sql/database.go | 2 +- sql/database_test.go | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/database.go b/sql/database.go index 964d07176c..92eebea8f0 100644 --- a/sql/database.go +++ b/sql/database.go @@ -707,7 +707,7 @@ func (db *sqliteDatabase) RemoveInterceptor(key string) { } // vacuumInto runs VACUUM INTO on the database and saves the vacuumed -// database at toPath +// database at toPath. func (db *sqliteDatabase) vacuumInto(toPath string) error { if _, err := db.Exec("VACUUM INTO ?1", func(stmt *Statement) { stmt.BindText(1, toPath) diff --git a/sql/database_test.go b/sql/database_test.go index 27aeb390ec..de3efd2f2a 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -192,8 +192,8 @@ func TestDatabaseVacuumState(t *testing.T) { migration1.EXPECT().Order().Return(1).AnyTimes() migration1.EXPECT().Apply(gomock.Any()).DoAndReturn(func(db Executor) error { require.NotContains(t, execSQL(t, db, "PRAGMA database_list", 2), "_migrate") - require.Equal(t, execSQL(t, db, "PRAGMA journal_mode", 0), "wal") - require.Equal(t, execSQL(t, db, "PRAGMA synchronous", 0), "1") // NORMAL + require.Equal(t, "wal", execSQL(t, db, "PRAGMA journal_mode", 0)) + require.Equal(t, "1", execSQL(t, db, "PRAGMA synchronous", 0)) // NORMAL execSQL(t, db, "create table foo(x int)", -1) return nil }).Times(1) @@ -205,10 +205,10 @@ func TestDatabaseVacuumState(t *testing.T) { require.Contains(t, execSQL(t, db, "PRAGMA database_list", 2), "_migrate") // Journaling is off for the temp database as it is deleted in case // of migration failure. - require.Equal(t, execSQL(t, db, "PRAGMA journal_mode", 0), "off") + require.Equal(t, "off", execSQL(t, db, "PRAGMA journal_mode", 0)) // Synchronous is off for the temp database as it is deleted in case // of migration failure. - require.Equal(t, execSQL(t, db, "PRAGMA synchronous", 0), "0") // OFF + require.Equal(t, "0", execSQL(t, db, "PRAGMA synchronous", 0)) // OFF execSQL(t, db, "create table bar(y int)", -1) return nil }).Times(1) @@ -290,7 +290,7 @@ func TestDatabaseVacuumStateError(t *testing.T) { "CREATE TABLE foo(x int);\n", Migrations: MigrationList{migration1, migration2}, } - db, err = Open("file:"+dbFile, + _, err = Open("file:"+dbFile, WithLogger(logger), WithDatabaseSchema(schema), WithVacuumState(2), From 2c291035d0d4dbb93b3e339c7fb4b22ec578e65c Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 29 Jun 2024 05:03:31 +0400 Subject: [PATCH 41/62] sql: fix review comments --- sql/database.go | 53 +++++++++++++++++++++++--------------------- sql/database_test.go | 14 ++++-------- 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/sql/database.go b/sql/database.go index 92eebea8f0..c982e6bfaa 100644 --- a/sql/database.go +++ b/sql/database.go @@ -294,6 +294,10 @@ func openDB(config *conf) (*sqliteDatabase, error) { db.Close()) } } else if db, err = ensureDBSchemaUpToDate(logger, db, config); err != nil { + // ensureDBSchemaUpToDate may replace the original database and open the new one, + // in which case the original db is already closed but we must close the new one. + // If there are migrations to be done in place without vacuuming, + // the original db is returned and we must close it if there's an error. if db != nil { err = errors.Join(err, db.Close()) } @@ -412,13 +416,13 @@ func deleteDB(path string) error { // replaces the database at toPath with the vacuumed one. The database // at fromPath is deleted after the operation. func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { - config.logger.Warn("finalizing migration by moving the migrated DB to the original path", + config.logger.Warn("finalizing migration by moving the temporary DB to the original path", zap.String("fromPath", fromPath), zap.String("toPath", toPath)) - // Try to open the migrated DB before deleting the original one. - // If the migrated DB is being copied to the original path by another + // Try to open the temporary migrated DB before deleting the original one. + // If the temporary DB is being copied to the original path by another // process, this will fail and the original database will not be deleted. - // We don't use the proper database schema here because the migrated DB + // We don't use the proper database schema here because the temporary DB // may have been created with a different set of migrations. db, err := Open("file:"+fromPath, WithLogger(config.logger), @@ -426,7 +430,7 @@ func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { WithTemp(), WithIgnoreSchemaDrift()) if err != nil { - return fmt.Errorf("open migrated DB %s: %w", fromPath, err) + return fmt.Errorf("open temporary DB %s: %w", fromPath, err) } if err := deleteDB(toPath); err != nil { return err @@ -434,9 +438,9 @@ func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { if err := db.vacuumInto(toPath); err != nil { return errors.Join(err, db.Close()) } - // Open the vacuumed DB to avoid race condition when another process - // also tries to vacuum the migrated DB into the original path after - // we close the migrated DB. + // Open the freshly vacuumed DB to avoid race condition when another process + // also tries to vacuum the temporary DB into the original path after + // we close the temporary DB. origDB, err := Open("file:"+toPath, WithLogger(config.logger), WithConnections(1), @@ -447,17 +451,16 @@ func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { return fmt.Errorf("open vacuumed DB %s: %w", toPath, err) } defer func() { - err = errors.Join(err, origDB.Close()) + if closeErr := origDB.Close(); closeErr != nil { + err = errors.Join(err, fmt.Errorf("close DB %s after migration: %w", toPath, closeErr)) + } }() if err := db.Close(); err != nil { - return fmt.Errorf("close migrated DB %s: %w", fromPath, err) + return fmt.Errorf("close temporary DB %s: %w", fromPath, err) } if err := deleteDB(fromPath); err != nil { return err } - if err := origDB.Close(); err != nil { - return fmt.Errorf("close vacuumed DB %s: %w", toPath, err) - } return nil } @@ -502,9 +505,9 @@ func handleIncompleteCopyMigration(config *conf) error { } if _, err := os.Stat(migratedPath + "_done"); err != nil { if errors.Is(err, os.ErrNotExist) { - // incomplete migration, delete the migrated DB to start over + // incomplete migration, delete the temporary DB to start over // after that - config.logger.Warn("incomplete migration detected, deleting temporary migrated DB", + config.logger.Warn("incomplete migration detected, deleting the temporary DB", zap.String("path", migratedPath)) return deleteDB(migratedPath) } @@ -745,8 +748,8 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, errors.Join(err, deleteDB(migratedPath)) } - // Opening the migrated DB runs the actual migrations on it. - // We disable vacuuming here because we're going to vacuum the migrated DB + // Opening the temporary migrated DB runs the actual migrations on it. + // We disable vacuuming here because we're going to vacuum the temporary DB // into the original one. opts := []Opt{ WithLogger(config.logger), @@ -760,7 +763,7 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, migratedDB, err := Open("file:"+migratedPath, opts...) if err != nil { return nil, errors.Join( - fmt.Errorf("process migrated DB %s: %w", migratedPath, err), + fmt.Errorf("process temporary DB %s: %w", migratedPath, err), deleteDB(migratedPath)) } tempDBReady := false @@ -771,10 +774,10 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, } }() - // Make sure the migrated DB is fully synced to the disk before creating the marker file. - // We don't need wal_checkpoint(TRUNCATE) here as we're going to delete the migrated DB. + // Make sure the temporary DB is fully synced to the disk before creating the marker file. + // We don't need wal_checkpoint(TRUNCATE) here as we're going to delete the temporary DB. if _, err := migratedDB.Exec("PRAGMA wal_checkpoint(FULL)", nil, nil); err != nil { - return nil, fmt.Errorf("checkpoint migrated DB %s: %w", migratedPath, err) + return nil, fmt.Errorf("checkpoint temporary DB %s: %w", migratedPath, err) } // Create the marker file to indicate that the migration is complete. @@ -814,9 +817,9 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, } // Overwrite the original database with the migrated one. - // The lock is held on the migrated DB during this, preventing concurrent + // The lock is held on the temporary DB during this, preventing concurrent // go-spacemesh instances to attempt the same operation. - config.logger.Info("moving migrated DB to original location", zap.String("path", dbPath)) + config.logger.Info("moving the temporary DB to original location", zap.String("path", dbPath)) if err := migratedDB.vacuumInto(dbPath); err != nil { return nil, err } @@ -831,10 +834,10 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, if err := migratedDB.Close(); err != nil { finalDB.Close() - return nil, fmt.Errorf("close migrated DB %s: %w", migratedPath, err) + return nil, fmt.Errorf("close temporary DB %s: %w", migratedPath, err) } - // Now we can delete the migrated DB and the marker file. + // Now we can delete the temporary DB and the marker file. if err := deleteDB(migratedPath); err != nil { finalDB.Close() return nil, err diff --git a/sql/database_test.go b/sql/database_test.go index de3efd2f2a..ac9c8899ef 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -375,16 +375,13 @@ func TestDropIncompleteMigration(t *testing.T) { Migrations: MigrationList{migration1, migration2}, } - func() { - defer func() { - require.NotNil(t, recover()) - }() + require.Panics(t, func() { Open("file:"+dbFile, WithLogger(logger), WithDatabaseSchema(schema), WithVacuumState(2), ) - }() + }) // Check that temporary database exists after the simulated crash. // Note that we're checking "*_migrate" not "*_migrate*" to avoid matching @@ -444,16 +441,13 @@ func TestResumeCopyMigration(t *testing.T) { Migrations: MigrationList{migration1, migration2}, } - func() { - defer func() { - require.NotNil(t, recover()) - }() + require.Panics(t, func() { Open("file:"+dbFile, WithLogger(logger), WithDatabaseSchema(schema), WithVacuumState(2), ) - }() + }) // Check that temporary database exists after the simulated crash. tmpDBFiles, err := filepath.Glob(filepath.Join(dir, "*")) From 5b47e76a6e5b9723f118a456e5d24a2daac367ba Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Mon, 8 Jul 2024 18:05:22 +0400 Subject: [PATCH 42/62] sql: fix failing tests --- sql/atxs/atxs_test.go | 6 +++--- sql/identities/identities_test.go | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 762aa5d794..ebb82b29b8 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1128,13 +1128,13 @@ func TestUnits(t *testing.T) { t.Parallel() t.Run("ATX not found", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() _, err := atxs.Units(db, types.RandomATXID(), types.RandomNodeID()) require.ErrorIs(t, err, sql.ErrNotFound) }) t.Run("smesher has no units in ATX", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atxID := types.RandomATXID() require.NoError(t, atxs.SetUnits(db, atxID, types.RandomNodeID(), 10)) _, err := atxs.Units(db, atxID, types.RandomNodeID()) @@ -1142,7 +1142,7 @@ func TestUnits(t *testing.T) { }) t.Run("returns units for given smesher in given ATX", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() atxID := types.RandomATXID() units := map[types.NodeID]uint32{ {1, 2, 3}: 10, diff --git a/sql/identities/identities_test.go b/sql/identities/identities_test.go index 880a09fdbc..b1b6b30ee9 100644 --- a/sql/identities/identities_test.go +++ b/sql/identities/identities_test.go @@ -161,7 +161,7 @@ func TestMarriageATX(t *testing.T) { t.Parallel() t.Run("not married", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() _, err := MarriageATX(db, id) @@ -169,7 +169,7 @@ func TestMarriageATX(t *testing.T) { }) t.Run("married", func(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() marriage := MarriageData{ @@ -188,7 +188,7 @@ func TestMarriageATX(t *testing.T) { func TestMarriage(t *testing.T) { t.Parallel() - db := sql.InMemory() + db := statesql.InMemory() id := types.RandomNodeID() marriage := MarriageData{ @@ -318,7 +318,7 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { t.Parallel() t.Run("married IDs", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() ids := []types.NodeID{ types.RandomNodeID(), types.RandomNodeID(), @@ -334,7 +334,7 @@ func TestEquivocationSetByMarriageATX(t *testing.T) { require.Equal(t, ids, set) }) t.Run("empty set", func(t *testing.T) { - db := sql.InMemory() + db := statesql.InMemory() set, err := EquivocationSetByMarriageATX(db, types.RandomATXID()) require.NoError(t, err) require.Empty(t, set) From dcab0566d3c61d0c2e50ca12a4ee784ae3488bb0 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 13 Aug 2024 14:51:18 +0400 Subject: [PATCH 43/62] statesql: update schema --- sql/statesql/schema/schema.sql | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/statesql/schema/schema.sql b/sql/statesql/schema/schema.sql index b20bc7b9b0..3eee310600 100755 --- a/sql/statesql/schema/schema.sql +++ b/sql/statesql/schema/schema.sql @@ -25,7 +25,6 @@ CREATE UNIQUE INDEX atx_blobs_id ON atx_blobs (id); CREATE TABLE atxs ( id CHAR(32), - prev_id CHAR(32), epoch INT NOT NULL, effective_num_units INT NOT NULL, commitment_atx CHAR(32), @@ -37,7 +36,7 @@ CREATE TABLE atxs coinbase CHAR(24), received INT NOT NULL, validity INTEGER DEFAULT false -, weight INTEGER); +, marriage_atx CHAR(32), weight INTEGER); CREATE INDEX atxs_by_coinbase ON atxs (coinbase); CREATE INDEX atxs_by_epoch_by_pubkey ON atxs (epoch, pubkey); CREATE INDEX atxs_by_epoch_by_pubkey_nonce ON atxs (pubkey, epoch desc, nonce) WHERE nonce IS NOT NULL; @@ -108,10 +107,12 @@ CREATE INDEX poets_by_service_id_by_round_id ON poets (service_id, round_id); CREATE TABLE posts ( atxid CHAR(32) NOT NULL, pubkey CHAR(32) NOT NULL, + prev_atxid CHAR(32), + prev_atx_index INT, units INT NOT NULL, UNIQUE (atxid, pubkey) ); -CREATE INDEX posts_by_atxid_by_pubkey ON posts (atxid, pubkey); +CREATE INDEX posts_by_atxid_by_pubkey ON posts (atxid, pubkey, prev_atxid); CREATE TABLE proposal_transactions ( tid CHAR(32), From 1d8d3103606c40416720c52e3b525c4b50eb5db9 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 13 Aug 2024 14:53:28 +0400 Subject: [PATCH 44/62] tmp: fix lint errors (will need to revert) --- activation/wire/wire_v2_test.go | 82 ++++++++++++++++----------------- 1 file changed, 40 insertions(+), 42 deletions(-) diff --git a/activation/wire/wire_v2_test.go b/activation/wire/wire_v2_test.go index e0303affb0..9c3cd5f3f6 100644 --- a/activation/wire/wire_v2_test.go +++ b/activation/wire/wire_v2_test.go @@ -1,58 +1,56 @@ package wire import ( - "math/rand/v2" "testing" fuzz "github.com/google/gofuzz" "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" - "github.com/spacemeshos/go-spacemesh/signing" ) -type testAtxV2Opt func(*ActivationTxV2) +// type testAtxV2Opt func(*ActivationTxV2) -func withMarriageCertificate(sig *signing.EdSigner, refAtx types.ATXID, atxPublisher types.NodeID) testAtxV2Opt { - return func(atx *ActivationTxV2) { - certificate := MarriageCertificate{ - ReferenceAtx: refAtx, - Signature: sig.Sign(signing.MARRIAGE, atxPublisher.Bytes()), - } - atx.Marriages = append(atx.Marriages, certificate) - } -} +// func withMarriageCertificate(sig *signing.EdSigner, refAtx types.ATXID, atxPublisher types.NodeID) testAtxV2Opt { +// return func(atx *ActivationTxV2) { +// certificate := MarriageCertificate{ +// ReferenceAtx: refAtx, +// Signature: sig.Sign(signing.MARRIAGE, atxPublisher.Bytes()), +// } +// atx.Marriages = append(atx.Marriages, certificate) +// } +// } -func newActivationTxV2(opts ...testAtxV2Opt) *ActivationTxV2 { - atx := &ActivationTxV2{ - PublishEpoch: rand.N(types.EpochID(255)), - PositioningATX: types.RandomATXID(), - PreviousATXs: make([]types.ATXID, 1+rand.IntN(255)), - NiPosts: []NiPostsV2{ - { - Membership: MerkleProofV2{ - Nodes: make([]types.Hash32, 32), - }, - Challenge: types.RandomHash(), - Posts: []SubPostV2{ - { - MarriageIndex: rand.Uint32N(256), - PrevATXIndex: 0, - Post: PostV1{ - Nonce: 0, - Indices: make([]byte, 800), - Pow: 0, - }, - }, - }, - }, - }, - } - for _, opt := range opts { - opt(atx) - } - return atx -} +// func newActivationTxV2(opts ...testAtxV2Opt) *ActivationTxV2 { +// atx := &ActivationTxV2{ +// PublishEpoch: rand.N(types.EpochID(255)), +// PositioningATX: types.RandomATXID(), +// PreviousATXs: make([]types.ATXID, 1+rand.IntN(255)), +// NiPosts: []NiPostsV2{ +// { +// Membership: MerkleProofV2{ +// Nodes: make([]types.Hash32, 32), +// }, +// Challenge: types.RandomHash(), +// Posts: []SubPostV2{ +// { +// MarriageIndex: rand.Uint32N(256), +// PrevATXIndex: 0, +// Post: PostV1{ +// Nonce: 0, +// Indices: make([]byte, 800), +// Pow: 0, +// }, +// }, +// }, +// }, +// }, +// } +// for _, opt := range opts { +// opt(atx) +// } +// return atx +// } func Benchmark_ATXv2ID(b *testing.B) { f := fuzz.New() From a942cccca5e14730ecb55aa0f0749de6d81c671d Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 15 Aug 2024 22:11:19 +0400 Subject: [PATCH 45/62] Revert "tmp: fix lint errors (will need to revert)" This reverts commit 1d8d3103606c40416720c52e3b525c4b50eb5db9. --- activation/wire/wire_v2_test.go | 82 +++++++++++++++++---------------- 1 file changed, 42 insertions(+), 40 deletions(-) diff --git a/activation/wire/wire_v2_test.go b/activation/wire/wire_v2_test.go index 9c3cd5f3f6..e0303affb0 100644 --- a/activation/wire/wire_v2_test.go +++ b/activation/wire/wire_v2_test.go @@ -1,56 +1,58 @@ package wire import ( + "math/rand/v2" "testing" fuzz "github.com/google/gofuzz" "github.com/stretchr/testify/require" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/signing" ) -// type testAtxV2Opt func(*ActivationTxV2) +type testAtxV2Opt func(*ActivationTxV2) -// func withMarriageCertificate(sig *signing.EdSigner, refAtx types.ATXID, atxPublisher types.NodeID) testAtxV2Opt { -// return func(atx *ActivationTxV2) { -// certificate := MarriageCertificate{ -// ReferenceAtx: refAtx, -// Signature: sig.Sign(signing.MARRIAGE, atxPublisher.Bytes()), -// } -// atx.Marriages = append(atx.Marriages, certificate) -// } -// } +func withMarriageCertificate(sig *signing.EdSigner, refAtx types.ATXID, atxPublisher types.NodeID) testAtxV2Opt { + return func(atx *ActivationTxV2) { + certificate := MarriageCertificate{ + ReferenceAtx: refAtx, + Signature: sig.Sign(signing.MARRIAGE, atxPublisher.Bytes()), + } + atx.Marriages = append(atx.Marriages, certificate) + } +} -// func newActivationTxV2(opts ...testAtxV2Opt) *ActivationTxV2 { -// atx := &ActivationTxV2{ -// PublishEpoch: rand.N(types.EpochID(255)), -// PositioningATX: types.RandomATXID(), -// PreviousATXs: make([]types.ATXID, 1+rand.IntN(255)), -// NiPosts: []NiPostsV2{ -// { -// Membership: MerkleProofV2{ -// Nodes: make([]types.Hash32, 32), -// }, -// Challenge: types.RandomHash(), -// Posts: []SubPostV2{ -// { -// MarriageIndex: rand.Uint32N(256), -// PrevATXIndex: 0, -// Post: PostV1{ -// Nonce: 0, -// Indices: make([]byte, 800), -// Pow: 0, -// }, -// }, -// }, -// }, -// }, -// } -// for _, opt := range opts { -// opt(atx) -// } -// return atx -// } +func newActivationTxV2(opts ...testAtxV2Opt) *ActivationTxV2 { + atx := &ActivationTxV2{ + PublishEpoch: rand.N(types.EpochID(255)), + PositioningATX: types.RandomATXID(), + PreviousATXs: make([]types.ATXID, 1+rand.IntN(255)), + NiPosts: []NiPostsV2{ + { + Membership: MerkleProofV2{ + Nodes: make([]types.Hash32, 32), + }, + Challenge: types.RandomHash(), + Posts: []SubPostV2{ + { + MarriageIndex: rand.Uint32N(256), + PrevATXIndex: 0, + Post: PostV1{ + Nonce: 0, + Indices: make([]byte, 800), + Pow: 0, + }, + }, + }, + }, + }, + } + for _, opt := range opts { + opt(atx) + } + return atx +} func Benchmark_ATXv2ID(b *testing.B) { f := fuzz.New() From fe95d18b3a195c4c8c1833c37239b5450700e4cc Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 15 Aug 2024 23:32:40 +0400 Subject: [PATCH 46/62] Move statesql migrations to a separate package to avoid cyclic deps --- .../wire/malfeasance_double_marry_test.go | 4 ---- checkpoint/recovery.go | 20 +++++++++++++++-- cmd/activeset/activeset.go | 3 ++- cmd/merge-nodes/internal/merge_action.go | 1 + node/node.go | 9 +++++++- sql/localsql/migrations/schema.go | 12 ++++++++++ sql/schemagen/main.go | 8 +++---- sql/statesql/migrations/schema.go | 10 +++++++++ sql/statesql/migrations/schema_test.go | 22 +++++++++++++++++++ .../{ => migrations}/state_0021_migration.go | 2 +- .../state_0021_migration_test.go | 8 +++---- .../schema/migrations/0021_atx_posts.sql | 14 ++++++------ sql/statesql/statesql.go | 6 +++-- sql/statesql/statesql_test.go | 6 ++--- tortoise/sim/utils.go | 2 +- 15 files changed, 96 insertions(+), 31 deletions(-) create mode 100644 sql/localsql/migrations/schema.go create mode 100644 sql/statesql/migrations/schema.go create mode 100644 sql/statesql/migrations/schema_test.go rename sql/statesql/{ => migrations}/state_0021_migration.go (99%) rename sql/statesql/{ => migrations}/state_0021_migration_test.go (93%) diff --git a/activation/wire/malfeasance_double_marry_test.go b/activation/wire/malfeasance_double_marry_test.go index 90bfd49de9..8caf899b9c 100644 --- a/activation/wire/malfeasance_double_marry_test.go +++ b/activation/wire/malfeasance_double_marry_test.go @@ -1,7 +1,3 @@ -//go:build exclude - -// FIXME: tmp circular dep fix - package wire import ( diff --git a/checkpoint/recovery.go b/checkpoint/recovery.go index 652eda2cd4..261ede31e2 100644 --- a/checkpoint/recovery.go +++ b/checkpoint/recovery.go @@ -25,11 +25,13 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxsync" "github.com/spacemeshos/go-spacemesh/sql/identities" "github.com/spacemeshos/go-spacemesh/sql/localsql" + localmigrations "github.com/spacemeshos/go-spacemesh/sql/localsql/migrations" "github.com/spacemeshos/go-spacemesh/sql/localsql/nipost" "github.com/spacemeshos/go-spacemesh/sql/malsync" "github.com/spacemeshos/go-spacemesh/sql/poets" "github.com/spacemeshos/go-spacemesh/sql/recovery" "github.com/spacemeshos/go-spacemesh/sql/statesql" + statemigrations "github.com/spacemeshos/go-spacemesh/sql/statesql/migrations" ) const recoveryDir = "recovery" @@ -121,12 +123,26 @@ func Recover( return nil, errors.New("restore layer not set") } logger.Info("recovering from checkpoint", zap.String("url", cfg.Uri), zap.Stringer("restore", cfg.Restore)) - db, err := statesql.Open("file:" + cfg.DbPath()) + schema, err := statemigrations.SchemaWithInCodeMigrations() + if err != nil { + return nil, fmt.Errorf("error loading db schema: %w", err) + } + db, err := statesql.Open( + "file:"+cfg.DbPath(), + sql.WithDatabaseSchema(schema), + ) if err != nil { return nil, fmt.Errorf("open old database: %w", err) } defer db.Close() - localDB, err := localsql.Open("file:" + filepath.Join(cfg.DataDir, cfg.LocalDbFile)) + lSchema, err := localmigrations.SchemaWithInCodeMigrations() + if err != nil { + return nil, fmt.Errorf("get schema with in-code migrations: %w", err) + } + localDB, err := localsql.Open( + "file:"+filepath.Join(cfg.DataDir, cfg.LocalDbFile), + sql.WithDatabaseSchema(lSchema), + ) if err != nil { return nil, fmt.Errorf("open old local database: %w", err) } diff --git a/cmd/activeset/activeset.go b/cmd/activeset/activeset.go index 1415972e6b..73ed2dbda8 100644 --- a/cmd/activeset/activeset.go +++ b/cmd/activeset/activeset.go @@ -9,6 +9,7 @@ import ( "strconv" "github.com/spacemeshos/go-spacemesh/common/types" + "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/statesql" ) @@ -30,7 +31,7 @@ Example: if len(dbpath) == 0 { must(errors.New("dbpath is empty"), "dbpath is empty\n") } - db, err := statesql.Open("file:" + dbpath) + db, err := statesql.Open("file:"+dbpath, sql.WithMigrationsDisabled()) must(err, "can't open db at dbpath=%v. err=%s\n", dbpath, err) ids, err := atxs.GetIDsByEpoch(context.Background(), db, types.EpochID(publish)) diff --git a/cmd/merge-nodes/internal/merge_action.go b/cmd/merge-nodes/internal/merge_action.go index 426ad82f2c..caa78b830b 100644 --- a/cmd/merge-nodes/internal/merge_action.go +++ b/cmd/merge-nodes/internal/merge_action.go @@ -42,6 +42,7 @@ func MergeDBs(ctx context.Context, dbLog *zap.Logger, from, to string) error { dstDB, err = localsql.Open("file:"+filepath.Join(to, localDbFile), sql.WithLogger(dbLog), + sql.WithMigrationsDisabled(), ) if err != nil { return err diff --git a/node/node.go b/node/node.go index 5643d6a6d2..8d6483e798 100644 --- a/node/node.go +++ b/node/node.go @@ -76,8 +76,10 @@ import ( "github.com/spacemeshos/go-spacemesh/sql/atxs" "github.com/spacemeshos/go-spacemesh/sql/layers" "github.com/spacemeshos/go-spacemesh/sql/localsql" + localmigrations "github.com/spacemeshos/go-spacemesh/sql/localsql/migrations" dbmetrics "github.com/spacemeshos/go-spacemesh/sql/metrics" "github.com/spacemeshos/go-spacemesh/sql/statesql" + statemigrations "github.com/spacemeshos/go-spacemesh/sql/statesql/migrations" "github.com/spacemeshos/go-spacemesh/syncer" "github.com/spacemeshos/go-spacemesh/syncer/atxsync" "github.com/spacemeshos/go-spacemesh/syncer/blockssync" @@ -1950,7 +1952,7 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { return fmt.Errorf("failed to create %s: %w", dbPath, err) } dbLog := app.addLogger(StateDbLogger, lg).Zap() - schema, err := statesql.Schema() + schema, err := statemigrations.SchemaWithInCodeMigrations() if err != nil { return fmt.Errorf("error loading db schema: %w", err) } @@ -2008,8 +2010,13 @@ func (app *App) setupDBs(ctx context.Context, lg log.Log) error { datastore.WithConsensusCache(app.atxsdata), ) + lSchema, err := localmigrations.SchemaWithInCodeMigrations() + if err != nil { + return fmt.Errorf("error loading db schema: %w", err) + } localDB, err := localsql.Open("file:"+filepath.Join(dbPath, localDbFile), sql.WithLogger(dbLog), + sql.WithDatabaseSchema(lSchema), sql.WithConnections(app.Config.DatabaseConnections), sql.WithAllowSchemaDrift(app.Config.DatabaseSchemaAllowDrift), ) diff --git a/sql/localsql/migrations/schema.go b/sql/localsql/migrations/schema.go new file mode 100644 index 0000000000..46f2bc5f15 --- /dev/null +++ b/sql/localsql/migrations/schema.go @@ -0,0 +1,12 @@ +package migrations + +import ( + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/localsql" +) + +func SchemaWithInCodeMigrations() (*sql.Schema, error) { + return localsql.Schema( + // add coded migrations here + ) +} diff --git a/sql/schemagen/main.go b/sql/schemagen/main.go index c55d2b3ad4..bfdbc3b040 100644 --- a/sql/schemagen/main.go +++ b/sql/schemagen/main.go @@ -8,8 +8,8 @@ import ( "go.uber.org/zap/zapcore" "github.com/spacemeshos/go-spacemesh/sql" - "github.com/spacemeshos/go-spacemesh/sql/localsql" - "github.com/spacemeshos/go-spacemesh/sql/statesql" + localmigrations "github.com/spacemeshos/go-spacemesh/sql/statesql/migrations" + statemigrations "github.com/spacemeshos/go-spacemesh/sql/statesql/migrations" ) var ( @@ -32,9 +32,9 @@ func main() { logger := zap.New(core).With(zap.String("dbType", *dbType)) switch *dbType { case "state": - schema, err = statesql.Schema() + schema, err = statemigrations.SchemaWithInCodeMigrations() case "local": - schema, err = localsql.Schema() + schema, err = localmigrations.SchemaWithInCodeMigrations() default: logger.Fatal("unknown database type, must be state or local") } diff --git a/sql/statesql/migrations/schema.go b/sql/statesql/migrations/schema.go new file mode 100644 index 0000000000..ad4c9cbc01 --- /dev/null +++ b/sql/statesql/migrations/schema.go @@ -0,0 +1,10 @@ +package migrations + +import ( + "github.com/spacemeshos/go-spacemesh/sql" + "github.com/spacemeshos/go-spacemesh/sql/statesql" +) + +func SchemaWithInCodeMigrations() (*sql.Schema, error) { + return statesql.Schema(New0021Migration(1_000_000)) +} diff --git a/sql/statesql/migrations/schema_test.go b/sql/statesql/migrations/schema_test.go new file mode 100644 index 0000000000..c94f2b549c --- /dev/null +++ b/sql/statesql/migrations/schema_test.go @@ -0,0 +1,22 @@ +package migrations + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + + "github.com/spacemeshos/go-spacemesh/sql" +) + +func TestCodedMigrations(t *testing.T) { + schema, err := SchemaWithInCodeMigrations() + require.NoError(t, err) + + db := sql.InMemory( + sql.WithDatabaseSchema(schema), + sql.WithLogger(zaptest.NewLogger(t)), + sql.WithForceMigrations(true), + ) + require.NotNil(t, db) +} diff --git a/sql/statesql/state_0021_migration.go b/sql/statesql/migrations/state_0021_migration.go similarity index 99% rename from sql/statesql/state_0021_migration.go rename to sql/statesql/migrations/state_0021_migration.go index ec5d5ae9fe..5bb4ca2393 100644 --- a/sql/statesql/state_0021_migration.go +++ b/sql/statesql/migrations/state_0021_migration.go @@ -1,4 +1,4 @@ -package statesql +package migrations import ( "errors" diff --git a/sql/statesql/state_0021_migration_test.go b/sql/statesql/migrations/state_0021_migration_test.go similarity index 93% rename from sql/statesql/state_0021_migration_test.go rename to sql/statesql/migrations/state_0021_migration_test.go index 65edcbd3c6..c298a762e7 100644 --- a/sql/statesql/state_0021_migration_test.go +++ b/sql/statesql/migrations/state_0021_migration_test.go @@ -1,4 +1,4 @@ -package statesql +package migrations import ( "slices" @@ -13,15 +13,13 @@ import ( "github.com/spacemeshos/go-spacemesh/signing" "github.com/spacemeshos/go-spacemesh/sql" "github.com/spacemeshos/go-spacemesh/sql/atxs" + "github.com/spacemeshos/go-spacemesh/sql/statesql" ) func Test0021Migration(t *testing.T) { - schema, err := Schema() + schema, err := statesql.Schema() require.NoError(t, err) schema.Migrations = slices.DeleteFunc(schema.Migrations, func(m sql.Migration) bool { - if m.Order() != 21 { - t.Logf("QQQQQ: include migration %d -- %s", m.Order(), m.Name()) - } return m.Order() == 21 }) diff --git a/sql/statesql/schema/migrations/0021_atx_posts.sql b/sql/statesql/schema/migrations/0021_atx_posts.sql index a009bd0655..ac2f91efcd 100644 --- a/sql/statesql/schema/migrations/0021_atx_posts.sql +++ b/sql/statesql/schema/migrations/0021_atx_posts.sql @@ -1,13 +1,13 @@ -- Table showing the PoST commitment by a smesher in given ATX. -- It shows the exact number of space units committed and the previous ATX id. CREATE TABLE posts ( - atxid CHAR(32) NOT NULL, - pubkey CHAR(32) NOT NULL, - prev_atxid CHAR(32), - prev_atx_index INT, - units INT NOT NULL, - UNIQUE (atxid, pubkey) -); + atxid CHAR(32) NOT NULL, + pubkey CHAR(32) NOT NULL, + prev_atxid CHAR(32), + prev_atx_index INT, + units INT NOT NULL, + UNIQUE (atxid, pubkey) + ); CREATE INDEX posts_by_atxid_by_pubkey ON posts (atxid, pubkey, prev_atxid); diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index b755237fb6..6ebc93f6e7 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -25,12 +25,14 @@ var _ sql.StateDatabase = &database{} func (db *database) IsStateDatabase() {} // Schema returns the schema for the state database. -func Schema() (*sql.Schema, error) { +func Schema(inCodeMigrations ...sql.Migration) (*sql.Schema, error) { sqlMigrations, err := sql.LoadSQLMigrations(migrations) if err != nil { return nil, err } - sqlMigrations = sqlMigrations.AddMigration(New0021Migration(1_000_000)) + for _, m := range inCodeMigrations { + sqlMigrations = sqlMigrations.AddMigration(m) + } // NOTE: coded state migrations can be added here // They can be a part of this localsql package return &sql.Schema{ diff --git a/sql/statesql/statesql_test.go b/sql/statesql/statesql_test.go index 207d7f2a5a..a59ce12079 100644 --- a/sql/statesql/statesql_test.go +++ b/sql/statesql/statesql_test.go @@ -37,8 +37,8 @@ func TestIdempotentMigration(t *testing.T) { require.NoError(t, err) require.NoError(t, db.Close()) - // "running migrations" + "applying migration 21" + "processed ATXs" - require.Equal(t, 3, observedLogs.Len(), "expected count of log messages") + // "running migrations" + require.Equal(t, 1, observedLogs.Len(), "expected count of log messages") l := observedLogs.All()[0] require.Equal(t, "running migrations", l.Message) require.Equal(t, int64(0), l.ContextMap()["current version"]) @@ -70,5 +70,5 @@ func TestIdempotentMigration(t *testing.T) { require.NoError(t, db.Close()) // make sure there's no schema drift warnings in the logs - require.Equal(t, 3, observedLogs.Len(), "expected 1 log message") + require.Equal(t, 1, observedLogs.Len(), "expected 1 log message") } diff --git a/tortoise/sim/utils.go b/tortoise/sim/utils.go index 86337c4551..1fd705cc49 100644 --- a/tortoise/sim/utils.go +++ b/tortoise/sim/utils.go @@ -23,7 +23,7 @@ func newCacheDB(logger *zap.Logger, conf config) *datastore.CachedDB { if len(conf.Path) == 0 { db = statesql.InMemory() } else { - db, err = statesql.Open(filepath.Join(conf.Path, atxpath)) + db, err = statesql.Open(filepath.Join(conf.Path, atxpath), sql.WithMigrationsDisabled()) if err != nil { panic(err) } From 22847dfdce86067a5854d241052fe0fdcd9d1099 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 16 Aug 2024 00:52:13 +0400 Subject: [PATCH 47/62] sql: fix schemagen --- sql/schemagen/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/schemagen/main.go b/sql/schemagen/main.go index bfdbc3b040..6838f79046 100644 --- a/sql/schemagen/main.go +++ b/sql/schemagen/main.go @@ -8,7 +8,7 @@ import ( "go.uber.org/zap/zapcore" "github.com/spacemeshos/go-spacemesh/sql" - localmigrations "github.com/spacemeshos/go-spacemesh/sql/statesql/migrations" + localmigrations "github.com/spacemeshos/go-spacemesh/sql/localsql/migrations" statemigrations "github.com/spacemeshos/go-spacemesh/sql/statesql/migrations" ) From 25185be728d405c3f379494d33d04fb57f7d3d80 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 16 Aug 2024 03:27:58 +0400 Subject: [PATCH 48/62] Address comments --- sql/database.go | 74 ++++++++++++++++++++++++++----------------------- 1 file changed, 40 insertions(+), 34 deletions(-) diff --git a/sql/database.go b/sql/database.go index 899d79f854..b57125e5b6 100644 --- a/sql/database.go +++ b/sql/database.go @@ -276,22 +276,18 @@ func openDB(config *conf) (*sqliteDatabase, error) { // fail, so we make it faster by disabling journaling and synchronous // writes. if _, err := db.Exec("PRAGMA journal_mode=OFF", nil, nil); err != nil { - return nil, errors.Join( - fmt.Errorf("PRAGMA journal_mode=OFF: %w", err), - db.Close()) + db.Close() + return nil, fmt.Errorf("PRAGMA journal_mode=OFF: %w", err) } if _, err := db.Exec("PRAGMA synchronous=OFF", nil, nil); err != nil { - return nil, errors.Join( - fmt.Errorf("PRAGMA journal_mode=OFF: %w", err), - db.Close()) + db.Close() + return nil, fmt.Errorf("PRAGMA journal_mode=OFF: %w", err) } } if freshDB && !config.forceMigrations { if err := config.schema.Apply(db); err != nil { - return nil, errors.Join( - fmt.Errorf("error running schema script: %w", err), - db.Close()) + return nil, fmt.Errorf("error running schema script: %w", err) } } else if db, err = ensureDBSchemaUpToDate(logger, db, config); err != nil { // ensureDBSchemaUpToDate may replace the original database and open the new one, @@ -299,7 +295,7 @@ func openDB(config *conf) (*sqliteDatabase, error) { // If there are migrations to be done in place without vacuuming, // the original db is returned and we must close it if there's an error. if db != nil { - err = errors.Join(err, db.Close()) + db.Close() } return nil, err } @@ -307,9 +303,8 @@ func openDB(config *conf) (*sqliteDatabase, error) { if !config.ignoreSchemaDrift { loaded, err := LoadDBSchemaScript(db) if err != nil { - return nil, errors.Join( - fmt.Errorf("error loading database schema: %w", err), - db.Close()) + db.Close() + return nil, fmt.Errorf("error loading database schema: %w", err) } diff := config.schema.Diff(loaded) switch { @@ -320,9 +315,8 @@ func openDB(config *conf) (*sqliteDatabase, error) { zap.String("diff", diff), ) default: - return nil, errors.Join( - fmt.Errorf("schema drift detected (uri %s):\n%s", config.uri, diff), - db.Close()) + db.Close() + return nil, fmt.Errorf("schema drift detected (uri %s):\n%s", config.uri, diff) } } @@ -338,7 +332,7 @@ func ensureDBSchemaUpToDate(logger *zap.Logger, db *sqliteDatabase, config *conf before, after, err := config.schema.CheckDBVersion(logger, db) switch { case err != nil: - return db, err + return db, fmt.Errorf("check db version: %w", err) case before == after: return db, nil case before > after: @@ -392,7 +386,14 @@ func Version(uri string) (int, error) { } db := &sqliteDatabase{pool: pool} v, err := version(db) - return v, errors.Join(err, db.Close()) + if err != nil { + db.Close() + return 0, err + } + if err := db.Close(); err != nil { + return 0, fmt.Errorf("close db %s: %w", uri, err) + } + return v, nil } // deleteDB deletes the database at the specified path by removing /path/to/DB* files. @@ -436,7 +437,8 @@ func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { return err } if err := db.vacuumInto(toPath); err != nil { - return errors.Join(err, db.Close()) + db.Close() + return err } // Open the freshly vacuumed DB to avoid race condition when another process // also tries to vacuum the temporary DB into the original path after @@ -450,17 +452,17 @@ func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { if err != nil { return fmt.Errorf("open vacuumed DB %s: %w", toPath, err) } - defer func() { - if closeErr := origDB.Close(); closeErr != nil { - err = errors.Join(err, fmt.Errorf("close DB %s after migration: %w", toPath, closeErr)) - } - }() if err := db.Close(); err != nil { + origDB.Close() return fmt.Errorf("close temporary DB %s: %w", fromPath, err) } if err := deleteDB(fromPath); err != nil { + origDB.Close() return err } + if err := origDB.Close(); err != nil { + return fmt.Errorf("close DB %s after migration: %w", toPath, err) + } return nil } @@ -589,7 +591,9 @@ func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func return err } defer func() { - err = errors.Join(err, tx.Release()) + if rErr := tx.Release(); err != nil { + err = errors.Join(err, fmt.Errorf("release tx: %w", rErr)) + } }() if err := exec(tx); err != nil { tx.queryCache.ClearCache() @@ -724,9 +728,7 @@ func (db *sqliteDatabase) vacuumInto(toPath string) error { // The source database is always closed by this function. // Upon success, the migrated database is opened. func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, err error) { - defer func() { - err = errors.Join(err, db.Close()) - }() + defer db.Close() dbPath, migratedPath, err := dbMigrationPaths(config.uri) if err != nil { @@ -791,8 +793,9 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, fmt.Errorf("create marker file %s_done: %w", migratedPath, err) } else { if err := f.Sync(); err != nil { - return nil, fmt.Errorf("sync/close marker file %s_done: %w", migratedPath, - errors.Join(err, f.Close())) + f.Close() + os.Remove(markerPath) + return nil, fmt.Errorf("sync/close marker file %s_done: %w", migratedPath, err) } if err := f.Close(); err != nil { return nil, fmt.Errorf("close marker file %s: %w", markerPath, err) @@ -843,7 +846,12 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, err } - return finalDB, err + if err := db.Close(); err != nil { + finalDB.Close() + return nil, fmt.Errorf("close original DB %s: %w", dbPath, err) + } + + return finalDB, nil } // QueryCount returns the number of queries executed, including failed @@ -865,9 +873,7 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (nR if encoder != nil { encoder(stmt) } - defer func() { - err = errors.Join(err, stmt.ClearBindings()) - }() + defer stmt.ClearBindings() rows := 0 for { From 9dcdd5923f9f103615f814ffe12f1bbadda5dcaa Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 16 Aug 2024 03:53:39 +0400 Subject: [PATCH 49/62] fetch: fix tests --- fetch/handler_test.go | 2 +- fetch/p2p_test.go | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/fetch/handler_test.go b/fetch/handler_test.go index bca3c7fac2..a1ff0f781f 100644 --- a/fetch/handler_test.go +++ b/fetch/handler_test.go @@ -330,7 +330,7 @@ func TestHandleEpochInfoReq(t *testing.T) { var resp server.Response require.NoError(t, codec.Decode(b.Bytes(), &resp)) require.Empty(t, resp.Data) - require.Contains(t, resp.Error, "exec epoch 11: database: no free connection") + require.Contains(t, resp.Error, "exec epoch 11: database closed") }) }) } diff --git a/fetch/p2p_test.go b/fetch/p2p_test.go index e92d5ea53b..1e56fb703f 100644 --- a/fetch/p2p_test.go +++ b/fetch/p2p_test.go @@ -270,7 +270,7 @@ func forStreamingCachedUncached( func TestP2PPeerEpochInfo(t *testing.T) { forStreamingCachedUncached( - t, "peer error: getting ATX IDs: exec epoch 11: database: no free connection", + t, "peer error: getting ATX IDs: exec epoch 11: database closed", func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { epoch := types.EpochID(11) atxIDs := tpf.createATXs(epoch) @@ -291,7 +291,7 @@ func TestP2PPeerEpochInfo(t *testing.T) { func TestP2PPeerMeshHashes(t *testing.T) { forStreaming( - t, "peer error: get aggHashes from 7 to 23 by 5: database: no free connection", false, + t, "peer error: get aggHashes from 7 to 23 by 5: database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { req := &MeshHashRequest{ From: 7, @@ -324,7 +324,7 @@ func TestP2PPeerMeshHashes(t *testing.T) { func TestP2PMaliciousIDs(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { var bad []types.NodeID for i := 0; i < 11; i++ { @@ -349,7 +349,7 @@ func TestP2PMaliciousIDs(t *testing.T) { func TestP2PGetATXs(t *testing.T) { forStreamingCachedUncached( - t, "database: no free connection", + t, "database closed", func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { epoch := types.EpochID(11) atx := newAtx(tpf.t, epoch) @@ -365,7 +365,7 @@ func TestP2PGetATXs(t *testing.T) { func TestP2PGetPoet(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { ref := types.PoetProofRef{0x42, 0x43} require.NoError(t, poets.Add(tpf.serverCDB, ref, []byte("proof1"), []byte("sid1"), "rid1")) @@ -380,7 +380,7 @@ func TestP2PGetPoet(t *testing.T) { func TestP2PGetBallot(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { signer, err := signing.NewEdSigner() require.NoError(t, err) @@ -402,7 +402,7 @@ func TestP2PGetBallot(t *testing.T) { func TestP2PGetActiveSet(t *testing.T) { forStreamingCachedUncached( - t, "database: no free connection", + t, "database closed", func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { id := types.RandomHash() set := &types.EpochActiveSet{ @@ -421,7 +421,7 @@ func TestP2PGetActiveSet(t *testing.T) { func TestP2PGetBlock(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { lid := types.LayerID(111) bk := types.NewExistingBlock(types.RandomBlockID(), types.InnerBlock{LayerIndex: lid}) @@ -472,7 +472,7 @@ func TestP2PGetProp(t *testing.T) { func TestP2PGetBlockTransactions(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { signer, err := signing.NewEdSigner() require.NoError(t, err) @@ -488,7 +488,7 @@ func TestP2PGetBlockTransactions(t *testing.T) { func TestP2PGetProposalTransactions(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { signer, err := signing.NewEdSigner() require.NoError(t, err) @@ -506,7 +506,7 @@ func TestP2PGetProposalTransactions(t *testing.T) { func TestP2PGetMalfeasanceProofs(t *testing.T) { forStreaming( - t, "database: no free connection", false, + t, "database closed", false, func(t *testing.T, ctx context.Context, tpf *testP2PFetch, errStr string) { nid := types.RandomNodeID() proof := types.RandomBytes(11) From 8bf64c7d562086c045ef74dbc79272480fc5a532 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Fri, 16 Aug 2024 04:01:32 +0400 Subject: [PATCH 50/62] sql: fix error handling --- sql/database.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/database.go b/sql/database.go index b57125e5b6..77f80a5789 100644 --- a/sql/database.go +++ b/sql/database.go @@ -591,7 +591,7 @@ func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func return err } defer func() { - if rErr := tx.Release(); err != nil { + if rErr := tx.Release(); rErr != nil { err = errors.Join(err, fmt.Errorf("release tx: %w", rErr)) } }() From e7b563beafc0563de698db4b43d5b41bd49c5fee Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 17 Aug 2024 21:39:41 +0400 Subject: [PATCH 51/62] statesql, localsql: fix InMemoryTest --- sql/localsql/localsql.go | 6 +----- sql/statesql/statesql.go | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go index 2dc32f18ad..70d966cd3c 100644 --- a/sql/localsql/localsql.go +++ b/sql/localsql/localsql.go @@ -73,11 +73,7 @@ func InMemory(opts ...sql.Opt) *database { // InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. func InMemoryTest(tb testing.TB, opts ...sql.Opt) sql.LocalDatabase { - opts = append(opts, sql.WithConnections(1)) - db, err := Open("file::memory:?mode=memory", opts...) - if err != nil { - panic(err) - } + db := InMemory(opts...) tb.Cleanup(func() { db.Close() }) return db } diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 6ebc93f6e7..f9fcf816f9 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -71,11 +71,7 @@ func InMemory(opts ...sql.Opt) sql.StateDatabase { // InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. func InMemoryTest(tb testing.TB, opts ...sql.Opt) sql.StateDatabase { - opts = append(opts, sql.WithConnections(1)) - db, err := Open("file::memory:?mode=memory", opts...) - if err != nil { - panic(err) - } + db := InMemory(opts...) tb.Cleanup(func() { db.Close() }) return db } From 5247124f3569388c66758ad49f461c072ae4f786 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Sat, 17 Aug 2024 22:24:04 +0400 Subject: [PATCH 52/62] sql: fix in-memory db handling Don't check for incomplete migrations in case of an in-memory DB --- sql/database.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/database.go b/sql/database.go index 77f80a5789..0e3ed34ab3 100644 --- a/sql/database.go +++ b/sql/database.go @@ -229,7 +229,7 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { for _, opt := range opts { opt(config) } - if !config.temp && config.handleIncompleteMigrations { + if !config.temp && config.handleIncompleteMigrations && !config.forceFresh { if err := handleIncompleteCopyMigration(config); err != nil { return nil, err } From 47fbc8476c5fc00c238ca98ea52c149a7dffc2b5 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 20 Aug 2024 15:45:41 +0400 Subject: [PATCH 53/62] Addressed comments --- cmd/merge-nodes/internal/merge_action_test.go | 4 +- node/node_version_check_test.go | 2 +- sql/database.go | 51 ++++++++++--------- sql/database_test.go | 18 +++---- sql/localsql/localsql.go | 6 +-- sql/schema.go | 2 +- .../migrations/state_0021_migration_test.go | 2 +- sql/statesql/statesql.go | 6 +-- sql/vacuum_test.go | 2 +- syncer/malsync/syncer.go | 3 +- 10 files changed, 45 insertions(+), 51 deletions(-) diff --git a/cmd/merge-nodes/internal/merge_action_test.go b/cmd/merge-nodes/internal/merge_action_test.go index 7fae559a48..5657a1baed 100644 --- a/cmd/merge-nodes/internal/merge_action_test.go +++ b/cmd/merge-nodes/internal/merge_action_test.go @@ -37,7 +37,7 @@ func Test_MergeDBs_InvalidTargetSchema(t *testing.T) { db, err := localsql.Open("file:"+filepath.Join(tmpDst, localDbFile), sql.WithDatabaseSchema(oldSchema(t)), sql.WithForceMigrations(true), - sql.WithIgnoreSchemaDrift(), + sql.WithNoCheckSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -96,7 +96,7 @@ func Test_MergeDBs_InvalidSourceSchema(t *testing.T) { db, err = localsql.Open("file:"+filepath.Join(tmpSrc, localDbFile), sql.WithDatabaseSchema(oldSchema(t)), sql.WithForceMigrations(true), - sql.WithIgnoreSchemaDrift(), + sql.WithNoCheckSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) diff --git a/node/node_version_check_test.go b/node/node_version_check_test.go index affc503f86..2f1b2afe94 100644 --- a/node/node_version_check_test.go +++ b/node/node_version_check_test.go @@ -46,7 +46,7 @@ func TestUpgradeToV15(t *testing.T) { db, err := statesql.Open(uri, sql.WithDatabaseSchema(schema), sql.WithForceMigrations(true), - sql.WithIgnoreSchemaDrift()) + sql.WithNoCheckSchemaDrift()) require.NoError(t, err) require.NoError(t, db.Close()) diff --git a/sql/database.go b/sql/database.go index 6e89049ca6..7ae299b43e 100644 --- a/sql/database.go +++ b/sql/database.go @@ -67,22 +67,23 @@ func defaultConf() *conf { connections: 16, logger: zap.NewNop(), schema: &Schema{}, + checkSchemaDrift: true, } } type conf struct { - enableMigrations bool - forceFresh bool - forceMigrations bool - connections int - vacuumState int - enableLatency bool - cache bool - cacheSizes map[QueryCacheKind]int - logger *zap.Logger - schema *Schema - allowSchemaDrift bool - ignoreSchemaDrift bool + enableMigrations bool + forceFresh bool + forceMigrations bool + connections int + vacuumState int + enableLatency bool + cache bool + cacheSizes map[QueryCacheKind]int + logger *zap.Logger + schema *Schema + allowSchemaDrift bool + checkSchemaDrift bool } // WithConnections overwrites number of pooled connections. @@ -156,17 +157,18 @@ func WithDatabaseSchema(schema *Schema) Opt { } } -// WithAllowSchemaDrift prevents Open from failing upon schema -// drift. A warning is printed instead. +// WithAllowSchemaDrift prevents Open from failing upon schema drift when schema drift +// checks are enabled. A warning is printed instead. func WithAllowSchemaDrift(allow bool) Opt { return func(c *conf) { c.allowSchemaDrift = allow } } -func WithIgnoreSchemaDrift() Opt { +// WithNoCheckSchemaDrift disables schema drift checks. +func WithNoCheckSchemaDrift() Opt { return func(c *conf) { - c.ignoreSchemaDrift = true + c.checkSchemaDrift = false } } @@ -262,7 +264,7 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { } } - if !config.ignoreSchemaDrift { + if config.checkSchemaDrift { loaded, err := LoadDBSchemaScript(db) if err != nil { return nil, errors.Join( @@ -539,17 +541,16 @@ func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, e } func mapSqliteError(err error) error { - code := sqlite.ErrCode(err) - if code == sqlite.SQLITE_CONSTRAINT_PRIMARYKEY || code == sqlite.SQLITE_CONSTRAINT_UNIQUE { + switch sqlite.ErrCode(err) { + case sqlite.SQLITE_CONSTRAINT_PRIMARYKEY, sqlite.SQLITE_CONSTRAINT_UNIQUE: return ErrObjectExists - } - if code == sqlite.SQLITE_INTERRUPT { - // TODO: we probably should check if there was indeed a context - // that was canceled. But we're likely to replace crawshaw library - // in future so this part should be rewritten anyway + case sqlite.SQLITE_INTERRUPT: + // TODO: we probably should check if there was indeed a context that was + // canceled return context.Canceled + default: + return err } - return err } // Blob represents a binary blob data. It can be reused efficiently diff --git a/sql/database_test.go b/sql/database_test.go index e4f222143e..088ef5e24b 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -26,7 +26,7 @@ func Test_Transaction_Isolation(t *testing.T) { field int );`, }), - WithIgnoreSchemaDrift(), + WithNoCheckSchemaDrift(), ) tx, err := db.Tx(context.Background()) require.NoError(t, err) @@ -93,7 +93,7 @@ func Test_Migration_Rollback_Only_NewMigrations(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), - WithIgnoreSchemaDrift(), + WithNoCheckSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -126,7 +126,7 @@ func Test_Migration_Disabled(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), - WithIgnoreSchemaDrift(), + WithNoCheckSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -161,7 +161,7 @@ func TestDatabaseSkipMigrations(t *testing.T) { db, err := Open("file:"+dbFile, WithDatabaseSchema(schema), WithForceMigrations(true), - WithIgnoreSchemaDrift(), + WithNoCheckSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -187,7 +187,7 @@ func TestDatabaseVacuumState(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), - WithIgnoreSchemaDrift(), + WithNoCheckSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -198,7 +198,7 @@ func TestDatabaseVacuumState(t *testing.T) { Migrations: MigrationList{migration1, migration2}, }), WithVacuumState(2), - WithIgnoreSchemaDrift(), + WithNoCheckSchemaDrift(), ) require.NoError(t, err) require.NoError(t, db.Close()) @@ -209,7 +209,7 @@ func TestDatabaseVacuumState(t *testing.T) { } func TestQueryCount(t *testing.T) { - db := InMemory(WithLogger(zaptest.NewLogger(t)), WithIgnoreSchemaDrift()) + db := InMemory(WithLogger(zaptest.NewLogger(t)), WithNoCheckSchemaDrift()) require.Equal(t, 0, db.QueryCount()) n, err := db.Exec("select 1", nil, nil) @@ -237,7 +237,7 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { db, err := Open("file:"+dbFile, WithLogger(logger), WithForceMigrations(true), - WithIgnoreSchemaDrift()) + WithNoCheckSchemaDrift()) require.NoError(t, err) _, err = db.Exec("PRAGMA user_version = 3", nil, nil) require.NoError(t, err) @@ -248,7 +248,7 @@ func Test_Migration_FailsIfDatabaseTooNew(t *testing.T) { WithDatabaseSchema(&Schema{ Migrations: MigrationList{migration1, migration2}, }), - WithIgnoreSchemaDrift(), + WithNoCheckSchemaDrift(), ) require.ErrorIs(t, err, ErrTooNew) } diff --git a/sql/localsql/localsql.go b/sql/localsql/localsql.go index 2dc32f18ad..70d966cd3c 100644 --- a/sql/localsql/localsql.go +++ b/sql/localsql/localsql.go @@ -73,11 +73,7 @@ func InMemory(opts ...sql.Opt) *database { // InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. func InMemoryTest(tb testing.TB, opts ...sql.Opt) sql.LocalDatabase { - opts = append(opts, sql.WithConnections(1)) - db, err := Open("file::memory:?mode=memory", opts...) - if err != nil { - panic(err) - } + db := InMemory(opts...) tb.Cleanup(func() { db.Close() }) return db } diff --git a/sql/schema.go b/sql/schema.go index cf125f162e..7ec1c8000e 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -191,7 +191,7 @@ func (g *SchemaGen) Generate(outputFile string) error { WithLogger(g.logger), WithDatabaseSchema(g.schema), WithForceMigrations(true), - WithIgnoreSchemaDrift()) + WithNoCheckSchemaDrift()) if err != nil { return fmt.Errorf("error opening in-memory db: %w", err) } diff --git a/sql/statesql/migrations/state_0021_migration_test.go b/sql/statesql/migrations/state_0021_migration_test.go index c298a762e7..8b1c354b96 100644 --- a/sql/statesql/migrations/state_0021_migration_test.go +++ b/sql/statesql/migrations/state_0021_migration_test.go @@ -26,7 +26,7 @@ func Test0021Migration(t *testing.T) { db := sql.InMemory( sql.WithLogger(zaptest.NewLogger(t)), sql.WithDatabaseSchema(schema), - sql.WithIgnoreSchemaDrift(), + sql.WithNoCheckSchemaDrift(), sql.WithForceMigrations(true), ) diff --git a/sql/statesql/statesql.go b/sql/statesql/statesql.go index 6ebc93f6e7..f9fcf816f9 100644 --- a/sql/statesql/statesql.go +++ b/sql/statesql/statesql.go @@ -71,11 +71,7 @@ func InMemory(opts ...sql.Opt) sql.StateDatabase { // InMemoryTest returns an in-mem database for testing and ensures database is closed during `tb.Cleanup`. func InMemoryTest(tb testing.TB, opts ...sql.Opt) sql.StateDatabase { - opts = append(opts, sql.WithConnections(1)) - db, err := Open("file::memory:?mode=memory", opts...) - if err != nil { - panic(err) - } + db := InMemory(opts...) tb.Cleanup(func() { db.Close() }) return db } diff --git a/sql/vacuum_test.go b/sql/vacuum_test.go index 1a89158c64..f3f887fafa 100644 --- a/sql/vacuum_test.go +++ b/sql/vacuum_test.go @@ -7,6 +7,6 @@ import ( ) func TestVacuumDB(t *testing.T) { - db := InMemory(WithIgnoreSchemaDrift()) + db := InMemory(WithNoCheckSchemaDrift()) require.NoError(t, Vacuum(db)) } diff --git a/syncer/malsync/syncer.go b/syncer/malsync/syncer.go index 173756002f..633e4a284a 100644 --- a/syncer/malsync/syncer.go +++ b/syncer/malsync/syncer.go @@ -346,7 +346,8 @@ func (s *Syncer) updateState(ctx context.Context) error { }); err != nil { if ctx.Err() != nil { // FIXME: with crawshaw, canceling the context which has been used to get - // a connection from the pool may cause "database: no free connection" errors + // a connection from the pool may cause "database: no free connection" errors. + // Related: #6273 err = ctx.Err() } return fmt.Errorf("error updating malsync state: %w", err) From 2e7e150d469981e907fc6f2f6df330f9b28ea211 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 20 Aug 2024 16:37:55 +0400 Subject: [PATCH 54/62] sql: ignore whitespace during schema drift checks --- sql/schema.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/schema.go b/sql/schema.go index 7ec1c8000e..0ef49bff45 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -56,7 +56,10 @@ type Schema struct { // Diff diffs the database schema against the actual schema. // If there's no differences, it returns an empty string. func (s *Schema) Diff(actualScript string) string { - return cmp.Diff(s.Script, actualScript) + opt := cmp.Comparer(func(x, y string) bool { + return strings.Join(strings.Fields(x), "") == strings.Join(strings.Fields(y), "") + }) + return cmp.Diff(s.Script, actualScript, opt) } // WriteToFile writes the schema to the corresponding updated schema file. From 670df1b248831bbfba56d40bf00ec88dce90a66b Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 20 Aug 2024 21:07:29 +0400 Subject: [PATCH 55/62] sql: fix failing migration tests on Windows --- sql/database.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/database.go b/sql/database.go index e6fb535050..f86263143e 100644 --- a/sql/database.go +++ b/sql/database.go @@ -239,7 +239,7 @@ func Open(uri string, opts ...Opt) (*sqliteDatabase, error) { return openDB(config) } -func openDB(config *conf) (*sqliteDatabase, error) { +func openDB(config *conf) (db *sqliteDatabase, err error) { logger := config.logger.With(zap.String("uri", config.uri)) var flags sqlite.OpenFlags if !config.forceFresh { @@ -268,7 +268,15 @@ func openDB(config *conf) (*sqliteDatabase, error) { return nil, fmt.Errorf("create db %s: %w", config.uri, err) } } - db := &sqliteDatabase{pool: pool} + db = &sqliteDatabase{pool: pool} + success := false + defer func() { + // Close the database even in case of a panic. This is important for tests + // that verify incomplete migration. + if !success && db != nil { + db.Close() + } + }() if config.enableLatency { db.latency = newQueryLatency() } @@ -278,11 +286,9 @@ func openDB(config *conf) (*sqliteDatabase, error) { // fail, so we make it faster by disabling journaling and synchronous // writes. if _, err := db.Exec("PRAGMA journal_mode=OFF", nil, nil); err != nil { - db.Close() return nil, fmt.Errorf("PRAGMA journal_mode=OFF: %w", err) } if _, err := db.Exec("PRAGMA synchronous=OFF", nil, nil); err != nil { - db.Close() return nil, fmt.Errorf("PRAGMA journal_mode=OFF: %w", err) } } @@ -327,6 +333,7 @@ func openDB(config *conf) (*sqliteDatabase, error) { db.queryCache = &queryCache{cacheSizesByKind: config.cacheSizes} } db.queryCount.Store(0) + success = true // do not close the db in the deferred func return db, nil } @@ -770,6 +777,7 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, fmt.Errorf("process temporary DB %s: %w", migratedPath, err), deleteDB(migratedPath)) } + defer migratedDB.Close() tempDBReady := false defer func() { err = errors.Join(err, migratedDB.Close()) @@ -836,6 +844,14 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, if err != nil { return nil, fmt.Errorf("open final DB %s: %w", config.uri, err) } + success := false + defer func() { + // Close the database even in case of a panic. This is important for tests + // that verify incomplete migration. + if !success { + finalDB.Close() + } + }() if err := migratedDB.Close(); err != nil { finalDB.Close() @@ -853,6 +869,7 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, fmt.Errorf("close original DB %s: %w", dbPath, err) } + success = true // do not close the db in the deferred func return finalDB, nil } From d69eab427e325435db5ac059c8bbecc889c77faf Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Tue, 20 Aug 2024 22:38:38 +0400 Subject: [PATCH 56/62] Address comments --- sql/database.go | 61 +++++++++++++++++-------------------------------- sql/schema.go | 5 +--- 2 files changed, 22 insertions(+), 44 deletions(-) diff --git a/sql/database.go b/sql/database.go index f86263143e..992a81634b 100644 --- a/sql/database.go +++ b/sql/database.go @@ -345,18 +345,8 @@ func ensureDBSchemaUpToDate(logger *zap.Logger, db *sqliteDatabase, config *conf case before == after: return db, nil case before > after: - // TODO: this should be logged by the caller - logger.Error("database version is newer than expected - downgrade is not supported", - zap.Int("current version", before), - zap.Int("target version", after), - ) return db, fmt.Errorf("%w: %d > %d", ErrTooNew, before, after) case !config.enableMigrations: - // TODO: this should be logged by the caller - logger.Error("database version is too old", - zap.Int("current version", before), - zap.Int("target version", after), - ) return db, fmt.Errorf("%w: %d < %d", ErrOldSchema, before, after) case config.temp: // Temporary database, do migrations without transactions @@ -370,22 +360,13 @@ func ensureDBSchemaUpToDate(logger *zap.Logger, db *sqliteDatabase, config *conf zap.Int("target version", after), ) return db.copyMigrateDB(config) - default: - // Do not produce extra "running migrations" log message for the - // temporary DB, as it was already logged - if !config.temp { - logger.Info("running migrations in-place", - zap.Int("current version", before), - zap.Int("target version", after), - ) - } else { - logger.Info("applying migrations to temporary DB", - zap.Int("current version", before), - zap.Int("target version", after), - ) - } - return db, config.schema.Migrate(logger, db, before, config.vacuumState) } + + logger.Info("running migrations in-place", + zap.Int("current version", before), + zap.Int("target version", after), + ) + return db, config.schema.Migrate(logger, db, before, config.vacuumState) } func Version(uri string) (int, error) { @@ -600,8 +581,8 @@ func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func return err } defer func() { - if rErr := tx.Release(); rErr != nil { - err = errors.Join(err, fmt.Errorf("release tx: %w", rErr)) + if rErr := tx.Release(); rErr != nil && err == nil { + err = fmt.Errorf("release tx: %w", rErr) } }() if err := exec(tx); err != nil { @@ -799,21 +780,21 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, // and PRAGMA synchronous=OFF, it may become corrupt in case of a crash or power // outage, so we avoid trying to open it. markerPath := migratedPath + "_done" - if f, err := os.Create(markerPath); err != nil { + f, err := os.Create(markerPath) + if err != nil { return nil, fmt.Errorf("create marker file %s_done: %w", migratedPath, err) - } else { - if err := f.Sync(); err != nil { - f.Close() - os.Remove(markerPath) - return nil, fmt.Errorf("sync/close marker file %s_done: %w", migratedPath, err) - } - if err := f.Close(); err != nil { - return nil, fmt.Errorf("close marker file %s: %w", markerPath, err) - } - // The temporary database is complete and should not be deleted - // until we copy it to the original database location. - tempDBReady = true } + if err := f.Sync(); err != nil { + f.Close() + os.Remove(markerPath) + return nil, fmt.Errorf("sync/close marker file %s_done: %w", migratedPath, err) + } + if err := f.Close(); err != nil { + return nil, fmt.Errorf("close marker file %s: %w", markerPath, err) + } + // The temporary database is complete and should not be deleted + // until we copy it to the original database location. + tempDBReady = true // We only close the source database at the end of the migration process // so that the lock is held. There's a possibility that right after we diff --git a/sql/schema.go b/sql/schema.go index 64f180967b..7dde8ef3f2 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -108,10 +108,7 @@ func (s *Schema) CheckDBVersion(logger *zap.Logger, db Database) (before, after if err != nil { return 0, 0, err } - after = 0 - if len(s.Migrations) > 0 { - after = s.Migrations.Version() - } + after = s.Migrations.Version() if before > after { logger.Error("database version is newer than expected - downgrade is not supported", zap.Int("current version", before), From 26aa21cc2c3fe08d0bb8d0de23f77c4e348240cf Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Wed, 21 Aug 2024 23:22:42 +0400 Subject: [PATCH 57/62] Address more comments --- sql/database_test.go | 3 --- sql/schema.go | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sql/database_test.go b/sql/database_test.go index 229bae4874..f277cebffa 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -577,6 +577,3 @@ func TestSchemaDrift(t *testing.T) { require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, observedLogs.All()[0].ContextMap()["diff"]) } - -// TBD: test WAL modes for temp DB -// TBD: remove SQLITE_OPEN_WAL from open flags and check journal mode diff --git a/sql/schema.go b/sql/schema.go index 7dde8ef3f2..390a580c06 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -158,12 +158,12 @@ func (s *Schema) Migrate(logger *zap.Logger, db Database, before, vacuumState in } return nil }); err != nil { - return errors.Join(err, db.Close()) + return err } if vacuumState != 0 && before <= vacuumState { if err := Vacuum(db); err != nil { - return errors.Join(err, db.Close()) + return err } } before = m.Order() @@ -190,7 +190,7 @@ func (s *Schema) MigrateTempDB(logger *zap.Logger, db Database, before int) erro if _, ok := s.skipMigration[m.Order()]; !ok { if err := m.Apply(db, logger); err != nil { - return errors.Join(fmt.Errorf("apply %s: %w", m.Name(), err), db.Close()) + return fmt.Errorf("apply %s: %w", m.Name(), err) } } From 3a3bd94e30a580369df97ac430d65cebdfd1eeb4 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 22 Aug 2024 02:39:01 +0400 Subject: [PATCH 58/62] sql: improve error handling during migrations --- sql/database.go | 87 ++++++++++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 37 deletions(-) diff --git a/sql/database.go b/sql/database.go index 992a81634b..65a64d4143 100644 --- a/sql/database.go +++ b/sql/database.go @@ -269,12 +269,12 @@ func openDB(config *conf) (db *sqliteDatabase, err error) { } } db = &sqliteDatabase{pool: pool} - success := false defer func() { // Close the database even in case of a panic. This is important for tests // that verify incomplete migration. - if !success && db != nil { + if r := recover(); r != nil { db.Close() + panic(r) } }() if config.enableLatency { @@ -333,7 +333,6 @@ func openDB(config *conf) (db *sqliteDatabase, err error) { db.queryCache = &queryCache{cacheSizesByKind: config.cacheSizes} } db.queryCount.Store(0) - success = true // do not close the db in the deferred func return db, nil } @@ -737,7 +736,14 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, zap.String("path", dbPath), zap.String("target", migratedPath)) if err := db.vacuumInto(migratedPath); err != nil { - return nil, errors.Join(err, deleteDB(migratedPath)) + if err := deleteDB(migratedPath); err != nil { + config.logger.Error( + "incomplete temporary copy of the database couldn't be deleted", + zap.String("path", migratedPath), + zap.Error(err), + ) + } + return nil, errors.Join(err) } // Opening the temporary migrated DB runs the actual migrations on it. @@ -754,16 +760,27 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, } migratedDB, err := Open("file:"+migratedPath, opts...) if err != nil { - return nil, errors.Join( - fmt.Errorf("process temporary DB %s: %w", migratedPath, err), - deleteDB(migratedPath)) + if err := deleteDB(migratedPath); err != nil { + config.logger.Error( + "incomplete temporary copy of the database couldn't be deleted", + zap.String("path", migratedPath), + zap.Error(err), + ) + } + return nil, fmt.Errorf("process temporary DB %s: %w", migratedPath, err) } defer migratedDB.Close() tempDBReady := false defer func() { - err = errors.Join(err, migratedDB.Close()) + migratedDB.Close() if !tempDBReady { - err = errors.Join(err, deleteDB(migratedPath)) + if err := deleteDB(migratedPath); err != nil { + config.logger.Error( + "incomplete temporary copy of the database couldn't be deleted", + zap.String("path", migratedPath), + zap.Error(err), + ) + } } }() @@ -773,25 +790,18 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, fmt.Errorf("checkpoint temporary DB %s: %w", migratedPath, err) } - // Create the marker file to indicate that the migration is complete. - // Make sure the file is written to the disk before closing the database. + // Create the marker file to indicate that the migration is complete and make sure + // the file is written to the disk before closing the database. // We could create a table in the temporary database instead of the marker file, // but as the temporary database is opened without PRAGMA journal_mode=OFF // and PRAGMA synchronous=OFF, it may become corrupt in case of a crash or power // outage, so we avoid trying to open it. - markerPath := migratedPath + "_done" - f, err := os.Create(markerPath) - if err != nil { - return nil, fmt.Errorf("create marker file %s_done: %w", migratedPath, err) - } - if err := f.Sync(); err != nil { - f.Close() - os.Remove(markerPath) - return nil, fmt.Errorf("sync/close marker file %s_done: %w", migratedPath, err) - } - if err := f.Close(); err != nil { - return nil, fmt.Errorf("close marker file %s: %w", markerPath, err) + if err := createMarkerFile(migratedPath); err != nil { + // The errors returned by createMarkerFile are already descriptive enough + // so no need to augment them + return nil, err } + // The temporary database is complete and should not be deleted // until we copy it to the original database location. tempDBReady = true @@ -825,14 +835,6 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, if err != nil { return nil, fmt.Errorf("open final DB %s: %w", config.uri, err) } - success := false - defer func() { - // Close the database even in case of a panic. This is important for tests - // that verify incomplete migration. - if !success { - finalDB.Close() - } - }() if err := migratedDB.Close(); err != nil { finalDB.Close() @@ -845,15 +847,26 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, err } - if err := db.Close(); err != nil { - finalDB.Close() - return nil, fmt.Errorf("close original DB %s: %w", dbPath, err) - } - - success = true // do not close the db in the deferred func return finalDB, nil } +func createMarkerFile(basePath string) error { + markerPath := basePath + "_done" + f, err := os.Create(markerPath) + if err != nil { + return fmt.Errorf("create marker file %s: %w", markerPath, err) + } + if err := f.Sync(); err != nil { + f.Close() + os.Remove(markerPath) + return fmt.Errorf("sync/close marker file %s: %w", markerPath, err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("close marker file %s: %w", markerPath, err) + } + return nil +} + // QueryCount returns the number of queries executed, including failed // queries, but not counting transaction start / commit / rollback. func (db *sqliteDatabase) QueryCount() int { From f4624337cacaf0a8514eaa4ee7c565e3ddc76bf7 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 22 Aug 2024 04:35:31 +0400 Subject: [PATCH 59/62] sql: hold exclusive database locks during copy-based migration --- sql/database.go | 140 +++++++++++++++++++++++++++++++++++++------ sql/database_test.go | 36 +++++++++++ 2 files changed, 159 insertions(+), 17 deletions(-) diff --git a/sql/database.go b/sql/database.go index 65a64d4143..0ccd9733e8 100644 --- a/sql/database.go +++ b/sql/database.go @@ -92,6 +92,7 @@ type conf struct { checkSchemaDrift bool temp bool handleIncompleteMigrations bool + exclusive bool } // WithConnections overwrites number of pooled connections. @@ -201,6 +202,18 @@ func withDisableIncompleteMigrationHandling() Opt { } } +// WithExclusive specifies that the database is to be open in exclusive mode. +// This means that no other processes can open the database at the same time. +// If the database is already open by any process, this Open will fail. +// Any subsequent attempts by other processes to open the database will fail until this db +// handle is closed. +// In Exclusive mode, the database supports just one concurrent connection. +func WithExclusive() Opt { + return func(c *conf) { + c.exclusive = true + } +} + // Opt for configuring database. type Opt func(c *conf) @@ -256,6 +269,9 @@ func openDB(config *conf) (db *sqliteDatabase, err error) { } } freshDB := config.forceFresh + if config.exclusive { + config.connections = 1 + } pool, err := sqlitex.Open(config.uri, flags, config.connections) if err != nil { if config.forceFresh || sqlite.ErrCode(err) != sqlite.SQLITE_CANTOPEN { @@ -270,13 +286,26 @@ func openDB(config *conf) (db *sqliteDatabase, err error) { } db = &sqliteDatabase{pool: pool} defer func() { - // Close the database even in case of a panic. This is important for tests - // that verify incomplete migration. + // If something goes wrong, close the database even in case of a + // panic. This is important for tests that verify incomplete migration. if r := recover(); r != nil { db.Close() panic(r) } }() + // In case of VACUUM INTO based migration, prepareDB may close this database and + // open another one. + actualDB, err := prepareDB(logger, db, config, freshDB) + if err != nil { + db.Close() + return nil, err + } + return actualDB, nil +} + +func prepareDB(logger *zap.Logger, db *sqliteDatabase, config *conf, freshDB bool) (*sqliteDatabase, error) { + var err error + if config.enableLatency { db.latency = newQueryLatency() } @@ -293,6 +322,12 @@ func openDB(config *conf) (db *sqliteDatabase, err error) { } } + if config.exclusive { + if err := db.startExclusive(); err != nil { + return nil, fmt.Errorf("error switching to exclusive mode: %w", err) + } + } + if freshDB && !config.forceMigrations { if err := config.schema.Apply(db); err != nil { return nil, fmt.Errorf("error running schema script: %w", err) @@ -311,7 +346,6 @@ func openDB(config *conf) (db *sqliteDatabase, err error) { if config.checkSchemaDrift { loaded, err := LoadDBSchemaScript(db) if err != nil { - db.Close() return nil, fmt.Errorf("error loading database schema: %w", err) } diff := config.schema.Diff(loaded) @@ -323,7 +357,6 @@ func openDB(config *conf) (db *sqliteDatabase, err error) { zap.String("diff", diff), ) default: - db.Close() return nil, fmt.Errorf("schema drift detected (uri %s):\n%s", config.uri, diff) } } @@ -409,7 +442,8 @@ func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { config.logger.Warn("finalizing migration by moving the temporary DB to the original path", zap.String("fromPath", fromPath), zap.String("toPath", toPath)) - // Try to open the temporary migrated DB before deleting the original one. + // Try to open the temporary migrated DB in exclusive mode before deleting the + // original one. // If the temporary DB is being copied to the original path by another // process, this will fail and the original database will not be deleted. // We don't use the proper database schema here because the temporary DB @@ -418,7 +452,9 @@ func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { WithLogger(config.logger), WithConnections(1), WithTemp(), - WithNoCheckSchemaDrift()) + WithNoCheckSchemaDrift(), + WithExclusive(), + ) if err != nil { return fmt.Errorf("open temporary DB %s: %w", fromPath, err) } @@ -429,15 +465,17 @@ func moveMigratedDB(config *conf, fromPath, toPath string) (err error) { db.Close() return err } - // Open the freshly vacuumed DB to avoid race condition when another process - // also tries to vacuum the temporary DB into the original path after - // we close the temporary DB. + // Open the freshly vacuumed DB in exclusive mode to avoid race condition when + // another process also tries to vacuum the temporary DB into the original path + // after we close the temporary DB. origDB, err := Open("file:"+toPath, WithLogger(config.logger), WithConnections(1), WithMigrationsDisabled(), WithNoCheckSchemaDrift(), - withDisableIncompleteMigrationHandling()) + withDisableIncompleteMigrationHandling(), + WithExclusive(), + ) if err != nil { return fmt.Errorf("open vacuumed DB %s: %w", toPath, err) } @@ -591,6 +629,42 @@ func (db *sqliteDatabase) withTx(ctx context.Context, initstmt string, exec func return tx.Commit() } +func (db *sqliteDatabase) startExclusive() error { + conn := db.getConn(context.Background()) + if conn == nil { + return ErrNoConnection + } + // We don't need to wait for long if the database is busy + conn.SetBusyTimeout(1 * time.Millisecond) + // From SQLite docs: + // When the locking-mode is set to EXCLUSIVE, the database connection + // never releases file-locks. The first time the database is read in + // EXCLUSIVE mode, a shared lock is obtained and held. The first time the + // database is written, an exclusive lock is obtained and held. + if _, err := exec(conn, "PRAGMA locking_mode=EXCLUSIVE", nil, nil); err != nil { + db.pool.Put(conn) + return fmt.Errorf("PRAGMA locking_mode=EXCLUSIVE: %w", err) + } + // We need to perform a transaction to have the database actually locked. + // From SQLite docs, regarding BEGIN EXCLUSIVE / BEGIN IMMEDIATE: + // EXCLUSIVE is similar to IMMEDIATE in that a write transaction is + // started immediately. EXCLUSIVE and IMMEDIATE are the same in WAL mode, + // but in other journaling modes, EXCLUSIVE prevents other database + // connections from reading the database while the transaction is + // underway. + _, err := exec(conn, "BEGIN EXCLUSIVE", nil, nil) + if err != nil { + db.pool.Put(conn) + return fmt.Errorf("error starting the EXCLUSIVE transaction: %w", err) + } + if _, err := exec(conn, "COMMIT", nil, nil); err != nil { + db.pool.Put(conn) + return fmt.Errorf("error committing the EXCLUSIVE transaction: %w", err) + } + db.pool.Put(conn) + return nil +} + // Tx creates deferred sqlite transaction. // // Deferred transactions are not started until the first statement. @@ -676,7 +750,7 @@ func (db *sqliteDatabase) Close() error { return nil } if err := db.pool.Close(); err != nil { - return fmt.Errorf("close pool %w", err) + return fmt.Errorf("close pool: %w", err) } db.closed = true return nil @@ -717,8 +791,6 @@ func (db *sqliteDatabase) vacuumInto(toPath string) error { // The source database is always closed by this function. // Upon success, the migrated database is opened. func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, err error) { - defer db.Close() - dbPath, migratedPath, err := dbMigrationPaths(config.uri) if err != nil { return nil, err @@ -727,7 +799,25 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, fmt.Errorf("cannot migrate database, only file DBs are supported: %s", config.uri) } - // Instead of just copying the source database to the temporary migration DB, use VACUUM INTO. + // Before we start the migration, re-open the database in exclusive mode + // so that no other connections will be able to use it. + // This will fail if another process is already using this database. + if err := db.Close(); err != nil { + return nil, fmt.Errorf("error closing DB: %w", err) + } + + excDB, err := Open("file:"+dbPath, + WithLogger(config.logger), + WithConnections(1), + WithNoCheckSchemaDrift(), + WithExclusive(), + ) + if err != nil { + return nil, fmt.Errorf("error opening the database in exclusive mode: %v", err) + } + defer excDB.Close() + + // instead of just copying the source database to the temporary migration DB, use VACUUM INTO. // This is somewhat slower but achieves two goals: // 1. The lock is held on the source database while it's being copied // 2. If the source database has a lot of free pages for whatever reason, those @@ -735,7 +825,7 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, config.logger.Info("making a temporary copy of the database", zap.String("path", dbPath), zap.String("target", migratedPath)) - if err := db.vacuumInto(migratedPath); err != nil { + if err := excDB.vacuumInto(migratedPath); err != nil { if err := deleteDB(migratedPath); err != nil { config.logger.Error( "incomplete temporary copy of the database couldn't be deleted", @@ -754,6 +844,7 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, WithConnections(1), WithTemp(), WithDatabaseSchema(config.schema), + WithExclusive(), } if !config.checkSchemaDrift { opts = append(opts, WithNoCheckSchemaDrift()) @@ -810,7 +901,7 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, // so that the lock is held. There's a possibility that right after we // close the source database, another process will see the migrated database // and the marker file and will try to open the migrated database. If the - if err := db.Close(); err != nil { + if err := excDB.Close(); err != nil { return nil, fmt.Errorf("close db: %w", err) } @@ -828,9 +919,11 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, err } - // Open the final DB before deleting the source DB, so one of the locks + // Open the final DB in the exclusive mode before deleting the source DB, so one of the locks // is always held. The migrations are already run, so we're disabling them. + origExclusive := config.exclusive config.enableMigrations = false + config.exclusive = true finalDB, err = openDB(config) if err != nil { return nil, fmt.Errorf("open final DB %s: %w", config.uri, err) @@ -847,6 +940,19 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, err } + // If we were not intending to open the database in exclusive mode, + // reopen it in the normal mode + if !origExclusive { + if err := finalDB.Close(); err != nil { + return nil, fmt.Errorf("close final DB: %w", err) + } + config.exclusive = false + finalDB, err = openDB(config) + if err != nil { + return nil, fmt.Errorf("open final DB %s: %w", config.uri, err) + } + } + return finalDB, nil } diff --git a/sql/database_test.go b/sql/database_test.go index f277cebffa..e3d138b3f4 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -224,6 +224,7 @@ func TestDatabaseVacuumState(t *testing.T) { Migrations: MigrationList{migration1}, }), WithForceMigrations(true), + WithConnections(10), ) require.NoError(t, err) execSQL(t, db, "select * from foo", -1) // ensure table exists @@ -577,3 +578,38 @@ func TestSchemaDrift(t *testing.T) { require.Regexp(t, `.*\n.*\+.*CREATE TABLE newtbl \(id int\);`, observedLogs.All()[0].ContextMap()["diff"]) } + +func TestExclusive(t *testing.T) { + for _, tc := range []struct { + name string + optsA []Opt + optsB []Opt + }{ + { + name: "exclusive succeeds, non-exclusive fails", + optsA: []Opt{WithExclusive()}, + }, + { + name: "exclusive succeeds, non-exclusive fails", + optsB: []Opt{WithExclusive()}, + }, + { + name: "first exclusive succeeds, second exclusive fails", + optsA: []Opt{WithExclusive()}, + optsB: []Opt{WithExclusive()}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + dir := t.TempDir() + dbPath := filepath.Join(dir, "test.db") + db, err := Open(dbPath, append([]Opt{WithNoCheckSchemaDrift()}, tc.optsA...)...) + require.NoError(t, err) + _, err = Open(dbPath, append([]Opt{WithNoCheckSchemaDrift()}, tc.optsB...)...) + require.ErrorContains(t, err, "SQLITE_BUSY: database is locked") + _, err = db.Exec("select count(*) from sqlite_master", nil, nil) + require.NoError(t, err) + require.NoError(t, db.Close()) + }) + } +} From 3119a63f5a29f3624be2ac4f597f3c9bf66c5873 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 22 Aug 2024 14:50:17 +0400 Subject: [PATCH 60/62] sql: further simplify error handling --- sql/database.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/sql/database.go b/sql/database.go index 0ccd9733e8..ffb4bf548b 100644 --- a/sql/database.go +++ b/sql/database.go @@ -861,23 +861,17 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, return nil, fmt.Errorf("process temporary DB %s: %w", migratedPath, err) } defer migratedDB.Close() - tempDBReady := false - defer func() { - migratedDB.Close() - if !tempDBReady { - if err := deleteDB(migratedPath); err != nil { - config.logger.Error( - "incomplete temporary copy of the database couldn't be deleted", - zap.String("path", migratedPath), - zap.Error(err), - ) - } - } - }() // Make sure the temporary DB is fully synced to the disk before creating the marker file. // We don't need wal_checkpoint(TRUNCATE) here as we're going to delete the temporary DB. if _, err := migratedDB.Exec("PRAGMA wal_checkpoint(FULL)", nil, nil); err != nil { + if err := deleteDB(migratedPath); err != nil { + config.logger.Error( + "incomplete temporary copy of the database couldn't be deleted", + zap.String("path", migratedPath), + zap.Error(err), + ) + } return nil, fmt.Errorf("checkpoint temporary DB %s: %w", migratedPath, err) } @@ -888,14 +882,20 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, // and PRAGMA synchronous=OFF, it may become corrupt in case of a crash or power // outage, so we avoid trying to open it. if err := createMarkerFile(migratedPath); err != nil { + if err := deleteDB(migratedPath); err != nil { + config.logger.Error( + "incomplete temporary copy of the database couldn't be deleted", + zap.String("path", migratedPath), + zap.Error(err), + ) + } // The errors returned by createMarkerFile are already descriptive enough // so no need to augment them return nil, err } - // The temporary database is complete and should not be deleted + // At this point, the temporary database is complete and should not be deleted // until we copy it to the original database location. - tempDBReady = true // We only close the source database at the end of the migration process // so that the lock is held. There's a possibility that right after we From 11527ab0a9a176aa5d2b38dcdb3a161e32ba9941 Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 22 Aug 2024 16:37:49 +0400 Subject: [PATCH 61/62] Remove unneeded errors.Join --- sql/database.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/database.go b/sql/database.go index ffb4bf548b..f08a3d6a49 100644 --- a/sql/database.go +++ b/sql/database.go @@ -833,7 +833,7 @@ func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, zap.Error(err), ) } - return nil, errors.Join(err) + return nil, err } // Opening the temporary migrated DB runs the actual migrations on it. From 64f4d3886e3d14170c3a05204f4fc96eff9d5aef Mon Sep 17 00:00:00 2001 From: Ivan Shvedunov Date: Thu, 22 Aug 2024 18:06:20 +0400 Subject: [PATCH 62/62] Address comments --- sql/database.go | 15 ++++++--------- sql/schema.go | 8 ++++---- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/sql/database.go b/sql/database.go index f08a3d6a49..a647086f4a 100644 --- a/sql/database.go +++ b/sql/database.go @@ -318,13 +318,13 @@ func prepareDB(logger *zap.Logger, db *sqliteDatabase, config *conf, freshDB boo return nil, fmt.Errorf("PRAGMA journal_mode=OFF: %w", err) } if _, err := db.Exec("PRAGMA synchronous=OFF", nil, nil); err != nil { - return nil, fmt.Errorf("PRAGMA journal_mode=OFF: %w", err) + return nil, fmt.Errorf("PRAGMA synchronous=OFF: %w", err) } } if config.exclusive { if err := db.startExclusive(); err != nil { - return nil, fmt.Errorf("error switching to exclusive mode: %w", err) + return nil, fmt.Errorf("start exclusive: %w", err) } } @@ -520,7 +520,7 @@ func dbMigrationPaths(uri string) (dbPath, migratedPath string, err error) { func handleIncompleteCopyMigration(config *conf) error { dbPath, migratedPath, err := dbMigrationPaths(config.uri) if err != nil { - return err + return fmt.Errorf("getting DB migration paths: %w", err) } if migratedPath == "" { return nil @@ -634,6 +634,7 @@ func (db *sqliteDatabase) startExclusive() error { if conn == nil { return ErrNoConnection } + defer db.pool.Put(conn) // We don't need to wait for long if the database is busy conn.SetBusyTimeout(1 * time.Millisecond) // From SQLite docs: @@ -642,7 +643,6 @@ func (db *sqliteDatabase) startExclusive() error { // EXCLUSIVE mode, a shared lock is obtained and held. The first time the // database is written, an exclusive lock is obtained and held. if _, err := exec(conn, "PRAGMA locking_mode=EXCLUSIVE", nil, nil); err != nil { - db.pool.Put(conn) return fmt.Errorf("PRAGMA locking_mode=EXCLUSIVE: %w", err) } // We need to perform a transaction to have the database actually locked. @@ -654,14 +654,11 @@ func (db *sqliteDatabase) startExclusive() error { // underway. _, err := exec(conn, "BEGIN EXCLUSIVE", nil, nil) if err != nil { - db.pool.Put(conn) return fmt.Errorf("error starting the EXCLUSIVE transaction: %w", err) } if _, err := exec(conn, "COMMIT", nil, nil); err != nil { - db.pool.Put(conn) return fmt.Errorf("error committing the EXCLUSIVE transaction: %w", err) } - db.pool.Put(conn) return nil } @@ -793,7 +790,7 @@ func (db *sqliteDatabase) vacuumInto(toPath string) error { func (db *sqliteDatabase) copyMigrateDB(config *conf) (finalDB *sqliteDatabase, err error) { dbPath, migratedPath, err := dbMigrationPaths(config.uri) if err != nil { - return nil, err + return nil, fmt.Errorf("getting DB migration paths: %w", err) } if migratedPath == "" { return nil, fmt.Errorf("cannot migrate database, only file DBs are supported: %s", config.uri) @@ -1060,7 +1057,7 @@ func (tx *sqliteTx) Release() error { // Exec query. func (tx *sqliteTx) Exec(query string, encoder Encoder, decoder Decoder) (int, error) { if err := tx.db.runInterceptors(query); err != nil { - return 0, err + return 0, fmt.Errorf("running query interceptors: %w", err) } tx.db.queryCount.Add(1) diff --git a/sql/schema.go b/sql/schema.go index 390a580c06..4dc667e78b 100644 --- a/sql/schema.go +++ b/sql/schema.go @@ -201,18 +201,18 @@ func (s *Schema) MigrateTempDB(logger *zap.Logger, db Database, before int) erro logger.Info("syncing temporary database") - // Enable synchronous mode and WAL journal to ensure the database is synced + // Enable WAL journal and synchronous mode to ensure the database is synced if _, err := db.Exec("PRAGMA journal_mode=WAL", nil, nil); err != nil { - return fmt.Errorf("PRAGMA journal_mode=WAL: %w", err) + return fmt.Errorf("setting WAL journal mode: %w", err) } if _, err := db.Exec("PRAGMA synchronous=FULL", nil, nil); err != nil { - return fmt.Errorf("PRAGMA journal_mode=OFF: %w", err) + return fmt.Errorf("setting synchronous mode: %w", err) } // This should trigger file sync if err := s.setVersion(db, v); err != nil { - return err + return fmt.Errorf("setting DB schema version: %w", err) } return nil