From adec895f32227e0b47510947ba72e53315e1c49c Mon Sep 17 00:00:00 2001 From: tanut32039 Date: Tue, 21 Jan 2025 12:42:38 +0700 Subject: [PATCH] [Feature] Unit test (#15) * add-test-1 * add test keys * add tests * add test * fix fail test * fix function * add tunnel relayer test * refac keys.go * fix conflict * update go mod version * fix band types and add testcases * add provider test * fix multiply gas logic * refactor app_test and add chaincmdtest * merge main into unit-test * declare a new viper variable * fix band testcase * fix config test * fix provider test * fix unit-test keys * refactor tunnel_relayer_test.go * fix app test * fix test * fix tunnel relayer test * change func name * remove t.parallel and new viper object --------- Co-authored-by: Tanut Lertwarachai Co-authored-by: nkitlabs --- cmd/chains_test.go | 128 +++ cmd/config_test.go | 18 +- internal/relayertest/constants.go | 3 + .../relayertest/mocks/chain_evm_client.go | 236 +++++ internal/relayertest/mocks/chain_provider.go | 10 +- internal/relayertest/system.go | 1 + .../chain_config_invalid_chain_type.toml | 15 + relayer/app.go | 100 +-- relayer/app_test.go | 850 ++++++++++++++---- relayer/band/client_test.go | 279 +++++- relayer/band/types/signing.go | 4 +- relayer/chains/evm/keys.go | 59 +- relayer/chains/evm/keys_test.go | 358 ++++++++ relayer/chains/evm/provider.go | 28 +- relayer/chains/evm/provider_eip1559_test.go | 333 +++++++ relayer/chains/evm/provider_legacy_test.go | 242 +++++ relayer/chains/evm/provider_test.go | 301 +++++++ relayer/chains/evm/sender.go | 2 +- relayer/chains/evm/sender_test.go | 158 ++++ relayer/chains/evm/types.go | 3 - relayer/chains/evm/utils.go | 16 +- relayer/chains/evm/utils_test.go | 77 ++ relayer/chains/provider.go | 12 +- relayer/config_test.go | 177 ++-- relayer/tunnel_relayer_test.go | 258 ++++++ scripts/mockgen.sh | 1 + 26 files changed, 3265 insertions(+), 404 deletions(-) create mode 100644 cmd/chains_test.go create mode 100644 internal/relayertest/mocks/chain_evm_client.go create mode 100644 internal/relayertest/testdata/chain_config_invalid_chain_type.toml create mode 100644 relayer/chains/evm/keys_test.go create mode 100644 relayer/chains/evm/provider_eip1559_test.go create mode 100644 relayer/chains/evm/provider_legacy_test.go create mode 100644 relayer/chains/evm/provider_test.go create mode 100644 relayer/chains/evm/sender_test.go create mode 100644 relayer/tunnel_relayer_test.go diff --git a/cmd/chains_test.go b/cmd/chains_test.go new file mode 100644 index 0000000..2f65126 --- /dev/null +++ b/cmd/chains_test.go @@ -0,0 +1,128 @@ +package cmd_test + +import ( + "os" + "path" + "regexp" + "testing" + + "github.com/pelletier/go-toml/v2" + "github.com/stretchr/testify/require" + + "github.com/bandprotocol/falcon/internal/relayertest" +) + +func TestChainsListEmpty(t *testing.T) { + sys := relayertest.NewSystem(t) + + res := sys.RunWithInput(t, "config", "init") + require.NoError(t, res.Err) + + res = sys.RunWithInput(t, "chains", "list") + require.Empty(t, res.Stdout.String()) +} + +func TestChainsAdd(t *testing.T) { + sys := relayertest.NewSystem(t) + + res := sys.RunWithInput(t, "config", "init") + require.NoError(t, res.Err) + + chainCfgPath := path.Join(sys.HomeDir, "chain_config.toml") + err := os.WriteFile(chainCfgPath, []byte(relayertest.ChainCfgText), 0o600) + require.NoError(t, err) + + require.FileExists(t, chainCfgPath) + + // Add chain + res = sys.RunWithInput(t, "chains", "add", "testnet", chainCfgPath) + require.Empty(t, res.Stdout.String()) + require.Empty(t, res.Stderr.String()) + + // Add another chain + res = sys.RunWithInput(t, "ch", "a", "testnet2", chainCfgPath) + require.Empty(t, res.Stdout.String()) + require.Empty(t, res.Stderr.String()) + + // Add existing chain + res = sys.RunWithInput(t, "ch", "a", "testnet", chainCfgPath) + require.Empty(t, res.Stdout.String()) + require.Error(t, res.Err, "existing chain name") + + // List chains to check + res = sys.RunWithInput(t, "chains", "list") + require.Regexp(t, regexp.MustCompile(`\d+: ([\w-]+) -> type\((\w+)\)`), res.Stdout.String()) + require.Empty(t, res.Stderr.String()) +} + +func TestChainsDelete(t *testing.T) { + sys := relayertest.NewSystem(t) + + res := sys.RunWithInput(t, "config", "init") + require.NoError(t, res.Err) + + chainCfgPath := path.Join(sys.HomeDir, "chain_config.toml") + err := os.WriteFile(chainCfgPath, []byte(relayertest.ChainCfgText), 0o600) + require.NoError(t, err) + + require.FileExists(t, chainCfgPath) + + // Add chain + res = sys.RunWithInput(t, "chains", "add", "testnet", chainCfgPath) + require.Empty(t, res.Stdout.String()) + require.Empty(t, res.Stderr.String()) + + // Add another chain + res = sys.RunWithInput(t, "chains", "add", "testnet2", chainCfgPath) + require.Empty(t, res.Stdout.String()) + require.Empty(t, res.Stderr.String()) + + // List chains + res = sys.RunWithInput(t, "chains", "list") + require.Regexp(t, regexp.MustCompile(`\d+: ([\w-]+) -> type\((\w+)\)`), res.Stdout.String()) + require.Empty(t, res.Stderr.String()) + + // Delete chain + res = sys.RunWithInput(t, "chains", "delete", "testnet") + require.Empty(t, res.Stdout.String()) + require.Empty(t, res.Stderr.String()) + + res = sys.RunWithInput(t, "ch", "d", "testnet2") + require.Empty(t, res.Stdout.String()) + require.Empty(t, res.Stderr.String()) + + // List chain with shorthand command + res = sys.RunWithInput(t, "ch", "l") + require.Empty(t, res.Stdout.String()) +} + +func TestChainsShow(t *testing.T) { + sys := relayertest.NewSystem(t) + + res := sys.RunWithInput(t, "config", "init") + require.NoError(t, res.Err) + + chainCfgPath := path.Join(sys.HomeDir, "chain_config.toml") + err := os.WriteFile(chainCfgPath, []byte(relayertest.ChainCfgText), 0o600) + require.NoError(t, err) + + require.FileExists(t, chainCfgPath) + + // Add chain + res = sys.RunWithInput(t, "chains", "add", "testnet", chainCfgPath) + require.Empty(t, res.Stdout.String()) + require.Empty(t, res.Stderr.String()) + + // Show chain configuration + res = sys.RunWithInput(t, "chains", "show", "testnet") + + var expectedChainCfg map[string]interface{} + err = toml.Unmarshal(res.Stdout.Bytes(), &expectedChainCfg) + require.NoError(t, err) + + var actualChainCfg map[string]interface{} + err = toml.Unmarshal([]byte(relayertest.ChainCfgText), &actualChainCfg) + require.NoError(t, err) + + require.Equal(t, expectedChainCfg, actualChainCfg) +} diff --git a/cmd/config_test.go b/cmd/config_test.go index d2db276..4e303be 100644 --- a/cmd/config_test.go +++ b/cmd/config_test.go @@ -10,7 +10,7 @@ import ( "github.com/bandprotocol/falcon/internal/relayertest" ) -func TestShowConfigCmd(t *testing.T) { +func TestConfigShow(t *testing.T) { sys := relayertest.NewSystem(t) res := sys.RunWithInput(t, "config", "init") @@ -23,14 +23,14 @@ func TestShowConfigCmd(t *testing.T) { require.Equal(t, relayertest.DefaultCfgText+"\n", actual) } -func TestShowConfigCmdNotInit(t *testing.T) { +func TestConfigShowNotInit(t *testing.T) { sys := relayertest.NewSystem(t) res := sys.RunWithInput(t, "config", "show") require.ErrorContains(t, res.Err, "config does not exist:") } -func TestInitCmdDefault(t *testing.T) { +func TestConfigInitDefault(t *testing.T) { sys := relayertest.NewSystem(t) res := sys.RunWithInput(t, "config", "init") @@ -46,7 +46,7 @@ func TestInitCmdDefault(t *testing.T) { require.Equal(t, relayertest.DefaultCfgText, string(actualBytes)) } -func TestInitCmdWithFileShortFlag(t *testing.T) { +func TestConfigInitWithFileShortFlag(t *testing.T) { sys := relayertest.NewSystem(t) customCfgPath := path.Join(sys.HomeDir, "custom.toml") @@ -66,7 +66,7 @@ func TestInitCmdWithFileShortFlag(t *testing.T) { require.Equal(t, relayertest.CustomCfgText, string(actualBytes)) } -func TestInitCmdWithFileLongFlag(t *testing.T) { +func TestConfigInitWithFileLongFlag(t *testing.T) { sys := relayertest.NewSystem(t) customCfgPath := path.Join(sys.HomeDir, "custom.toml") @@ -85,7 +85,7 @@ func TestInitCmdWithFileLongFlag(t *testing.T) { require.Equal(t, relayertest.CustomCfgText, string(actualBytes)) } -func TestInitCmdWithFileTimeString(t *testing.T) { +func TestConfigInitWithFileTimeString(t *testing.T) { sys := relayertest.NewSystem(t) customCfgPath := path.Join(sys.HomeDir, "custom.toml") @@ -104,7 +104,7 @@ func TestInitCmdWithFileTimeString(t *testing.T) { require.Equal(t, relayertest.CustomCfgText, string(actualBytes)) } -func TestInitCmdInvalidFile(t *testing.T) { +func TestConfigInitInvalidFile(t *testing.T) { sys := relayertest.NewSystem(t) customCfgPath := path.Join(sys.HomeDir, "custom.toml") @@ -115,7 +115,7 @@ func TestInitCmdInvalidFile(t *testing.T) { require.ErrorContains(t, res.Err, "error toml: expected newline") } -func TestInitCmdNoCustomFile(t *testing.T) { +func TestConfigInitNoCustomFile(t *testing.T) { sys := relayertest.NewSystem(t) customCfgPath := path.Join(sys.HomeDir, "custom.toml") @@ -123,7 +123,7 @@ func TestInitCmdNoCustomFile(t *testing.T) { require.ErrorContains(t, res.Err, "no such file or directory") } -func TestInitCmdAlreadyExist(t *testing.T) { +func TestConfigInitAlreadyExist(t *testing.T) { sys := relayertest.NewSystem(t) res := sys.RunWithInput(t, "config", "init") diff --git a/internal/relayertest/constants.go b/internal/relayertest/constants.go index 5c5c368..e89dcb0 100644 --- a/internal/relayertest/constants.go +++ b/internal/relayertest/constants.go @@ -58,3 +58,6 @@ var ChainCfgText string //go:embed testdata/default_with_chain_config.toml var DefaultCfgTextWithChainCfg string + +//go:embed testdata/chain_config_invalid_chain_type.toml +var ChainCfgInvalidChainTypeText string diff --git a/internal/relayertest/mocks/chain_evm_client.go b/internal/relayertest/mocks/chain_evm_client.go new file mode 100644 index 0000000..da26632 --- /dev/null +++ b/internal/relayertest/mocks/chain_evm_client.go @@ -0,0 +1,236 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: relayer/chains/evm/client.go +// +// Generated by this command: +// +// mockgen -source=relayer/chains/evm/client.go -mock_names Client=MockEVMClient -package mocks -destination internal/relayertest/mocks/chain_evm_client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + big "math/big" + reflect "reflect" + time "time" + + ethereum "github.com/ethereum/go-ethereum" + common "github.com/ethereum/go-ethereum/common" + types "github.com/ethereum/go-ethereum/core/types" + gomock "go.uber.org/mock/gomock" +) + +// MockEVMClient is a mock of Client interface. +type MockEVMClient struct { + ctrl *gomock.Controller + recorder *MockEVMClientMockRecorder + isgomock struct{} +} + +// MockEVMClientMockRecorder is the mock recorder for MockEVMClient. +type MockEVMClientMockRecorder struct { + mock *MockEVMClient +} + +// NewMockEVMClient creates a new mock instance. +func NewMockEVMClient(ctrl *gomock.Controller) *MockEVMClient { + mock := &MockEVMClient{ctrl: ctrl} + mock.recorder = &MockEVMClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEVMClient) EXPECT() *MockEVMClientMockRecorder { + return m.recorder +} + +// BroadcastTx mocks base method. +func (m *MockEVMClient) BroadcastTx(ctx context.Context, tx *types.Transaction) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BroadcastTx", ctx, tx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BroadcastTx indicates an expected call of BroadcastTx. +func (mr *MockEVMClientMockRecorder) BroadcastTx(ctx, tx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BroadcastTx", reflect.TypeOf((*MockEVMClient)(nil).BroadcastTx), ctx, tx) +} + +// CheckAndConnect mocks base method. +func (m *MockEVMClient) CheckAndConnect(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CheckAndConnect", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// CheckAndConnect indicates an expected call of CheckAndConnect. +func (mr *MockEVMClientMockRecorder) CheckAndConnect(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckAndConnect", reflect.TypeOf((*MockEVMClient)(nil).CheckAndConnect), ctx) +} + +// Connect mocks base method. +func (m *MockEVMClient) Connect(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connect", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Connect indicates an expected call of Connect. +func (mr *MockEVMClientMockRecorder) Connect(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockEVMClient)(nil).Connect), ctx) +} + +// EstimateBaseFee mocks base method. +func (m *MockEVMClient) EstimateBaseFee(ctx context.Context) (*big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EstimateBaseFee", ctx) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EstimateBaseFee indicates an expected call of EstimateBaseFee. +func (mr *MockEVMClientMockRecorder) EstimateBaseFee(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EstimateBaseFee", reflect.TypeOf((*MockEVMClient)(nil).EstimateBaseFee), ctx) +} + +// EstimateGas mocks base method. +func (m *MockEVMClient) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EstimateGas", ctx, msg) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EstimateGas indicates an expected call of EstimateGas. +func (mr *MockEVMClientMockRecorder) EstimateGas(ctx, msg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EstimateGas", reflect.TypeOf((*MockEVMClient)(nil).EstimateGas), ctx, msg) +} + +// EstimateGasPrice mocks base method. +func (m *MockEVMClient) EstimateGasPrice(ctx context.Context) (*big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EstimateGasPrice", ctx) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EstimateGasPrice indicates an expected call of EstimateGasPrice. +func (mr *MockEVMClientMockRecorder) EstimateGasPrice(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EstimateGasPrice", reflect.TypeOf((*MockEVMClient)(nil).EstimateGasPrice), ctx) +} + +// EstimateGasTipCap mocks base method. +func (m *MockEVMClient) EstimateGasTipCap(ctx context.Context) (*big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "EstimateGasTipCap", ctx) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EstimateGasTipCap indicates an expected call of EstimateGasTipCap. +func (mr *MockEVMClientMockRecorder) EstimateGasTipCap(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EstimateGasTipCap", reflect.TypeOf((*MockEVMClient)(nil).EstimateGasTipCap), ctx) +} + +// GetBalance mocks base method. +func (m *MockEVMClient) GetBalance(ctx context.Context, gethAddr common.Address) (*big.Int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBalance", ctx, gethAddr) + ret0, _ := ret[0].(*big.Int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBalance indicates an expected call of GetBalance. +func (mr *MockEVMClientMockRecorder) GetBalance(ctx, gethAddr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBalance", reflect.TypeOf((*MockEVMClient)(nil).GetBalance), ctx, gethAddr) +} + +// GetBlockHeight mocks base method. +func (m *MockEVMClient) GetBlockHeight(ctx context.Context) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBlockHeight", ctx) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetBlockHeight indicates an expected call of GetBlockHeight. +func (mr *MockEVMClientMockRecorder) GetBlockHeight(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBlockHeight", reflect.TypeOf((*MockEVMClient)(nil).GetBlockHeight), ctx) +} + +// GetTxReceipt mocks base method. +func (m *MockEVMClient) GetTxReceipt(ctx context.Context, txHash string) (*types.Receipt, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTxReceipt", ctx, txHash) + ret0, _ := ret[0].(*types.Receipt) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTxReceipt indicates an expected call of GetTxReceipt. +func (mr *MockEVMClientMockRecorder) GetTxReceipt(ctx, txHash any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTxReceipt", reflect.TypeOf((*MockEVMClient)(nil).GetTxReceipt), ctx, txHash) +} + +// PendingNonceAt mocks base method. +func (m *MockEVMClient) PendingNonceAt(ctx context.Context, address common.Address) (uint64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PendingNonceAt", ctx, address) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PendingNonceAt indicates an expected call of PendingNonceAt. +func (mr *MockEVMClientMockRecorder) PendingNonceAt(ctx, address any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PendingNonceAt", reflect.TypeOf((*MockEVMClient)(nil).PendingNonceAt), ctx, address) +} + +// Query mocks base method. +func (m *MockEVMClient) Query(ctx context.Context, gethAddr common.Address, data []byte) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Query", ctx, gethAddr, data) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockEVMClientMockRecorder) Query(ctx, gethAddr, data any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockEVMClient)(nil).Query), ctx, gethAddr, data) +} + +// StartLivelinessCheck mocks base method. +func (m *MockEVMClient) StartLivelinessCheck(ctx context.Context, interval time.Duration) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartLivelinessCheck", ctx, interval) +} + +// StartLivelinessCheck indicates an expected call of StartLivelinessCheck. +func (mr *MockEVMClientMockRecorder) StartLivelinessCheck(ctx, interval any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartLivelinessCheck", reflect.TypeOf((*MockEVMClient)(nil).StartLivelinessCheck), ctx, interval) +} diff --git a/internal/relayertest/mocks/chain_provider.go b/internal/relayertest/mocks/chain_provider.go index 621b8f2..595e3e2 100644 --- a/internal/relayertest/mocks/chain_provider.go +++ b/internal/relayertest/mocks/chain_provider.go @@ -188,11 +188,12 @@ func (mr *MockChainProviderMockRecorder) RelayPacket(ctx, packet any) *gomock.Ca } // ShowKey mocks base method. -func (m *MockChainProvider) ShowKey(keyName string) string { +func (m *MockChainProvider) ShowKey(keyName string) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ShowKey", keyName) ret0, _ := ret[0].(string) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // ShowKey indicates an expected call of ShowKey. @@ -312,11 +313,12 @@ func (mr *MockKeyProviderMockRecorder) LoadFreeSenders(homePath, passphrase any) } // ShowKey mocks base method. -func (m *MockKeyProvider) ShowKey(keyName string) string { +func (m *MockKeyProvider) ShowKey(keyName string) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ShowKey", keyName) ret0, _ := ret[0].(string) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // ShowKey indicates an expected call of ShowKey. diff --git a/internal/relayertest/system.go b/internal/relayertest/system.go index 02fe682..b75c939 100644 --- a/internal/relayertest/system.go +++ b/internal/relayertest/system.go @@ -55,5 +55,6 @@ func (s *System) RunWithInput(t *testing.T, args ...string) RunResult { rootCmd.SetArgs(args) res.Err = rootCmd.ExecuteContext(ctx) + return res } diff --git a/internal/relayertest/testdata/chain_config_invalid_chain_type.toml b/internal/relayertest/testdata/chain_config_invalid_chain_type.toml new file mode 100644 index 0000000..604fd6f --- /dev/null +++ b/internal/relayertest/testdata/chain_config_invalid_chain_type.toml @@ -0,0 +1,15 @@ +endpoints = ['http://localhost:8545'] +chain_type = 'evm2' +max_retry = 3 +query_timeout = 3000000000 +chain_id = 31337 +tunnel_router_address = '0xDc64a140Aa3E981100a9becA4E685f962f0cF6C9' +private_key = '' +block_confirmation = 5 +waiting_tx_duration = 3000000000 +checking_tx_interval = 1000000000 +gas_type = 'eip1559' +gas_multiplier = 1.1 +execute_timeout = 3000000000 +liveliness_checking_interval = 900000000000 + diff --git a/relayer/app.go b/relayer/app.go index fb1f6a5..1e4e441 100644 --- a/relayer/app.go +++ b/relayer/app.go @@ -22,10 +22,10 @@ import ( ) const ( - configFolderName = "config" - configFileName = "config.toml" - passphraseFileName = "passphrase.hash" - passphraseEnvKey = "PASSPHRASE" + ConfigFolderName = "config" + ConfigFileName = "config.toml" + PassphraseFileName = "passphrase.hash" + PassphraseEnvKey = "PASSPHRASE" ) // App is the main application struct. @@ -35,7 +35,7 @@ type App struct { Debug bool Config *Config - targetChains chains.ChainProviders + TargetChains chains.ChainProviders BandClient band.Client EnvPassphrase string } @@ -127,7 +127,7 @@ func (a *App) initLogger(logLevel, logFormat string) error { // initTargetChains initializes the target chains. func (a *App) initTargetChains() error { - a.targetChains = make(chains.ChainProviders) + a.TargetChains = make(chains.ChainProviders) for chainName, chainConfig := range a.Config.TargetChains { cp, err := chainConfig.NewChainProvider(chainName, a.Log, a.HomePath, a.Debug) @@ -139,7 +139,7 @@ func (a *App) initTargetChains() error { return err } - a.targetChains[chainName] = cp + a.TargetChains[chainName] = cp } return nil @@ -147,7 +147,7 @@ func (a *App) initTargetChains() error { // LoadConfigFile reads config file into a.Config if file is present. func (a *App) LoadConfigFile() error { - cfgPath := path.Join(a.HomePath, configFolderName, configFileName) + cfgPath := path.Join(a.HomePath, ConfigFolderName, ConfigFileName) // check if file doesn't exist, exit the function as the config may not be initialized. if _, err := os.Stat(cfgPath); os.IsNotExist(err) { @@ -170,8 +170,8 @@ func (a *App) LoadConfigFile() error { // InitConfigFile initializes the configuration to the given path. func (a *App) InitConfigFile(homePath string, customFilePath string) error { - cfgDir := path.Join(homePath, configFolderName) - cfgPath := path.Join(cfgDir, configFileName) + cfgDir := path.Join(homePath, ConfigFolderName) + cfgPath := path.Join(cfgDir, ConfigFileName) // check if the config file already exists // https://stackoverflow.com/questions/12518876/how-to-check-if-a-file-exists-in-go @@ -231,8 +231,8 @@ func (a *App) InitPassphrase() error { h.Write([]byte(a.EnvPassphrase)) b := h.Sum(nil) - cfgDir := path.Join(a.HomePath, configFolderName) - passphrasePath := path.Join(cfgDir, passphraseFileName) + cfgDir := path.Join(a.HomePath, ConfigFolderName) + passphrasePath := path.Join(cfgDir, PassphraseFileName) // Create the file and write the hashed passphrase to the given location. f, err := os.Create(passphrasePath) @@ -267,7 +267,7 @@ func (a *App) QueryTunnelInfo(ctx context.Context, tunnelID uint64) (*types.Tunn tunnel.IsActive, ) - cp, ok := a.targetChains[bandChainInfo.TargetChainID] + cp, ok := a.TargetChains[bandChainInfo.TargetChainID] if !ok { a.Log.Debug("Target chain provider not found", zap.String("chain_id", bandChainInfo.TargetChainID)) return types.NewTunnel(bandChainInfo, nil), nil @@ -310,8 +310,8 @@ func (a *App) AddChainConfig(chainName string, filePath string) error { a.Config.TargetChains[chainName] = chainProviderConfig - cfgDir := path.Join(a.HomePath, configFolderName) - cfgPath := path.Join(cfgDir, configFileName) + cfgDir := path.Join(a.HomePath, ConfigFolderName) + cfgPath := path.Join(cfgDir, ConfigFileName) // Marshal config object into bytes b, err := toml.Marshal(a.Config) @@ -334,8 +334,8 @@ func (a *App) DeleteChainConfig(chainName string) error { delete(a.Config.TargetChains, chainName) - cfgDir := path.Join(a.HomePath, configFolderName) - cfgPath := path.Join(cfgDir, configFileName) + cfgDir := path.Join(a.HomePath, ConfigFolderName) + cfgPath := path.Join(cfgDir, ConfigFileName) // Marshal config object into bytes b, err := toml.Marshal(a.Config) @@ -375,20 +375,15 @@ func (a *App) AddKey( return nil, fmt.Errorf("config does not exist: %s", a.HomePath) } - if err := a.validatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { return nil, err } - cp, exist := a.targetChains[chainName] - + cp, exist := a.TargetChains[chainName] if !exist { return nil, fmt.Errorf("chain name does not exist: %s", chainName) } - if cp.IsKeyNameExist(keyName) { - return nil, fmt.Errorf("key name already exists: %s", keyName) - } - keyOutput, err := cp.AddKey(keyName, mnemonic, privateKey, a.HomePath, coinType, account, index, a.EnvPassphrase) if err != nil { return nil, err @@ -403,20 +398,15 @@ func (a *App) DeleteKey(chainName string, keyName string) error { return fmt.Errorf("config does not exist: %s", a.HomePath) } - if err := a.validatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { return err } - cp, exist := a.targetChains[chainName] - + cp, exist := a.TargetChains[chainName] if !exist { return fmt.Errorf("chain name does not exist: %s", chainName) } - if !cp.IsKeyNameExist(keyName) { - return fmt.Errorf("key name does not exist: %s", keyName) - } - return cp.DeleteKey(a.HomePath, keyName, a.EnvPassphrase) } @@ -426,20 +416,15 @@ func (a *App) ExportKey(chainName string, keyName string) (string, error) { return "", fmt.Errorf("config does not exist: %s", a.HomePath) } - if err := a.validatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { return "", err } - cp, exist := a.targetChains[chainName] - + cp, exist := a.TargetChains[chainName] if !exist { return "", fmt.Errorf("chain name does not exist: %s", chainName) } - if !cp.IsKeyNameExist(keyName) { - return "", fmt.Errorf("key name does not exist: %s", chainName) - } - privateKey, err := cp.ExportPrivateKey(keyName, a.EnvPassphrase) if err != nil { return "", err @@ -451,13 +436,12 @@ func (a *App) ExportKey(chainName string, keyName string) (string, error) { // ListKeys retrieves the list of keys from the chain provider. func (a *App) ListKeys(chainName string) ([]*chainstypes.Key, error) { if a.Config == nil { - return make([]*chainstypes.Key, 0), fmt.Errorf("config does not exist: %s", a.HomePath) + return nil, fmt.Errorf("config does not exist: %s", a.HomePath) } - cp, exist := a.targetChains[chainName] - + cp, exist := a.TargetChains[chainName] if !exist { - return make([]*chainstypes.Key, 0), fmt.Errorf("chain name does not exist: %s", chainName) + return nil, fmt.Errorf("chain name does not exist: %s", chainName) } return cp.ListKeys(), nil @@ -469,16 +453,12 @@ func (a *App) ShowKey(chainName string, keyName string) (string, error) { return "", fmt.Errorf("config does not exist: %s", a.HomePath) } - cp, exist := a.targetChains[chainName] + cp, exist := a.TargetChains[chainName] if !exist { return "", fmt.Errorf("chain name does not exist: %s", chainName) } - if !cp.IsKeyNameExist(keyName) { - return "", fmt.Errorf("key name does not exist: %s", keyName) - } - - return cp.ShowKey(keyName), nil + return cp.ShowKey(keyName) } // QueryBalance retrieves the balance of the key from the chain provider. @@ -487,7 +467,7 @@ func (a *App) QueryBalance(ctx context.Context, chainName string, keyName string return nil, fmt.Errorf("config does not exist: %s", a.HomePath) } - cp, exist := a.targetChains[chainName] + cp, exist := a.TargetChains[chainName] if !exist { return nil, fmt.Errorf("chain name does not exist: %s", chainName) @@ -513,20 +493,20 @@ func (a *App) loadEnvPassphrase() string { } else { a.Log.Debug("Loaded .env file successfully, attempting to use variable from .env file") } - return os.Getenv(passphraseEnvKey) + return os.Getenv(PassphraseEnvKey) } -// validatePassphrase checks if the provided passphrase (from the environment) +// ValidatePassphrase checks if the provided passphrase (from the environment) // matches the hashed passphrase stored on disk. -func (a *App) validatePassphrase(envPassphrase string) error { +func (a *App) ValidatePassphrase(envPassphrase string) error { // prepare bytes slices of hashed env passphrase h := sha256.New() h.Write([]byte(envPassphrase)) envb := h.Sum(nil) // load passphrase from local disk - cfgDir := path.Join(a.HomePath, configFolderName) - passphrasePath := path.Join(cfgDir, passphraseFileName) + cfgDir := path.Join(a.HomePath, ConfigFolderName) + passphrasePath := path.Join(cfgDir, PassphraseFileName) b, err := os.ReadFile(passphrasePath) if err != nil { @@ -551,12 +531,12 @@ func (a *App) Start(ctx context.Context, tunnelIDs []uint64) error { } // validate passphrase - if err := a.validatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { return err } // initialize target chain providers - for chainName, chainProvider := range a.targetChains { + for chainName, chainProvider := range a.TargetChains { if err := chainProvider.LoadFreeSenders(a.HomePath, a.EnvPassphrase); err != nil { a.Log.Error("Cannot load keys in target chain", zap.Error(err), @@ -577,7 +557,7 @@ func (a *App) Start(ctx context.Context, tunnelIDs []uint64) error { // initialize the tunnel relayer tunnelRelayers := []*TunnelRelayer{} for _, tunnel := range tunnels { - chainProvider, ok := a.targetChains[tunnel.TargetChainID] + chainProvider, ok := a.TargetChains[tunnel.TargetChainID] if !ok { return fmt.Errorf("target chain provider not found: %s", tunnel.TargetChainID) } @@ -604,7 +584,7 @@ func (a *App) Start(ctx context.Context, tunnelIDs []uint64) error { a.Config.Global.PenaltyExponentialFactor, isSyncTunnelsAllowed, a.BandClient, - a.targetChains, + a.TargetChains, ) return scheduler.Start(ctx) @@ -618,11 +598,11 @@ func (a *App) Relay(ctx context.Context, tunnelID uint64) error { return err } - if err := a.validatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { return err } - chainProvider, ok := a.targetChains[tunnel.TargetChainID] + chainProvider, ok := a.TargetChains[tunnel.TargetChainID] if !ok { return fmt.Errorf("target chain provider not found: %s", tunnel.TargetChainID) } diff --git a/relayer/app_test.go b/relayer/app_test.go index 837dd4e..33d2f3b 100644 --- a/relayer/app_test.go +++ b/relayer/app_test.go @@ -2,6 +2,8 @@ package relayer_test import ( "context" + "crypto/sha256" + "fmt" "os" "path" "testing" @@ -19,6 +21,7 @@ import ( "github.com/bandprotocol/falcon/relayer/band" bandtypes "github.com/bandprotocol/falcon/relayer/band/types" "github.com/bandprotocol/falcon/relayer/chains" + "github.com/bandprotocol/falcon/relayer/chains/evm" chainstypes "github.com/bandprotocol/falcon/relayer/chains/types" "github.com/bandprotocol/falcon/relayer/types" ) @@ -27,7 +30,6 @@ type AppTestSuite struct { suite.Suite app *relayer.App - ctx context.Context chainProviderConfig *mocks.MockChainProviderConfig chainProvider *mocks.MockChainProvider client *mocks.MockClient @@ -37,22 +39,13 @@ type AppTestSuite struct { func (s *AppTestSuite) SetupTest() { tmpDir := s.T().TempDir() ctrl := gomock.NewController(s.T()) - - log, err := zap.NewDevelopment() - s.Require().NoError(err) + log := zap.NewNop() // mock objects. s.chainProviderConfig = mocks.NewMockChainProviderConfig(ctrl) s.chainProvider = mocks.NewMockChainProvider(ctrl) s.client = mocks.NewMockClient(ctrl) - s.chainProviderConfig.EXPECT(). - NewChainProvider("testnet_evm", log, tmpDir, false). - Return(s.chainProvider, nil). - AnyTimes() - - s.chainProvider.EXPECT().Init(gomock.Any()).Return(nil).AnyTimes() - cfg := relayer.Config{ BandChain: band.Config{ RpcEndpoints: []string{"http://localhost:26659"}, @@ -63,12 +56,24 @@ func (s *AppTestSuite) SetupTest() { }, Global: relayer.GlobalConfig{}, } - s.ctx = context.Background() - s.app = relayer.NewApp(log, tmpDir, false, &cfg) + cfgFolder := path.Join(tmpDir, relayer.ConfigFolderName) + err := os.Mkdir(cfgFolder, os.ModePerm) + s.Require().NoError(err) + + s.app = &relayer.App{ + Log: log, + HomePath: tmpDir, + Config: &cfg, + TargetChains: map[string]chains.ChainProvider{ + "testnet_evm": s.chainProvider, + }, + BandClient: s.client, + EnvPassphrase: "secret", + } - err = s.app.Init(s.ctx, "", "") - s.app.BandClient = s.client + // Call InitPassphrase + err = s.app.InitPassphrase() s.Require().NoError(err) } @@ -77,129 +82,348 @@ func TestAppTestSuite(t *testing.T) { } func (s *AppTestSuite) TestInitConfig() { - s.app.Config = nil - customCfgPath := "" - - err := s.app.InitConfigFile(s.app.HomePath, customCfgPath) - s.Require().NoError(err) + testcases := []struct { + name string + preprocess func() + in string + out *relayer.Config + err error + }{ + { + name: "success - default", + in: "", + out: relayer.DefaultConfig(), + }, + { + name: "config already exists", + preprocess: func() { + err := s.app.InitConfigFile(s.app.HomePath, "") + s.Require().NoError(err) + }, + in: "", + err: fmt.Errorf("config already exists:"), + }, + { + name: "init config from specific file", + preprocess: func() { + customCfgPath := path.Join(s.app.HomePath, "custom.toml") + cfg := ` + [target_chains] + + [global] + checking_packet_interval = 60000000000 + + [bandchain] + rpc_endpoints = ['http://localhost:26659'] + timeout = 50 + ` + + err := os.WriteFile(customCfgPath, []byte(cfg), 0o600) + s.Require().NoError(err) + }, + in: path.Join(s.app.HomePath, "custom.toml"), + out: &relayer.Config{ + BandChain: band.Config{ + RpcEndpoints: []string{"http://localhost:26659"}, + Timeout: 50, + }, + TargetChains: map[string]chains.ChainProviderConfig{}, + Global: relayer.GlobalConfig{ + CheckingPacketInterval: time.Minute, + }, + }, + }, + } - cfgPath := path.Join(s.app.HomePath, "config", "config.toml") - s.Require().FileExists(cfgPath) + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preprocess != nil { + tc.preprocess() + } + + err := s.app.InitConfigFile(s.app.HomePath, tc.in) + cfgFolder := path.Join(s.app.HomePath, relayer.ConfigFolderName) + cfgPath := path.Join(cfgFolder, relayer.ConfigFileName) + + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + actualByte, err := os.ReadFile(cfgPath) + s.Require().NoError(err) + + // marshal default config + expect := tc.out + expectBytes, err := toml.Marshal(expect) + s.Require().NoError(err) + + s.Require().Equal(string(expectBytes), string(actualByte)) + } + + // clear config folder + err = os.RemoveAll(cfgFolder) + s.Require().NoError(err) + }) + } +} - // read the file - actualByte, err := os.ReadFile(cfgPath) +func (s *AppTestSuite) TestAddChainConfig() { + newHomePath := path.Join(s.app.HomePath, "new_folder") + err := os.Mkdir(newHomePath, os.ModePerm) s.Require().NoError(err) - // marshal default config - expect := relayer.DefaultConfig() - expectBytes, err := toml.Marshal(expect) - s.Require().NoError(err) + type Input struct { + chainName string + cfgPath string + existingCfg *relayer.Config + } + testcases := []struct { + name string + preprocess func() + in Input + err error + out string + }{ + { + name: "success", + in: Input{ + chainName: "testnet", + cfgPath: path.Join(newHomePath, "chain_config.toml"), + }, + preprocess: func() { + chainCfgPath := path.Join(newHomePath, "chain_config.toml") + err := os.WriteFile(chainCfgPath, []byte(relayertest.ChainCfgText), 0o600) + s.Require().NoError(err) + }, + out: relayertest.DefaultCfgTextWithChainCfg, + }, + { + name: "invalid chain type", + in: Input{ + chainName: "testnet", + cfgPath: path.Join(newHomePath, "chain_config.toml"), + }, + preprocess: func() { + chainCfgPath := path.Join(newHomePath, "chain_config.toml") + err := os.WriteFile(chainCfgPath, []byte(relayertest.ChainCfgInvalidChainTypeText), 0o600) + s.Require().NoError(err) + }, + err: fmt.Errorf("unsupported chain type"), + }, + { + name: "existing chain name", + in: Input{ + chainName: "testnet", + cfgPath: path.Join(newHomePath, "chain_config.toml"), + existingCfg: &relayer.Config{ + TargetChains: map[string]chains.ChainProviderConfig{ + "testnet": &evm.EVMChainProviderConfig{}, + }, + }, + }, + preprocess: func() { + chainCfgPath := path.Join(newHomePath, "chain_config.toml") + err := os.WriteFile(chainCfgPath, []byte(relayertest.ChainCfgText), 0o600) + s.Require().NoError(err) + }, + err: fmt.Errorf("existing chain name :"), + }, + } - s.Require().Equal(string(expectBytes), string(actualByte)) + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preprocess != nil { + tc.preprocess() + } + + // init app + app := relayer.NewApp(nil, newHomePath, false, tc.in.existingCfg) + if app.Config == nil { + err := app.InitConfigFile(newHomePath, "") + s.Require().NoError(err) + s.Require().FileExists(path.Join(newHomePath, "config", "config.toml")) + + err = app.LoadConfigFile() + s.Require().NoError(err) + s.Require().NotNil(app.Config) + } + + err = app.AddChainConfig(tc.in.chainName, tc.in.cfgPath) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + + actualBytes, err := os.ReadFile(path.Join(newHomePath, "config", "config.toml")) + + s.Require().NoError(err) + s.Require().Equal(tc.out, string(actualBytes)) + } + + // clear config folder + cfgFolder := path.Join(newHomePath, relayer.ConfigFolderName) + err = os.RemoveAll(cfgFolder) + s.Require().NoError(err) + }) + } } -func (s *AppTestSuite) TestInitExistingConfig() { - s.app.Config = nil - customCfgPath := "" - - err := s.app.InitConfigFile(s.app.HomePath, customCfgPath) +func (s *AppTestSuite) TestDeleteChainConfig() { + newHomePath := path.Join(s.app.HomePath, "new_folder") + err := os.Mkdir(newHomePath, os.ModePerm) s.Require().NoError(err) - // second time should fail - err = s.app.InitConfigFile(s.app.HomePath, customCfgPath) - s.Require().ErrorContains(err, "config already exists:") -} - -func (s *AppTestSuite) TestInitCustomConfig() { - s.app.Config = nil - customCfgPath := path.Join(s.app.HomePath, "custom.toml") - - // Create custom config file - cfg := ` - [target_chains] - - [global] - checking_packet_interval = 60000000000 - - [bandchain] - rpc_endpoints = ['http://localhost:26659'] - timeout = 50 - ` // write file - err := os.WriteFile(customCfgPath, []byte(cfg), 0o600) - s.Require().NoError(err) - - err = s.app.InitConfigFile(s.app.HomePath, customCfgPath) - s.Require().NoError(err) - - s.Require().FileExists(path.Join(s.app.HomePath, "config", "config.toml")) - - // read the file - b, err := os.ReadFile(path.Join(s.app.HomePath, "config", "config.toml")) - s.Require().NoError(err) - - // unmarshal data - actual := relayer.Config{} - err = toml.Unmarshal(b, &actual) + customCfgPath := path.Join(s.app.HomePath, "custom.toml") + err = os.WriteFile(customCfgPath, []byte(relayertest.DefaultCfgTextWithChainCfg), 0o600) s.Require().NoError(err) - expect := relayer.Config{ - BandChain: band.Config{ - RpcEndpoints: []string{"http://localhost:26659"}, - Timeout: 50, + testcases := []struct { + name string + in string + out string + err error + }{ + { + name: "success", + in: "testnet", + out: relayertest.DefaultCfgText, }, - TargetChains: nil, - Global: relayer.GlobalConfig{ - CheckingPacketInterval: time.Minute, + { + name: "not existing chain name", + in: "testnet2", + err: fmt.Errorf("not existing chain name"), }, } - s.Require().Equal(expect, actual) + for _, tc := range testcases { + s.Run(tc.name, func() { + app := relayer.NewApp(nil, newHomePath, false, nil) + err := app.InitConfigFile(newHomePath, customCfgPath) + s.Require().NoError(err) + + // load config file + err = app.LoadConfigFile() + s.Require().NoError(err) + + err = app.DeleteChainConfig(tc.in) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + + actualBytes, err := os.ReadFile(path.Join(newHomePath, "config", "config.toml")) + s.Require().NoError(err) + s.Require().Equal(tc.out, string(actualBytes)) + } + + // clear config folder + cfgFolder := path.Join(newHomePath, relayer.ConfigFolderName) + err = os.RemoveAll(cfgFolder) + s.Require().NoError(err) + }) + } } -func (s *AppTestSuite) TestQueryTunnelInfo() { - tunnelBandInfo := bandtypes.NewTunnel(1, 1, "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", "testnet_evm", false) - tunnelChainInfo := chainstypes.NewTunnel(1, "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", false) - - s.client.EXPECT(). - GetTunnel(s.ctx, uint64(1)). - Return(tunnelBandInfo, nil) - - s.chainProvider.EXPECT(). - QueryTunnelInfo(s.ctx, uint64(1), "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"). - Return(tunnelChainInfo, nil) - - tunnel, err := s.app.QueryTunnelInfo(s.ctx, 1) - bandChainInfo := bandtypes.NewTunnel(1, 1, "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", "testnet_evm", false) +func (s *AppTestSuite) TestGetChainConfig() { + testcases := []struct { + name string + in string + err error + out chains.ChainProviderConfig + }{ + { + name: "success", + in: "testnet_evm", + out: s.chainProviderConfig, + }, + { + name: "not existing chain name", + in: "testnet_evm2", + err: fmt.Errorf("not existing chain name"), + }, + } - expected := types.NewTunnel( - bandChainInfo, - tunnelChainInfo, - ) - s.Require().NoError(err) - s.Require().Equal(expected, tunnel) + for _, tc := range testcases { + s.Run(tc.name, func() { + actual, err := s.app.GetChainConfig(tc.in) + + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(tc.out, actual) + } + }) + } } -func (s *AppTestSuite) TestQueryTunnelInfoNotSupportedChain() { - s.app.Config.TargetChains = nil - err := s.app.Init(s.ctx, "", "") - - s.Require().NoError(err) - - tunnelBandInfo := bandtypes.NewTunnel(1, 1, "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", "testnet_evm", false) - s.client.EXPECT(). - GetTunnel(s.ctx, uint64(1)). - Return(tunnelBandInfo, nil) - s.app.BandClient = s.client - - tunnel, err := s.app.QueryTunnelInfo(s.ctx, 1) +func (s *AppTestSuite) TestQueryTunnelInfo() { + mockTunnelBandInfo := bandtypes.NewTunnel(1, 1, "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", "testnet_evm", false) + mockTunnelBandInfoNoChain := bandtypes.NewTunnel(1, 1, "0xmock", "unknown_chain", false) + mockTunnelChainInfo := chainstypes.NewTunnel(1, "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", false) + + testcases := []struct { + name string + preprocess func() + in uint64 + out *types.Tunnel + err error + }{ + { + name: "success", + preprocess: func() { + s.client.EXPECT(). + GetTunnel(gomock.Any(), uint64(1)). + Return(mockTunnelBandInfo, nil) + s.chainProvider.EXPECT(). + QueryTunnelInfo(gomock.Any(), uint64(1), "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"). + Return(mockTunnelChainInfo, nil) + }, + in: 1, + out: types.NewTunnel(mockTunnelBandInfo, mockTunnelChainInfo), + }, + { + name: "cannot query chain info", + preprocess: func() { + s.client.EXPECT(). + GetTunnel(gomock.Any(), uint64(1)). + Return(mockTunnelBandInfo, nil) + s.chainProvider.EXPECT(). + QueryTunnelInfo(gomock.Any(), uint64(1), "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"). + Return(nil, fmt.Errorf("cannot connect to chain")) + }, + in: 1, + err: fmt.Errorf("cannot connect to chain"), + }, + { + name: "no chain provider", + preprocess: func() { + s.client.EXPECT(). + GetTunnel(gomock.Any(), uint64(1)). + Return(mockTunnelBandInfoNoChain, nil) + }, + in: 1, + out: types.NewTunnel(mockTunnelBandInfoNoChain, nil), + }, + } - expected := types.NewTunnel( - tunnelBandInfo, - nil, - ) - s.Require().NoError(err) - s.Require().Equal(expected, tunnel) + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preprocess != nil { + tc.preprocess() + } + + tunnel, err := s.app.QueryTunnelInfo(context.Background(), tc.in) + + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(tc.out, tunnel) + } + }) + } } func (s *AppTestSuite) TestQueryTunnelPacketInfo() { @@ -233,11 +457,11 @@ func (s *AppTestSuite) TestQueryTunnelPacketInfo() { // Set up the mock expectation s.client.EXPECT(). - GetTunnelPacket(s.ctx, uint64(1), uint64(1)). + GetTunnelPacket(gomock.Any(), uint64(1), uint64(1)). Return(tunnelPacketBandInfo, nil) // Call the function under test - packet, err := s.app.QueryTunnelPacketInfo(s.ctx, 1, 1) + packet, err := s.app.QueryTunnelPacketInfo(context.Background(), 1, 1) // Create the expected packet structure for comparison expected := bandtypes.NewPacket(1, 1, signalPrices, signingInfo, nil) @@ -247,89 +471,357 @@ func (s *AppTestSuite) TestQueryTunnelPacketInfo() { s.Require().Equal(expected, packet) } -func (s *AppTestSuite) TestAddChainConfig() { - s.app.Config = nil - // create new chain config file - chainCfgPath := path.Join(s.app.HomePath, "chain_config.toml") - chainName := "testnet" - - // write chain config file - err := os.WriteFile(chainCfgPath, []byte(relayertest.ChainCfgText), 0o600) +func (s *AppTestSuite) TestInitPassphrase() { + // reset passphrase file. + err := os.Remove(path.Join(s.app.HomePath, "config", "passphrase.hash")) s.Require().NoError(err) - s.Require().FileExists(chainCfgPath) - - // init chain config file - customCfgPath := "" - err = s.app.InitConfigFile(s.app.HomePath, customCfgPath) - s.Require().NoError(err) - - s.Require().FileExists(path.Join(s.app.HomePath, "config", "config.toml")) - - // load config - err = s.app.LoadConfigFile() + // Call InitPassphrase + err = s.app.InitPassphrase() s.Require().NoError(err) - err = s.app.AddChainConfig(chainName, chainCfgPath) + // Verify the file exists + passphrasePath := path.Join(s.app.HomePath, "config", "passphrase.hash") + _, err = os.Stat(passphrasePath) s.Require().NoError(err) - expectedBytes := []byte(relayertest.DefaultCfgTextWithChainCfg) - actualBytes, err := os.ReadFile(path.Join(s.app.HomePath, "config", "config.toml")) + // Verify file content + hasher := sha256.New() + hasher.Write([]byte(s.app.EnvPassphrase)) + expectedHash := hasher.Sum(nil) + actualContent, err := os.ReadFile(passphrasePath) s.Require().NoError(err) - s.Require().Equal(relayertest.DefaultCfgTextWithChainCfg, string(actualBytes)) - - s.Require().Equal(expectedBytes, actualBytes) + s.Require().Equal(expectedHash, actualContent) } -func (s *AppTestSuite) TestDeleteChainConfig() { - s.app.Config = nil - customCfgPath := path.Join(s.app.HomePath, "custom.toml") +func (s *AppTestSuite) TestAddKey() { + testcases := []struct { + name string + chainName string + keyName string + mnemonic string + privateKey string + coinType uint32 + account uint + index uint + err error + out *chainstypes.Key + preprocess func() + }{ + { + name: "success - private key", + chainName: "testnet_evm", + keyName: "testkey", + privateKey: "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", // anvil + coinType: 60, + out: chainstypes.NewKey("", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266", ""), + preprocess: func() { + s.chainProvider.EXPECT(). + AddKey( + "testkey", + "", + "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", + s.app.HomePath, + uint32(60), + uint(0), + uint(0), + s.app.EnvPassphrase, + ). + Return(chainstypes.NewKey("", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266", ""), nil) + }, + }, + { + name: "error from AddKey", + chainName: "testnet_evm", + keyName: "testkey", + privateKey: "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", // anvil + coinType: 60, + preprocess: func() { + s.chainProvider.EXPECT(). + AddKey( + "testkey", + "", + "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", + s.app.HomePath, + uint32(60), + uint(0), + uint(0), + s.app.EnvPassphrase, + ). + Return(nil, fmt.Errorf("add key error")) + }, + err: fmt.Errorf("add key error"), + }, + { + name: "chain name does not exist", + chainName: "testnet_evm2", + keyName: "testkey", + privateKey: "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", // anvil + coinType: 60, + err: fmt.Errorf("chain name does not exist:"), + }, + } - // write file - err := os.WriteFile(customCfgPath, []byte(relayertest.DefaultCfgTextWithChainCfg), 0o600) - s.Require().NoError(err) + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preprocess != nil { + tc.preprocess() + } + + actual, err := s.app.AddKey( + tc.chainName, + tc.keyName, + tc.mnemonic, + tc.privateKey, + tc.coinType, + tc.account, + tc.index, + ) + + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(tc.out, actual) + } + }) + } +} - err = s.app.InitConfigFile(s.app.HomePath, customCfgPath) - s.Require().NoError(err) +func (s *AppTestSuite) TestDeleteKey() { + testcases := []struct { + name string + chainName string + keyName string + err error + preprocess func() + }{ + { + name: "success", + chainName: "testnet_evm", + keyName: "testkey", + preprocess: func() { + s.chainProvider.EXPECT(). + DeleteKey(s.app.HomePath, "testkey", s.app.EnvPassphrase). + Return(nil) + }, + }, + { + name: "error delete key", + chainName: "testnet_evm", + keyName: "testkey", + preprocess: func() { + s.chainProvider.EXPECT(). + DeleteKey(s.app.HomePath, "testkey", s.app.EnvPassphrase). + Return(fmt.Errorf("delete key error")) + }, + err: fmt.Errorf("delete key error"), + }, + { + name: "chain name does not exist", + chainName: "testnet_evm2", + keyName: "testkey", + err: fmt.Errorf("chain name does not exist:"), + }, + } - // load config file - err = s.app.LoadConfigFile() - s.Require().NoError(err) + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preprocess != nil { + tc.preprocess() + } - // delete chain config by given chain name - chainName := "testnet" - err = s.app.DeleteChainConfig(chainName) - s.Require().NoError(err) + err := s.app.DeleteKey(tc.chainName, tc.keyName) - expectedBytes := []byte(relayertest.DefaultCfgText) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + } + }) + } +} - actualBytes, err := os.ReadFile(path.Join(s.app.HomePath, "config", "config.toml")) - s.Require().NoError(err) +func (s *AppTestSuite) TestExportKey() { + testcases := []struct { + name string + chainName string + keyName string + out string + err error + preprocess func() + }{ + { + name: "success", + chainName: "testnet_evm", + keyName: "testkey", + preprocess: func() { + s.chainProvider.EXPECT(). + ExportPrivateKey("testkey", s.app.EnvPassphrase). + Return("0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", nil) + }, + out: "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", + }, + { + name: "error export private key", + chainName: "testnet_evm", + keyName: "testkey", + preprocess: func() { + s.chainProvider.EXPECT(). + ExportPrivateKey("testkey", s.app.EnvPassphrase). + Return("", fmt.Errorf("export key error")) + }, + err: fmt.Errorf("export key error"), + }, + { + name: "chain name does not exist", + chainName: "testnet_evm2", + keyName: "testkey", + err: fmt.Errorf("chain name does not exist:"), + }, + } - s.Require().Equal(expectedBytes, actualBytes) + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preprocess != nil { + tc.preprocess() + } + + actual, err := s.app.ExportKey(tc.chainName, tc.keyName) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(tc.out, actual) + } + }) + } } -func (s *AppTestSuite) TestGetChainConfig() { - s.app.Config = nil - customCfgPath := path.Join(s.app.HomePath, "custom.toml") - - // write file - err := os.WriteFile(customCfgPath, []byte(relayertest.DefaultCfgTextWithChainCfg), 0o600) - s.Require().NoError(err) +func (s *AppTestSuite) TestListKeys() { + testcases := []struct { + name string + in string + preprocess func() + err error + out []*chainstypes.Key + }{ + { + name: "success", + in: "testnet_evm", + preprocess: func() { + s.chainProvider.EXPECT(). + ListKeys(). + Return([]*chainstypes.Key{ + chainstypes.NewKey("", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266", "testkey1"), + chainstypes.NewKey("", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92267", "testkey2"), + }) + }, + out: []*chainstypes.Key{ + chainstypes.NewKey("", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266", "testkey1"), + chainstypes.NewKey("", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92267", "testkey2"), + }, + }, + { + name: "chain name does not exist", + in: "testnet_evm2", + err: fmt.Errorf("chain name does not exist:"), + }, + } - err = s.app.InitConfigFile(s.app.HomePath, customCfgPath) - s.Require().NoError(err) + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preprocess != nil { + tc.preprocess() + } + + actual, err := s.app.ListKeys(tc.in) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(actual, tc.out) + } + }) + } +} - // load config file - err = s.app.LoadConfigFile() - s.Require().NoError(err) +func (s *AppTestSuite) TestShowKey() { + testcases := []struct { + name string + chainName string + keyName string + preprocess func() + err error + out string + }{ + { + name: "success", + chainName: "testnet_evm", + keyName: "testkey", + preprocess: func() { + s.chainProvider.EXPECT(). + ShowKey("testkey"). + Return("0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92267", nil) + }, + out: "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92267", + }, + { + name: "show key error", + chainName: "testnet_evm", + keyName: "testkey", + preprocess: func() { + s.chainProvider.EXPECT(). + ShowKey("testkey"). + Return("", fmt.Errorf("show key error")) + }, + out: "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92267", + err: fmt.Errorf("show key error"), + }, + { + name: "chain name does not exist", + chainName: "testnet_evm2", + keyName: "testkey", + err: fmt.Errorf("chain name does not exist:"), + }, + } - chainName := "testnet" - actual, err := s.app.GetChainConfig(chainName) - s.Require().NoError(err) + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preprocess != nil { + tc.preprocess() + } + + actual, err := s.app.ShowKey(tc.chainName, tc.keyName) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(actual, tc.out) + } + }) + } +} - expect := relayertest.CustomCfg.TargetChains[chainName] +func (s *AppTestSuite) TestValidatePassphraseInvalidPassphrase() { + testcases := []struct { + name string + envPassphrase string + err error + }{ + {name: "valid", envPassphrase: "secret", err: nil}, + {name: "invalid", envPassphrase: "invalid", err: fmt.Errorf("invalid passphrase")}, + } - s.Require().Equal(expect, actual) + for _, tc := range testcases { + s.Run(tc.name, func() { + err := s.app.ValidatePassphrase(tc.envPassphrase) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + } + }) + } } diff --git a/relayer/band/client_test.go b/relayer/band/client_test.go index 1d9fd1f..3c0ccbe 100644 --- a/relayer/band/client_test.go +++ b/relayer/band/client_test.go @@ -2,6 +2,7 @@ package band_test import ( "context" + "fmt" "testing" "time" @@ -9,6 +10,7 @@ import ( cmbytes "github.com/cometbft/cometbft/libs/bytes" codectypes "github.com/cosmos/cosmos-sdk/codec/types" "github.com/cosmos/cosmos-sdk/types" + querytypes "github.com/cosmos/cosmos-sdk/types/query" "github.com/cosmos/gogoproto/proto" "github.com/stretchr/testify/suite" "go.uber.org/mock/gomock" @@ -24,7 +26,7 @@ import ( bandclienttypes "github.com/bandprotocol/falcon/relayer/band/types" ) -type AppTestSuite struct { +type ClientTestSuite struct { suite.Suite ctx context.Context @@ -34,19 +36,16 @@ type AppTestSuite struct { log *zap.Logger } -func TestAppTestSuite(t *testing.T) { - suite.Run(t, new(AppTestSuite)) +func TestClientTestSuite(t *testing.T) { + suite.Run(t, new(ClientTestSuite)) } // SetupTest sets up the test suite by creating a temporary directory and declare mock objects. -func (s *AppTestSuite) SetupTest() { +func (s *ClientTestSuite) SetupTest() { ctrl := gomock.NewController(s.T()) - log, err := zap.NewDevelopment() - s.Require().NoError(err) - // mock objects. - s.log = log + s.log = zap.NewNop() s.tunnelQueryClient = mocks.NewMockTunnelQueryClient(ctrl) s.bandtssQueryClient = mocks.NewMockBandtssQueryClient(ctrl) s.client = band.NewClient( @@ -57,50 +56,121 @@ func (s *AppTestSuite) SetupTest() { s.ctx = context.Background() } -func (s *AppTestSuite) TestGetTunnel() { - // mock route value - destinationChainID := "eth" - destinationContractAddress := "0xe00F1f85abDB2aF6760759547d450da68CE66Bb1" - r := &tunneltypes.TSSRoute{ - DestinationChainID: destinationChainID, - DestinationContractAddress: destinationContractAddress, - } - var routeI tunneltypes.RouteI = r +// GetMockIBCTunnel returns a mock IBC tunnel. +func (s *ClientTestSuite) GetMockIBCTunnel(tunnelID uint64) (tunneltypes.Tunnel, error) { + ibcRoute := tunneltypes.IBCRoute{ChannelID: "test"} + var routeI tunneltypes.RouteI = &ibcRoute + msg, ok := routeI.(proto.Message) - s.Require().Equal(true, ok) + if !ok { + return tunneltypes.Tunnel{}, fmt.Errorf("cannot convert route to proto.Message") + } - any, err := codectypes.NewAnyWithValue(msg) - s.Require().NoError(err) + routeAny, err := codectypes.NewAnyWithValue(msg) + if err != nil { + return tunneltypes.Tunnel{}, err + } - tunnel := tunneltypes.Tunnel{ - ID: uint64(1), + return tunneltypes.Tunnel{ + ID: tunnelID, Sequence: 100, - Route: any, + Route: routeAny, FeePayer: "cosmos1xyz...", SignalDeviations: []tunneltypes.SignalDeviation{}, Interval: 60, TotalDeposit: types.NewCoins(types.NewCoin("uband", math.NewInt(1000))), IsActive: false, - CreatedAt: time.Now().Unix(), + CreatedAt: 1736145613, Creator: "cosmos1abc...", + }, nil +} + +// GetMockTSSTunnel returns a mock TSS tunnel. +func (s *ClientTestSuite) GetMockTSSTunnel(tunnelID uint64) (tunneltypes.Tunnel, error) { + r := &tunneltypes.TSSRoute{ + DestinationChainID: "eth", + DestinationContractAddress: "0xe00F1f85abDB2aF6760759547d450da68CE66Bb1", } - queryResponse := &tunneltypes.QueryTunnelResponse{ - Tunnel: tunnel, + var routeI tunneltypes.RouteI = r + + msg, ok := routeI.(proto.Message) + if !ok { + return tunneltypes.Tunnel{}, fmt.Errorf("cannot convert route to proto.Message") } - // expect response from bandQueryClient - s.tunnelQueryClient.EXPECT().Tunnel(s.ctx, &tunneltypes.QueryTunnelRequest{ - TunnelId: uint64(1), - }).Return(queryResponse, nil) + routeAny, err := codectypes.NewAnyWithValue(msg) + if err != nil { + return tunneltypes.Tunnel{}, err + } + + return tunneltypes.Tunnel{ + ID: tunnelID, + Sequence: 100, + Route: routeAny, + FeePayer: "cosmos1xyz...", + SignalDeviations: []tunneltypes.SignalDeviation{}, + Interval: 60, + TotalDeposit: types.NewCoins(types.NewCoin("uband", math.NewInt(1000))), + IsActive: false, + CreatedAt: 1736145613, + Creator: "cosmos1abc...", + }, nil +} - expected := bandclienttypes.NewTunnel(1, 100, "0xe00F1f85abDB2aF6760759547d450da68CE66Bb1", "eth", false) +func (s *ClientTestSuite) TestGetTunnel() { + tssTunnel, err := s.GetMockTSSTunnel(1) + s.Require().NoError(err) - actual, err := s.client.GetTunnel(s.ctx, uint64(1)) + ibcTunnel, err := s.GetMockIBCTunnel(2) s.Require().NoError(err) - s.Require().Equal(expected, actual) + + testcases := []struct { + name string + in uint64 + preprocess func(c context.Context) + out *bandclienttypes.Tunnel + err error + }{ + { + name: "success", + in: 1, + out: bandclienttypes.NewTunnel(1, 100, "0xe00F1f85abDB2aF6760759547d450da68CE66Bb1", "eth", false), + preprocess: func(c context.Context) { + s.tunnelQueryClient.EXPECT().Tunnel(s.ctx, &tunneltypes.QueryTunnelRequest{ + TunnelId: uint64(1), + }).Return(&tunneltypes.QueryTunnelResponse{Tunnel: tssTunnel}, nil) + }, + }, + { + name: "unsupported route type", + in: 2, + err: fmt.Errorf("unsupported route type"), + preprocess: func(c context.Context) { + s.tunnelQueryClient.EXPECT().Tunnel(s.ctx, &tunneltypes.QueryTunnelRequest{ + TunnelId: uint64(2), + }).Return(&tunneltypes.QueryTunnelResponse{Tunnel: ibcTunnel}, nil) + }, + }, + } + + for _, tc := range testcases { + s.T().Run(tc.name, func(t *testing.T) { + if tc.preprocess != nil { + tc.preprocess(s.ctx) + } + + actual, err := s.client.GetTunnel(s.ctx, tc.in) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(tc.out, actual) + } + }) + } } -func (s *AppTestSuite) TestGetTunnelPacket() { +func (s *ClientTestSuite) TestGetTSSTunnelPacket() { // mock query response pc := &tunneltypes.TSSPacketReceipt{ SigningID: 2, @@ -142,6 +212,7 @@ func (s *AppTestSuite) TestGetTunnelPacket() { CurrentGroupSigningResult: signingResult, IncomingGroupSigningResult: nil, } + // expect response from bandQueryClient s.tunnelQueryClient.EXPECT().Packet(s.ctx, &tunneltypes.QueryPacketRequest{ TunnelId: uint64(1), @@ -181,3 +252,143 @@ func (s *AppTestSuite) TestGetTunnelPacket() { s.Require().NoError(err) s.Require().Equal(expected, actual) } + +func (s *ClientTestSuite) TestGetOtherTunnelPacket() { + // mock query response + pc := &tunneltypes.IBCPacketReceipt{ + Sequence: 2, + } + + var packetReceiptI tunneltypes.PacketReceiptI = pc + msg, ok := packetReceiptI.(proto.Message) + s.Require().Equal(true, ok) + + any, err := codectypes.NewAnyWithValue(msg) + s.Require().NoError(err) + + packet := tunneltypes.Packet{ + TunnelID: 1, + Sequence: 100, + Prices: []feedstypes.Price{ + {SignalID: "signal1", Price: 100}, + {SignalID: "signal2", Price: 200}, + }, + Receipt: any, + CreatedAt: time.Now().Unix(), + } + + queryPacketResponse := &tunneltypes.QueryPacketResponse{ + Packet: &packet, + } + + // expect response from bandQueryClient + s.tunnelQueryClient.EXPECT().Packet(s.ctx, &tunneltypes.QueryPacketRequest{ + TunnelId: uint64(1), + Sequence: uint64(100), + }).Return(queryPacketResponse, nil) + + // actual result + _, err = s.client.GetTunnelPacket(s.ctx, uint64(1), uint64(100)) + s.Require().ErrorContains(err, "unsupported packet content type") +} + +func (s *ClientTestSuite) TestGetTunnels() { + // mock tunnels result + tssTunnels := make([]*tunneltypes.Tunnel, 0, 120) + for i := 1; i <= cap(tssTunnels); i++ { + tunnel, err := s.GetMockTSSTunnel(uint64(i)) + s.Require().NoError(err) + + tssTunnels = append(tssTunnels, &tunnel) + } + + // expected result from tssTunnels + expectedRes := make([]bandclienttypes.Tunnel, 0, len(tssTunnels)) + for _, tunnel := range tssTunnels { + routeI, err := tunnel.GetRouteValue() + s.Require().NoError(err) + + tssRoute, ok := routeI.(*tunneltypes.TSSRoute) + s.Require().True(ok) + + expectedRes = append(expectedRes, *bandclienttypes.NewTunnel( + tunnel.ID, + tunnel.Sequence, + tssRoute.DestinationContractAddress, + tssRoute.DestinationChainID, + tunnel.IsActive, + )) + } + + // create mock ibc tunnel + ibcTunnel, err := s.GetMockIBCTunnel(uint64(121)) + s.Require().NoError(err) + + testcases := []struct { + name string + preprocess func(c context.Context) + out []bandclienttypes.Tunnel + err error + }{ + { + name: "success", + preprocess: func(c context.Context) { + // expect response from bandQueryClient + s.tunnelQueryClient.EXPECT().Tunnels(s.ctx, &tunneltypes.QueryTunnelsRequest{ + Pagination: &querytypes.PageRequest{ + Key: nil, + }, + }).Return(&tunneltypes.QueryTunnelsResponse{ + Tunnels: tssTunnels[:100], + Pagination: &querytypes.PageResponse{ + NextKey: []byte("next-key"), + }, + }, nil) + + s.tunnelQueryClient.EXPECT().Tunnels(s.ctx, &tunneltypes.QueryTunnelsRequest{ + Pagination: &querytypes.PageRequest{ + Key: []byte("next-key"), + }, + }).Return(&tunneltypes.QueryTunnelsResponse{ + Tunnels: tssTunnels[100:], + Pagination: &querytypes.PageResponse{ + NextKey: []byte(""), + }, + }, nil) + }, + out: expectedRes, + }, + { + name: "filter out unrelated tunnel", + preprocess: func(c context.Context) { + s.tunnelQueryClient.EXPECT().Tunnels(s.ctx, &tunneltypes.QueryTunnelsRequest{ + Pagination: &querytypes.PageRequest{ + Key: nil, + }, + }).Return(&tunneltypes.QueryTunnelsResponse{ + Tunnels: []*tunneltypes.Tunnel{tssTunnels[0], &ibcTunnel}, + Pagination: &querytypes.PageResponse{ + NextKey: []byte(""), + }, + }, nil) + }, + out: []bandclienttypes.Tunnel{expectedRes[0]}, + }, + } + + for _, tc := range testcases { + s.T().Run(tc.name, func(t *testing.T) { + if tc.preprocess != nil { + tc.preprocess(s.ctx) + } + + actual, err := s.client.GetTunnels(s.ctx) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(tc.out, actual) + } + }) + } +} diff --git a/relayer/band/types/signing.go b/relayer/band/types/signing.go index 1074ecf..0163157 100644 --- a/relayer/band/types/signing.go +++ b/relayer/band/types/signing.go @@ -26,12 +26,12 @@ func NewEVMSignature( // Signing contains information of a requested message and group signature. type Signing struct { ID uint64 `json:"id"` - Message cmbytes.HexBytes `json:"messsage"` + Message cmbytes.HexBytes `json:"message"` EVMSignature *EVMSignature `json:"evm_signature"` Status string `json:"signing_status"` } -// ConvertSigning converts tsstypes.SigningResult and return .Signing +// ConvertSigning converts tsstypes.SigningResult and return Signing type. func ConvertSigning(res *tsstypes.SigningResult) *Signing { if res == nil { return nil diff --git a/relayer/chains/evm/keys.go b/relayer/chains/evm/keys.go index 06d82a8..70bd672 100644 --- a/relayer/chains/evm/keys.go +++ b/relayer/chains/evm/keys.go @@ -42,9 +42,22 @@ func (cp *EVMChainProvider) AddKey( index uint, passphrase string, ) (*chainstypes.Key, error) { + if cp.IsKeyNameExist(keyName) { + return nil, fmt.Errorf("duplicate key name") + } + if privateKey != "" { return cp.AddKeyWithPrivateKey(keyName, privateKey, homePath, passphrase) } + + var err error + // Generate mnemonic if not provided + if mnemonic == "" { + mnemonic, err = hdwallet.NewMnemonic(mnemonicSize) + if err != nil { + return nil, err + } + } return cp.AddKeyWithMnemonic(keyName, mnemonic, homePath, coinType, account, index, passphrase) } @@ -58,16 +71,6 @@ func (cp *EVMChainProvider) AddKeyWithMnemonic( index uint, passphrase string, ) (*chainstypes.Key, error) { - var err error - - // Generate mnemonic if not provided - if mnemonic == "" { - mnemonic, err = hdwallet.NewMnemonic(mnemonicSize) - if err != nil { - return nil, err - } - } - // Generate private key using mnemonic priv, err := cp.generatePrivateKey(mnemonic, coinType, account, index) if err != nil { @@ -85,7 +88,7 @@ func (cp *EVMChainProvider) AddKeyWithPrivateKey( passphrase string, ) (*chainstypes.Key, error) { // Convert private key from hex - priv, err := crypto.HexToECDSA(ConvertPrivateKeyStrToHex(privateKey)) + priv, err := crypto.HexToECDSA(StripPrivateKeyPrefix(privateKey)) if err != nil { return nil, err } @@ -125,14 +128,12 @@ func (cp *EVMChainProvider) finalizeKeyAddition( return chainstypes.NewKey(mnemonic, addressHex, ""), nil } -// IsKeyNameExist checks whether the given key name is already in use. -func (cp *EVMChainProvider) IsKeyNameExist(keyName string) bool { - _, ok := cp.KeyInfo[keyName] - return ok -} - // DeleteKey deletes the given key name from the key store and removes its information. func (cp *EVMChainProvider) DeleteKey(homePath, keyName, passphrase string) error { + if !cp.IsKeyNameExist(keyName) { + return fmt.Errorf("key name does not exist: %s", keyName) + } + address, err := HexToAddress(cp.KeyInfo[keyName]) if err != nil { return err @@ -148,7 +149,11 @@ func (cp *EVMChainProvider) DeleteKey(homePath, keyName, passphrase string) erro // ExportPrivateKey exports private key of given key name. func (cp *EVMChainProvider) ExportPrivateKey(keyName, passphrase string) (string, error) { - key, err := cp.getKeyFromKeyName(keyName, passphrase) + if !cp.IsKeyNameExist(keyName) { + return "", fmt.Errorf("key name does not exist: %s", keyName) + } + + key, err := cp.GetKeyFromKeyName(keyName, passphrase) if err != nil { return "", err } @@ -166,8 +171,18 @@ func (cp *EVMChainProvider) ListKeys() []*chainstypes.Key { } // ShowKey shows key by the given name. -func (cp *EVMChainProvider) ShowKey(keyName string) string { - return cp.KeyInfo[keyName] +func (cp *EVMChainProvider) ShowKey(keyName string) (string, error) { + if !cp.IsKeyNameExist(keyName) { + return "", fmt.Errorf("key name does not exist: %s", keyName) + } + + return cp.KeyInfo[keyName], nil +} + +// IsKeyNameExist checks whether the given key name is already in use. +func (cp *EVMChainProvider) IsKeyNameExist(keyName string) bool { + _, ok := cp.KeyInfo[keyName] + return ok } // storePrivateKey stores private key to keyStore. @@ -236,9 +251,7 @@ func (cp *EVMChainProvider) generatePrivateKey( return privatekey, nil } -func (cp *EVMChainProvider) getKeyFromKeyName( - keyName, passphrase string, -) (*keyStore.Key, error) { +func (cp *EVMChainProvider) GetKeyFromKeyName(keyName, passphrase string) (*keyStore.Key, error) { address, err := HexToAddress(cp.KeyInfo[keyName]) if err != nil { return nil, err diff --git a/relayer/chains/evm/keys_test.go b/relayer/chains/evm/keys_test.go new file mode 100644 index 0000000..a514c23 --- /dev/null +++ b/relayer/chains/evm/keys_test.go @@ -0,0 +1,358 @@ +package evm_test + +import ( + "encoding/hex" + "fmt" + "testing" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + "github.com/bandprotocol/falcon/relayer/chains/evm" + chaintypes "github.com/bandprotocol/falcon/relayer/chains/types" +) + +const ( + testPrivateKey = "0x72d4772a70645a5a5ec3fdc27afda98d2860a6f7903bff5fd45c0a23d7982121" + testAddress = "0x990Ec0f6dFc9e8eE20dec3Ab855D03007A9dD946" + testMnemonic = "repeat sugar clarify visa chief soon walnut kangaroo rude parrot height piano spoil desk basket swim income catalog more plunge supreme above later worry" +) + +type KeysTestSuite struct { + suite.Suite + + chainProvider *evm.EVMChainProvider + log *zap.Logger + homePath string +} + +func TestKeysTestSuite(t *testing.T) { + suite.Run(t, new(KeysTestSuite)) +} + +// SetupTest sets up the test suite by creating a temporary directory and declare mock objects. +func (s *KeysTestSuite) SetupTest() { + s.homePath = s.T().TempDir() + s.log = zap.NewNop() + + chainName := "testnet" + client := evm.NewClient(chainName, evmCfg, s.log) + + chainProvider, err := evm.NewEVMChainProvider(chainName, client, evmCfg, s.log, s.homePath) + s.Require().NoError(err) + s.chainProvider = chainProvider +} + +func (s *KeysTestSuite) TestAddKeyByPrivateKey() { + type Input struct { + keyName string + privKey string + } + testcases := []struct { + name string + input Input + err error + out *chaintypes.Key + }{ + { + name: "success", + input: Input{ + keyName: "testkey", + privKey: testPrivateKey, + }, + out: chaintypes.NewKey("", testAddress, ""), + }, + { + name: "invalid private key", + input: Input{ + keyName: "testkey2", + privKey: "x72d4772a70645a5a5ec3fdc27afda98d2860a6f7903bff5fd45c0a23d7982121", + }, + err: fmt.Errorf("invalid hex character"), + }, + { + name: "duplicate private key", + input: Input{ + keyName: "testkey3", + privKey: testPrivateKey, + }, + err: fmt.Errorf("account already exists"), + }, + } + + for _, tc := range testcases { + s.T().Run(tc.name, func(t *testing.T) { + key, err := s.chainProvider.AddKey( + tc.input.keyName, + "", + tc.input.privKey, + s.homePath, + 0, + 0, + 0, + "", + ) + + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(tc.out, key) + + // check that key info actually stored in local disk + keyInfo, err := evm.LoadKeyInfo(s.homePath, s.chainProvider.ChainName) + s.Require().NoError(err) + + _, exist := keyInfo[tc.input.keyName] + s.Require().True(exist) + } + }) + } +} + +func (s *KeysTestSuite) TestAddKeyByMnemonic() { + type Input struct { + keyName string + mnemonic string + coinType uint32 + account uint + index uint + } + testcases := []struct { + name string + input Input + err error + out *chaintypes.Key + }{ + { + name: "success", + input: Input{ + keyName: "testkey", + mnemonic: testMnemonic, + coinType: 60, + account: 0, + index: 0, + }, + out: chaintypes.NewKey(testMnemonic, testAddress, ""), + }, + { + name: "success with different index", + input: Input{ + keyName: "testkey2", + mnemonic: testMnemonic, + coinType: 60, + account: 0, + index: 1, + }, + out: chaintypes.NewKey(testMnemonic, "0x01AF9badF97c97C9444E0b7fa94b69b8CB3C28e7", ""), + }, + { + name: "success with no mnemonic", + input: Input{ + keyName: "testkey3", + mnemonic: "", + coinType: 60, + account: 0, + index: 0, + }, + }, + { + name: "duplicate key name", + input: Input{ + keyName: "testkey", + mnemonic: "", + coinType: 60, + account: 0, + index: 0, + }, + err: fmt.Errorf("duplicate key name"), + }, + { + name: "invalid mnemonic", + input: Input{ + keyName: "testkey4", + mnemonic: "mnemonic", + coinType: 60, + account: 0, + index: 0, + }, + err: fmt.Errorf("mnemonic is invalid"), + }, + } + + for _, tc := range testcases { + s.T().Run(tc.name, func(t *testing.T) { + key, err := s.chainProvider.AddKey( + tc.input.keyName, + tc.input.mnemonic, + "", + s.homePath, + tc.input.coinType, + tc.input.account, + tc.input.index, + "", + ) + + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + + if tc.out != nil { + s.Require().Equal(tc.out, key) + } + + // check that key info actually stored in local disk + keyInfo, err := evm.LoadKeyInfo(s.homePath, s.chainProvider.ChainName) + s.Require().NoError(err) + + _, exist := keyInfo[tc.input.keyName] + s.Require().True(exist) + } + }) + } +} + +func (s *KeysTestSuite) TestDeleteKey() { + keyName := "deletablekey" + privatekeyHex := testPrivateKey + + // Add a key to delete + _, err := s.chainProvider.AddKeyWithPrivateKey(keyName, privatekeyHex, s.homePath, "") + s.Require().NoError(err) + + // Delete the key + err = s.chainProvider.DeleteKey(s.homePath, keyName, "") + s.Require().NoError(err) + + // Ensure the key is no longer in the KeyInfo or KeyStore + s.Require().False(s.chainProvider.IsKeyNameExist(keyName)) + + addr, err := evm.HexToAddress(testAddress) + s.Require().NoError(err) + s.Require().False(s.chainProvider.KeyStore.HasAddress(addr)) + + // Delete the key again should return error + err = s.chainProvider.DeleteKey(s.homePath, keyName, "") + s.Require().ErrorContains(err, "key name does not exist") +} + +func (s *KeysTestSuite) TestExportPrivateKey() { + keyName := "exportkey" + privatekeyHex := testPrivateKey + + // Add a key to export + _, err := s.chainProvider.AddKeyWithPrivateKey(keyName, privatekeyHex, s.homePath, "") + s.Require().NoError(err) + + // Export the private key + exportedKey, err := s.chainProvider.ExportPrivateKey(keyName, "") + s.Require().NoError(err) + + s.Require().Equal(evm.StripPrivateKeyPrefix(privatekeyHex), evm.StripPrivateKeyPrefix(exportedKey)) +} + +func (s *KeysTestSuite) TestListKeys() { + // Add multiple keys + keyName1 := "key1" + keyName2 := "key2" + mnemonic := "" + privateKey := "" + coinType := 60 + account := 0 + index := 0 + passphrase := "" + + key1, err := s.chainProvider.AddKey( + keyName1, + mnemonic, + privateKey, + s.homePath, + uint32(coinType), + uint(account), + uint(index), + passphrase, + ) + s.Require().NoError(err) + + key2, err := s.chainProvider.AddKey( + keyName2, + mnemonic, + privateKey, + s.homePath, + uint32(coinType), + uint(account), + uint(index), + passphrase, + ) + s.Require().NoError(err) + + // List all keys + actual := s.chainProvider.ListKeys() + s.Require().Equal(2, len(actual)) + + expected1 := chaintypes.NewKey("", key1.Address, keyName1) + expected2 := chaintypes.NewKey("", key2.Address, keyName2) + + // Check if expected1 and expected2 are in actual + foundExpected1 := false + foundExpected2 := false + + for _, key := range actual { + if key.Address == expected1.Address { + foundExpected1 = true + } + if key.Address == expected2.Address { + foundExpected2 = true + } + } + + s.Require().True(foundExpected1) + s.Require().True(foundExpected2) +} + +func (s *KeysTestSuite) TestShowKey() { + keyName := "showkey" + privatekeyHex := testPrivateKey + + // Add a key to show + _, err := s.chainProvider.AddKeyWithPrivateKey(keyName, privatekeyHex, s.homePath, "") + s.Require().NoError(err) + + // Show the key + address, err := s.chainProvider.ShowKey(keyName) + s.Require().Equal(address, address) + s.Require().NoError(err) +} + +func (s *KeysTestSuite) TestIsKeyNameExist() { + s.chainProvider.KeyInfo["testkey1"] = testAddress + expected := s.chainProvider.IsKeyNameExist("testkey1") + + s.Require().Equal(expected, true) + + expected = s.chainProvider.IsKeyNameExist("testkey2") + s.Require().Equal(expected, false) +} + +func (s *KeysTestSuite) TestGetKeyFromKeyName() { + keyName := "testkeyname" + privatekeyHex := testPrivateKey + + // Add a key to test retrieval + _, err := s.chainProvider.AddKeyWithPrivateKey(keyName, privatekeyHex, s.homePath, "") + s.Require().NoError(err) + + // Retrieve the key using the key name + key, err := s.chainProvider.GetKeyFromKeyName(keyName, "") + s.Require().NoError(err) + s.Require().NotNil(key) + + // Verify that the retrieved private key matches the original private key + s.Require().Equal(testPrivateKey[2:], hex.EncodeToString(crypto.FromECDSA(key.PrivateKey))) // Remove "0x" + + // Retrieve the key using the invalid passphrase should return error + _, err = s.chainProvider.GetKeyFromKeyName(keyName, "invalid") + s.Require().ErrorContains(err, "could not decrypt key with given password") +} diff --git a/relayer/chains/evm/provider.go b/relayer/chains/evm/provider.go index d2730c0..9b4f1e3 100644 --- a/relayer/chains/evm/provider.go +++ b/relayer/chains/evm/provider.go @@ -202,7 +202,7 @@ func (cp *EVMChainProvider) RelayPacket(ctx context.Context, packet *bandtypes.P var txStatus TxStatus checkTxLogic: for time.Since(createdAt) < cp.Config.WaitingTxDuration { - result, err := cp.checkConfirmedTx(ctx, txHash) + result, err := cp.CheckConfirmedTx(ctx, txHash) if err != nil { log.Debug( "Failed to check tx status", @@ -228,7 +228,12 @@ func (cp *EVMChainProvider) RelayPacket(ctx context.Context, packet *bandtypes.P ) return nil case TX_STATUS_FAILED: - retryCount += 1 + log.Debug( + "Transaction failed during relay attempt", + zap.Error(err), + zap.String("tx_hash", txHash), + zap.Int("retry_count", retryCount), + ) break checkTxLogic case TX_STATUS_UNMINED: log.Debug( @@ -269,12 +274,12 @@ func (cp *EVMChainProvider) createAndSignRelayTx( sender *Sender, gasInfo GasInfo, ) (*gethtypes.Transaction, error) { - calldata, err := cp.createCalldata(packet) + calldata, err := cp.CreateCalldata(packet) if err != nil { return nil, fmt.Errorf("failed to create calldata: %w", err) } - tx, err := cp.newRelayTx(ctx, calldata, sender.Address, gasInfo) + tx, err := cp.NewRelayTx(ctx, calldata, sender.Address, gasInfo) if err != nil { return nil, fmt.Errorf("failed to create an evm transaction: %w", err) } @@ -287,8 +292,8 @@ func (cp *EVMChainProvider) createAndSignRelayTx( return signedTx, nil } -// checkConfirmedTx checks the confirmed transaction status. -func (cp *EVMChainProvider) checkConfirmedTx( +// CheckConfirmedTx checks the confirmed transaction status. +func (cp *EVMChainProvider) CheckConfirmedTx( ctx context.Context, txHash string, ) (*ConfirmTxResult, error) { @@ -296,7 +301,6 @@ func (cp *EVMChainProvider) checkConfirmedTx( txHash, TX_STATUS_UNMINED, decimal.NullDecimal{}, - cp.GasType, ) receipt, err := cp.Client.GetTxReceipt(ctx, txHash) @@ -320,7 +324,7 @@ func (cp *EVMChainProvider) checkConfirmedTx( // calculate gas used and effective gas price gasUsed := decimal.NewNullDecimal(decimal.New(int64(receipt.GasUsed), 0)) - return NewConfirmTxResult(txHash, TX_STATUS_SUCCESS, gasUsed, cp.GasType), nil + return NewConfirmTxResult(txHash, TX_STATUS_SUCCESS, gasUsed), nil } // EstimateGasFee estimates the gas for the transaction. @@ -430,8 +434,8 @@ func (cp *EVMChainProvider) queryTunnelInfo( return &output.Info, nil } -// newRelayTx creates a new relay transaction. -func (cp *EVMChainProvider) newRelayTx( +// NewRelayTx creates a new relay transaction. +func (cp *EVMChainProvider) NewRelayTx( ctx context.Context, data []byte, sender gethcommon.Address, @@ -492,8 +496,8 @@ func (cp *EVMChainProvider) newRelayTx( return tx, nil } -// createCalldata creates the calldata for the relay transaction. -func (cp *EVMChainProvider) createCalldata(packet *bandtypes.Packet) ([]byte, error) { +// CreateCalldata creates the calldata for the relay transaction. +func (cp *EVMChainProvider) CreateCalldata(packet *bandtypes.Packet) ([]byte, error) { var signing *bandtypes.Signing // get signing from packet; prefer to use signing from diff --git a/relayer/chains/evm/provider_eip1559_test.go b/relayer/chains/evm/provider_eip1559_test.go new file mode 100644 index 0000000..7abf9de --- /dev/null +++ b/relayer/chains/evm/provider_eip1559_test.go @@ -0,0 +1,333 @@ +package evm_test + +import ( + "context" + "encoding/hex" + "fmt" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum" + gethtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + "go.uber.org/zap" + + "github.com/bandprotocol/falcon/internal/relayertest/mocks" + bandtypes "github.com/bandprotocol/falcon/relayer/band/types" + "github.com/bandprotocol/falcon/relayer/chains/evm" +) + +type EIP1559ProviderTestSuite struct { + suite.Suite + ctrl *gomock.Controller + + chainProvider *evm.EVMChainProvider + client *mocks.MockEVMClient + homePath string + chainName string + + relayingPacket bandtypes.Packet + relayingCalldata []byte + gasInfo evm.GasInfo + mockSender evm.Sender +} + +func TestEIP1559ProviderTestSuite(t *testing.T) { + suite.Run(t, new(EIP1559ProviderTestSuite)) +} + +func (s *EIP1559ProviderTestSuite) SetupTest() { + s.ctrl = gomock.NewController(s.T()) + s.client = mocks.NewMockEVMClient(s.ctrl) + + evmConfig := *baseEVMCfg + evmConfig.GasType = evm.GasTypeEIP1559 + s.chainName = "testnet" + s.homePath = s.T().TempDir() + chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), s.homePath) + s.Require().NoError(err) + + s.chainProvider = chainProvider + s.chainProvider.FreeSenders = make(chan *evm.Sender, 1) + s.chainProvider.FreeSenders <- &s.mockSender + + s.mockSender, err = mockSender() + s.Require().NoError(err) + + s.relayingPacket = mockPacket() + s.relayingCalldata, err = s.chainProvider.CreateCalldata(&s.relayingPacket) + s.Require().NoError(err) + + s.gasInfo = evm.NewGasEIP1559Info(big.NewInt(10_000_000_000), big.NewInt(8_000_000_000)) +} + +func (s *EIP1559ProviderTestSuite) MockDefaultResponses() { + gasInfoCalldata, err := hex.DecodeString("658612e9") + s.Require().NoError(err) + gasInfoResponse, err := hex.DecodeString(uint256ToHex(big.NewInt(12_000_000_000))) + s.Require().NoError(err) + + mockCtx := gomock.Any() + s.client.EXPECT().CheckAndConnect(mockCtx).Return(nil).AnyTimes() + s.client.EXPECT().EstimateGasTipCap(mockCtx).Return(s.gasInfo.GasPriorityFee, nil).AnyTimes() + s.client.EXPECT().EstimateBaseFee(mockCtx).Return(s.gasInfo.GasBaseFee, nil).AnyTimes() + s.client.EXPECT().PendingNonceAt(mockCtx, s.mockSender.Address).Return(uint64(100), nil).AnyTimes() + s.client.EXPECT(). + Query(mockCtx, s.chainProvider.TunnelRouterAddress, gasInfoCalldata). + Return(gasInfoResponse, nil). + AnyTimes() +} + +func (s *EIP1559ProviderTestSuite) TestRelayPacketSuccess() { + // mock client responses + s.client.EXPECT().EstimateGas(gomock.Any(), ethereum.CallMsg{ + From: s.mockSender.Address, + To: &s.chainProvider.TunnelRouterAddress, + Data: s.relayingCalldata, + GasFeeCap: s.gasInfo.GasFeeCap, + GasTipCap: s.gasInfo.GasPriorityFee, + }).Return(uint64(200_000), nil) + + txHash := "0xabc123" + s.client.EXPECT().BroadcastTx(gomock.Any(), gomock.Any()).Return(txHash, nil) + s.client.EXPECT().GetTxReceipt(gomock.Any(), txHash).Return(&gethtypes.Receipt{ + Status: gethtypes.ReceiptStatusSuccessful, + GasUsed: 21000, + BlockNumber: big.NewInt(100), + }, nil) + + s.client.EXPECT().GetBlockHeight(gomock.Any()).Return(uint64(105), nil) + s.MockDefaultResponses() + + err := s.chainProvider.RelayPacket(context.Background(), &s.relayingPacket) + s.Require().NoError(err) +} + +func (s *EIP1559ProviderTestSuite) TestRelayPacketSuccessWithoutQueryMaxGasFee() { + s.chainProvider.Config.MaxBaseFee = 2_000_000_000 + s.chainProvider.Config.MaxPriorityFee = 3_000_000_000 + + // mock client responses + s.client.EXPECT().EstimateGas(gomock.Any(), ethereum.CallMsg{ + From: s.mockSender.Address, + To: &s.chainProvider.TunnelRouterAddress, + Data: s.relayingCalldata, + GasFeeCap: big.NewInt(5_000_000_000), + GasTipCap: big.NewInt(3_000_000_000), + }).Return(uint64(200_000), nil) + + txHash := "0xabc123" + s.client.EXPECT().BroadcastTx(gomock.Any(), gomock.Any()).Return(txHash, nil) + s.client.EXPECT().GetTxReceipt(gomock.Any(), txHash).Return(&gethtypes.Receipt{ + Status: gethtypes.ReceiptStatusSuccessful, + GasUsed: 21000, + BlockNumber: big.NewInt(100), + }, nil) + + s.client.EXPECT().GetBlockHeight(gomock.Any()).Return(uint64(105), nil) + s.MockDefaultResponses() + + err := s.chainProvider.RelayPacket(context.Background(), &s.relayingPacket) + s.Require().NoError(err) +} + +func (s *EIP1559ProviderTestSuite) TestRelayPacketFailedConnect() { + // mock client responses + s.client.EXPECT().CheckAndConnect(gomock.Any()).Return(fmt.Errorf("failed to connect client")) + s.MockDefaultResponses() + + err := s.chainProvider.RelayPacket(context.Background(), &s.relayingPacket) + s.Require().ErrorContains(err, "failed to connect client") +} + +func (s *EIP1559ProviderTestSuite) TestRelayPacketFailedGasEstimation() { + // mock client responses + s.client.EXPECT().EstimateGasTipCap(gomock.Any()).Return(nil, fmt.Errorf("failed to estimate gas tip cap")) + s.MockDefaultResponses() + + err := s.chainProvider.RelayPacket(context.Background(), &s.relayingPacket) + s.Require().ErrorContains(err, "failed to estimate gas tip cap") +} + +func (s *EIP1559ProviderTestSuite) TestRelayPacketFailedBroadcastTx() { + // mock client responses + s.client.EXPECT().EstimateGas(gomock.Any(), ethereum.CallMsg{ + From: s.mockSender.Address, + To: &s.chainProvider.TunnelRouterAddress, + Data: s.relayingCalldata, + GasFeeCap: s.gasInfo.GasFeeCap, + GasTipCap: s.gasInfo.GasPriorityFee, + }).Return(uint64(200_000), nil).Times(s.chainProvider.Config.MaxRetry) + + s.client.EXPECT(). + BroadcastTx(gomock.Any(), gomock.Any()). + Return("", fmt.Errorf("failed to broadcast an evm transaction")). + Times(s.chainProvider.Config.MaxRetry) + s.MockDefaultResponses() + + err := s.chainProvider.RelayPacket(context.Background(), &s.relayingPacket) + s.Require().ErrorContains(err, "failed to relay packet after") +} + +func (s *EIP1559ProviderTestSuite) TestRelayPacketFailedTxReceiptStatus() { + // mock client responses + s.client.EXPECT().EstimateGas(gomock.Any(), ethereum.CallMsg{ + From: s.mockSender.Address, + To: &s.chainProvider.TunnelRouterAddress, + Data: s.relayingCalldata, + GasFeeCap: s.gasInfo.GasFeeCap, + GasTipCap: s.gasInfo.GasPriorityFee, + }).Return(uint64(200_000), nil).Times(s.chainProvider.Config.MaxRetry) + + txHash := "0xabc123" + s.client.EXPECT(). + BroadcastTx(gomock.Any(), gomock.Any()). + Return(txHash, nil). + Times(s.chainProvider.Config.MaxRetry) + + s.client.EXPECT(). + GetTxReceipt(gomock.Any(), txHash). + Return(&gethtypes.Receipt{ + Status: gethtypes.ReceiptStatusFailed, + GasUsed: 21000, + BlockNumber: big.NewInt(100), + }, nil). + Times(s.chainProvider.Config.MaxRetry) + s.MockDefaultResponses() + + err := s.chainProvider.RelayPacket(context.Background(), &s.relayingPacket) + s.Require().ErrorContains(err, "failed to relay packet after") +} + +func (s *EIP1559ProviderTestSuite) TestBumpAndBoundGas() { + s.MockDefaultResponses() + + // Test cases + testCases := []struct { + name string + maxPriorityFee uint64 + maxBaseFee uint64 + initialPriorityFee int64 + initialBaseFee int64 + multiplier float64 + expectedPriorityFee int64 + expectedBaseFee int64 + }{ + { + name: "Priority and base fee within limits", + maxPriorityFee: 10_000_000_000, + maxBaseFee: 20_000_000_000, + initialPriorityFee: 5_000_000_000, + initialBaseFee: 15_000_000_000, + multiplier: 1.2, + expectedPriorityFee: 6_000_000_000, // due to big.Float imprecision + expectedBaseFee: 15_000_000_000, // Unchanged + }, + { + name: "Priority fee exceeds cap", + maxPriorityFee: 8_000_000_000, + maxBaseFee: 20_000_000_000, + initialPriorityFee: 7_000_000_000, + initialBaseFee: 15_000_000_000, + multiplier: 1.2, + expectedPriorityFee: 8_000_000_000, // Capped at maxPriorityFee + expectedBaseFee: 15_000_000_000, + }, + { + name: "Base fee exceeds cap", + maxPriorityFee: 10_000_000_000, + maxBaseFee: 18_000_000_000, + initialPriorityFee: 5_000_000_000, + initialBaseFee: 19_000_000_000, + multiplier: 1.2, + expectedPriorityFee: 6_000_000_000, // due to big.Float imprecision + expectedBaseFee: 18_000_000_000, + }, + { + name: "No priority fee cap, use relayer fee", + maxPriorityFee: 0, + maxBaseFee: 0, + initialPriorityFee: 11_000_000_000, + initialBaseFee: 18_000_000_000, + multiplier: 1.2, + expectedPriorityFee: 12_000_000_000, // due to big.Float imprecision + expectedBaseFee: 18_000_000_000, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + s.chainProvider.Config.MaxPriorityFee = tc.maxPriorityFee + s.chainProvider.Config.MaxBaseFee = tc.maxBaseFee + + actual, err := s.chainProvider.BumpAndBoundGas( + context.Background(), + evm.NewGasEIP1559Info(big.NewInt(tc.initialPriorityFee), big.NewInt(tc.initialBaseFee)), + tc.multiplier, + ) + s.Require().NoError(err) + + expected := evm.NewGasEIP1559Info(big.NewInt(tc.expectedPriorityFee), big.NewInt(tc.expectedBaseFee)) + s.Require().Equal(expected, actual, "Failed test case: %s", tc.name) + }) + } +} + +func (s *EIP1559ProviderTestSuite) TestEstimateGas() { + s.client.EXPECT().EstimateGasTipCap(gomock.Any()).Return(big.NewInt(5_000_000_000), nil) + s.client.EXPECT().EstimateBaseFee(gomock.Any()).Return(big.NewInt(10_000_000_000), nil) + s.MockDefaultResponses() + + actual, err := s.chainProvider.EstimateGasFee(context.Background()) + s.Require().NoError(err) + + expected := evm.GasInfo{ + Type: evm.GasTypeEIP1559, + GasPrice: nil, + GasPriorityFee: big.NewInt(5_000_000_000), + GasBaseFee: big.NewInt(10_000_000_000), + GasFeeCap: big.NewInt(15_000_000_000), + } + + s.Require().Equal(expected, actual) +} + +func (s *EIP1559ProviderTestSuite) TestNewRelayTx() { + data := []byte("mock calldata") + + callMsg := ethereum.CallMsg{ + From: s.mockSender.Address, + To: &s.chainProvider.TunnelRouterAddress, + Data: data, + GasFeeCap: s.gasInfo.GasFeeCap, + GasTipCap: s.gasInfo.GasPriorityFee, + } + s.client.EXPECT().EstimateGas(gomock.Any(), callMsg).Return(uint64(100_000), nil) + s.client.EXPECT().PendingNonceAt(gomock.Any(), s.mockSender.Address).Return(uint64(1), nil) + + actual, err := s.chainProvider.NewRelayTx(context.Background(), data, s.mockSender.Address, s.gasInfo) + s.Require().NoError(err) + + expected := gethtypes.NewTx(&gethtypes.DynamicFeeTx{ + ChainID: big.NewInt(int64(s.chainProvider.Config.ChainID)), + Nonce: uint64(1), + To: &s.chainProvider.TunnelRouterAddress, + Value: decimal.NewFromInt(0).BigInt(), + Data: data, + Gas: 100_000, + GasFeeCap: s.gasInfo.GasFeeCap, + GasTipCap: s.gasInfo.GasPriorityFee, + }) + + // check only some parts of the received tx. + s.Require().Equal(expected.Nonce(), actual.Nonce(), "Nonce mismatch") + s.Require().Equal(expected.To(), actual.To(), "To address mismatch") + s.Require().Equal(expected.Data(), actual.Data(), "Data mismatch") + s.Require().Equal(expected.Gas(), actual.Gas(), "Gas limit mismatch") + s.Require().Equal(expected.GasPrice(), actual.GasPrice(), "GasPrice mismatch") + s.Require().Equal(expected.GasTipCap(), actual.GasTipCap(), "GasTipCap mismatch") + s.Require().Equal(expected.GasFeeCap(), actual.GasFeeCap(), "GasFeeCap mismatch") + s.Require().Equal(expected.ChainId(), actual.ChainId(), "ChainID mismatch") +} diff --git a/relayer/chains/evm/provider_legacy_test.go b/relayer/chains/evm/provider_legacy_test.go new file mode 100644 index 0000000..ffb0d64 --- /dev/null +++ b/relayer/chains/evm/provider_legacy_test.go @@ -0,0 +1,242 @@ +package evm_test + +import ( + "context" + "encoding/hex" + "fmt" + "math/big" + "testing" + + "github.com/ethereum/go-ethereum" + gethtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + "go.uber.org/zap" + + "github.com/bandprotocol/falcon/internal/relayertest/mocks" + bandtypes "github.com/bandprotocol/falcon/relayer/band/types" + "github.com/bandprotocol/falcon/relayer/chains/evm" +) + +type LegacyProviderTestSuite struct { + suite.Suite + ctrl *gomock.Controller + + chainProvider *evm.EVMChainProvider + client *mocks.MockEVMClient + homePath string + chainName string + + relayingPacket bandtypes.Packet + relayingCalldata []byte + gasInfo evm.GasInfo + mockSender evm.Sender +} + +func TestLegacyProviderTestSuite(t *testing.T) { + suite.Run(t, new(LegacyProviderTestSuite)) +} + +func (s *LegacyProviderTestSuite) SetupTest() { + s.ctrl = gomock.NewController(s.T()) + s.client = mocks.NewMockEVMClient(s.ctrl) + + evmConfig := *baseEVMCfg + evmConfig.GasType = evm.GasTypeLegacy + s.chainName = "testnet" + s.homePath = s.T().TempDir() + chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), s.homePath) + s.Require().NoError(err) + + s.chainProvider = chainProvider + s.chainProvider.FreeSenders = make(chan *evm.Sender, 1) + s.chainProvider.FreeSenders <- &s.mockSender + + s.mockSender, err = mockSender() + s.Require().NoError(err) + + s.relayingPacket = mockPacket() + s.relayingCalldata, err = s.chainProvider.CreateCalldata(&s.relayingPacket) + s.Require().NoError(err) + + s.gasInfo = evm.NewGasLegacyInfo(big.NewInt(10_000_000_000)) +} + +func (s *LegacyProviderTestSuite) MockDefaultResponses() { + gasInfoCalldata, err := hex.DecodeString("658612e9") + s.Require().NoError(err) + gasInfoResponse, err := hex.DecodeString(uint256ToHex(big.NewInt(12_000_000_000))) + s.Require().NoError(err) + + mockCtx := gomock.Any() + s.client.EXPECT().CheckAndConnect(mockCtx).Return(nil).AnyTimes() + s.client.EXPECT().EstimateGasPrice(mockCtx).Return(s.gasInfo.GasPrice, nil).AnyTimes() + s.client.EXPECT().PendingNonceAt(mockCtx, s.mockSender.Address).Return(uint64(100), nil).AnyTimes() + s.client.EXPECT(). + Query(mockCtx, s.chainProvider.TunnelRouterAddress, gasInfoCalldata). + Return(gasInfoResponse, nil). + AnyTimes() +} + +func (s *LegacyProviderTestSuite) TestRelayPacketSuccess() { + // mock client responses + s.client.EXPECT().EstimateGas(gomock.Any(), ethereum.CallMsg{ + From: s.mockSender.Address, + To: &s.chainProvider.TunnelRouterAddress, + Data: s.relayingCalldata, + GasPrice: s.gasInfo.GasPrice, + }).Return(uint64(200_000), nil) + + txHash := "0xabc123" + s.client.EXPECT().BroadcastTx(gomock.Any(), gomock.Any()).Return(txHash, nil) + s.client.EXPECT().GetTxReceipt(gomock.Any(), txHash).Return(&gethtypes.Receipt{ + Status: gethtypes.ReceiptStatusSuccessful, + GasUsed: 21000, + BlockNumber: big.NewInt(100), + }, nil) + + s.client.EXPECT().GetBlockHeight(gomock.Any()).Return(uint64(105), nil) + s.MockDefaultResponses() + + err := s.chainProvider.RelayPacket(context.Background(), &s.relayingPacket) + s.Require().NoError(err) +} + +func (s *LegacyProviderTestSuite) TestRelayPacketSuccessWithoutQueryMaxGasFee() { + s.chainProvider.Config.MaxGasPrice = 2_000_000_000 + + // mock client responses + s.client.EXPECT().EstimateGas(gomock.Any(), ethereum.CallMsg{ + From: s.mockSender.Address, + To: &s.chainProvider.TunnelRouterAddress, + Data: s.relayingCalldata, + GasPrice: big.NewInt(2_000_000_000), + }).Return(uint64(200_000), nil) + + txHash := "0xabc123" + s.client.EXPECT().BroadcastTx(gomock.Any(), gomock.Any()).Return(txHash, nil) + s.client.EXPECT().GetTxReceipt(gomock.Any(), txHash).Return(&gethtypes.Receipt{ + Status: gethtypes.ReceiptStatusSuccessful, + GasUsed: 21000, + BlockNumber: big.NewInt(100), + }, nil) + + s.client.EXPECT().GetBlockHeight(gomock.Any()).Return(uint64(105), nil) + s.MockDefaultResponses() + + err := s.chainProvider.RelayPacket(context.Background(), &s.relayingPacket) + s.Require().NoError(err) +} + +func (s *LegacyProviderTestSuite) TestRelayPacketFailedGasEstimation() { + // mock client responses + s.client.EXPECT().EstimateGasPrice(gomock.Any()).Return(nil, fmt.Errorf("failed to estimate gas price")) + s.MockDefaultResponses() + + err := s.chainProvider.RelayPacket(context.Background(), &s.relayingPacket) + s.Require().ErrorContains(err, "failed to estimate gas price") +} + +func (s *LegacyProviderTestSuite) TestBumpAndBoundGas() { + s.MockDefaultResponses() + + testCases := []struct { + name string + maxGasPrice uint64 + initialGasPrice int64 + multiplier float64 + expectedGasPrice int64 + }{ + { + name: "Gas price within limit", + maxGasPrice: 15_000_000_000, + initialGasPrice: 10_000_000_000, + multiplier: 1.2, + expectedGasPrice: 12_000_000_000, + }, + { + name: "Gas price exceeding limit", + maxGasPrice: 15_000_000_000, + initialGasPrice: 14_000_000_000, + multiplier: 1.2, + expectedGasPrice: 15_000_000_000, + }, + { + name: "No gas price cap, use relayer fee", + maxGasPrice: 0, + initialGasPrice: 11_000_000_000, + multiplier: 1.2, + expectedGasPrice: 12_000_000_000, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + s.chainProvider.Config.MaxGasPrice = tc.maxGasPrice + + actual, err := s.chainProvider.BumpAndBoundGas( + context.Background(), + evm.NewGasLegacyInfo(big.NewInt(tc.initialGasPrice)), + tc.multiplier, + ) + s.Require().NoError(err) + + expected := evm.NewGasLegacyInfo(big.NewInt(tc.expectedGasPrice)) + s.Require().Equal(expected, actual, "Failed test case: %s", tc.name) + }) + } +} + +func (s *LegacyProviderTestSuite) TestEstimateGas() { + s.client.EXPECT().EstimateGasPrice(gomock.Any()).Return(big.NewInt(5_000_000_000), nil) + s.MockDefaultResponses() + + actual, err := s.chainProvider.EstimateGasFee(context.Background()) + s.Require().NoError(err) + + expected := evm.GasInfo{ + Type: evm.GasTypeLegacy, + GasPrice: big.NewInt(5_000_000_000), + GasPriorityFee: nil, + GasBaseFee: nil, + GasFeeCap: nil, + } + + s.Require().Equal(expected, actual) +} + +func (s *LegacyProviderTestSuite) TestNewRelayTx() { + data := []byte("mock calldata") + callMsg := ethereum.CallMsg{ + From: s.mockSender.Address, + To: &s.chainProvider.TunnelRouterAddress, + Data: data, + GasPrice: s.gasInfo.GasPrice, + } + + s.client.EXPECT().EstimateGas(gomock.Any(), callMsg).Return(uint64(100), nil) + s.client.EXPECT().PendingNonceAt(gomock.Any(), s.mockSender.Address).Return(uint64(1), nil) + + actual, err := s.chainProvider.NewRelayTx(context.Background(), data, s.mockSender.Address, s.gasInfo) + s.Require().NoError(err) + + expected := gethtypes.NewTx(&gethtypes.LegacyTx{ + Nonce: uint64(1), + To: &s.chainProvider.TunnelRouterAddress, + Value: decimal.NewFromInt(0).BigInt(), + Data: data, + Gas: uint64(100), + GasPrice: s.gasInfo.GasPrice, + }) + + // check only some parts of the received tx. + s.Require().Equal(expected.Nonce(), actual.Nonce(), "Nonce mismatch") + s.Require().Equal(expected.To(), actual.To(), "To address mismatch") + s.Require().Equal(expected.Data(), actual.Data(), "Data mismatch") + s.Require().Equal(expected.Gas(), actual.Gas(), "Gas limit mismatch") + s.Require().Equal(expected.GasPrice(), actual.GasPrice(), "GasPrice mismatch") + s.Require().Equal(expected.GasTipCap(), actual.GasTipCap(), "GasTipCap mismatch") + s.Require().Equal(expected.GasFeeCap(), actual.GasFeeCap(), "GasFeeCap mismatch") + s.Require().Equal(expected.ChainId(), actual.ChainId(), "ChainID mismatch") +} diff --git a/relayer/chains/evm/provider_test.go b/relayer/chains/evm/provider_test.go new file mode 100644 index 0000000..0ad145f --- /dev/null +++ b/relayer/chains/evm/provider_test.go @@ -0,0 +1,301 @@ +package evm_test + +import ( + "context" + "encoding/hex" + "fmt" + "math/big" + "testing" + "time" + + cmbytes "github.com/cometbft/cometbft/libs/bytes" + gethcommon "github.com/ethereum/go-ethereum/common" + gethtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/crypto" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + "go.uber.org/zap" + + "github.com/bandprotocol/falcon/internal/relayertest/mocks" + bandtypes "github.com/bandprotocol/falcon/relayer/band/types" + "github.com/bandprotocol/falcon/relayer/chains" + "github.com/bandprotocol/falcon/relayer/chains/evm" + chaintypes "github.com/bandprotocol/falcon/relayer/chains/types" +) + +var baseEVMCfg = &evm.EVMChainProviderConfig{ + BaseChainProviderConfig: chains.BaseChainProviderConfig{ + Endpoints: []string{"http://localhost:8545"}, + ChainType: chains.ChainTypeEVM, + MaxRetry: 3, + ChainID: 31337, + TunnelRouterAddress: "0xDc64a140Aa3E981100a9becA4E685f962f0cF6C9", + QueryTimeout: 3 * time.Second, + ExecuteTimeout: 3 * time.Second, + }, + BlockConfirmation: 5, + WaitingTxDuration: time.Second * 3, + CheckingTxInterval: time.Second, + LivelinessCheckingInterval: 15 * time.Minute, + GasMultiplier: 1, +} + +func mockPacket() bandtypes.Packet { + relatedMsg := cmbytes.HexBytes("0xdeadbeef") + rAddr := gethcommon.HexToAddress("0xfad9c8855b740a0b7ed4c221dbad0f33a83a49ca") + signature := cmbytes.HexBytes("0xabcd") + + evmSignature := bandtypes.NewEVMSignature(rAddr.Bytes(), signature) + signingInfo := bandtypes.NewSigning( + 1, + relatedMsg, + evmSignature, + "SIGNING_STATUS_SUCCESS", + ) + + return bandtypes.Packet{ + TunnelID: 1, + Sequence: 42, + SignalPrices: []bandtypes.SignalPrice{ + {SignalID: "signal1", Price: 100}, + {SignalID: "signal2", Price: 200}, + }, + CurrentGroupSigning: signingInfo, + IncomingGroupSigning: nil, + } +} + +func mockSender() (evm.Sender, error) { + addr, err := evm.HexToAddress(testAddress) + if err != nil { + return evm.Sender{}, err + } + + priv, err := crypto.HexToECDSA(evm.StripPrivateKeyPrefix(testPrivateKey)) + if err != nil { + return evm.Sender{}, err + } + + return evm.Sender{ + Address: addr, + PrivateKey: priv, + }, nil +} + +func uint256ToHex(value *big.Int) string { + return fmt.Sprintf("%064x", value) +} + +type ProviderTestSuite struct { + suite.Suite + + ctrl *gomock.Controller + chainProvider *evm.EVMChainProvider + client *mocks.MockEVMClient + log *zap.Logger + homePath string + chainName string +} + +func TestProviderTestSuite(t *testing.T) { + suite.Run(t, new(ProviderTestSuite)) +} + +// SetupTest sets up the test suite by creating a temporary directory and declare mock objects. +func (s *ProviderTestSuite) SetupTest() { + var err error + tmpDir := s.T().TempDir() + + s.ctrl = gomock.NewController(s.T()) + s.client = mocks.NewMockEVMClient(s.ctrl) + + // mock objects. + s.log = zap.NewNop() + + chainName := "testnet" + s.chainName = chainName + + s.chainProvider, err = evm.NewEVMChainProvider(s.chainName, s.client, baseEVMCfg, s.log, s.homePath) + s.Require().NoError(err) + + s.chainProvider.Client = s.client + s.homePath = tmpDir +} + +func (s *ProviderTestSuite) TestQueryTunnelInfo() { + queryTunnelCalldata, err := hex.DecodeString( + "077071ef0000000000000000000000000000000000000000000000000000000000000001000000000000000000000000e688b84b23f322a994a53dbf8e15fa82cdb71127", + ) + s.Require().NoError(err) + + // abi-encoded from {"isActive": True,"latestSequence": 1,"balance": 1000000000000000000} + queryTunnelResponse, err := hex.DecodeString( + "000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000de0b6b3a7640000", + ) + s.Require().NoError(err) + + type Input struct { + tunnelID uint64 + tunnelAddr string + } + testcases := []struct { + name string + input Input + preProcess func() + err error + out *chaintypes.Tunnel + }{ + { + name: "success", + input: Input{1, "0xe688b84b23f322a994A53dbF8E15FA82CDB71127"}, + preProcess: func() { + s.client.EXPECT().CheckAndConnect(gomock.Any()).Return(nil) + s.client.EXPECT(). + Query(gomock.Any(), s.chainProvider.TunnelRouterAddress, queryTunnelCalldata). + Return(queryTunnelResponse, nil) + }, + out: &chaintypes.Tunnel{ + ID: 1, + TargetAddress: "0xe688b84b23f322a994A53dbF8E15FA82CDB71127", + IsActive: true, + LatestSequence: 1, + Balance: big.NewInt(1000000000000000000), + }, + }, + { + name: "failed to connect client", + input: Input{1, "0xe688b84b23f322a994A53dbF8E15FA82CDB71127"}, + preProcess: func() { + s.client.EXPECT().CheckAndConnect(gomock.Any()).Return(fmt.Errorf("Connect client error")) + }, + err: fmt.Errorf("Connect client error"), + }, + { + name: "invalid target address", + input: Input{1, "0xincorrect"}, + preProcess: func() { + s.client.EXPECT().CheckAndConnect(gomock.Any()).Return(nil) + }, + err: fmt.Errorf("invalid address"), + }, + { + name: "cannot unpack data", + input: Input{1, "0xe688b84b23f322a994A53dbF8E15FA82CDB71127"}, + preProcess: func() { + s.client.EXPECT().CheckAndConnect(gomock.Any()).Return(nil) + s.client.EXPECT(). + Query(gomock.Any(), s.chainProvider.TunnelRouterAddress, queryTunnelCalldata). + Return([]uint8{0, 124}, nil) + }, + err: fmt.Errorf("failed to unpack data"), + }, + } + + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preProcess != nil { + tc.preProcess() + } + defer s.ctrl.Finish() + + tunnel, err := s.chainProvider.QueryTunnelInfo( + context.Background(), + tc.input.tunnelID, + tc.input.tunnelAddr, + ) + + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(tc.out, tunnel) + } + }) + } +} + +func (s *ProviderTestSuite) TestEstimateGasUnsupportedGas() { + _, err := s.chainProvider.EstimateGasFee(context.Background()) + s.Require().ErrorContains(err, "unsupported gas type:") +} + +func (s *ProviderTestSuite) TestCheckConfirmedTx() { + txHash := "0xabc123" + txBlock := int64(100) + + testcases := []struct { + name string + preProcess func() + err error + out *evm.ConfirmTxResult + }{ + { + name: "success", + preProcess: func() { + currentBlock := txBlock + int64(s.chainProvider.Config.BlockConfirmation) + 10 + + s.client.EXPECT().GetTxReceipt(gomock.Any(), txHash).Return(&gethtypes.Receipt{ + Status: gethtypes.ReceiptStatusSuccessful, + GasUsed: 21000, + BlockNumber: big.NewInt(txBlock), + }, nil) + s.client.EXPECT().GetBlockHeight(gomock.Any()).Return(uint64(currentBlock), nil) + }, + out: evm.NewConfirmTxResult( + txHash, + evm.TX_STATUS_SUCCESS, + decimal.NewNullDecimal(decimal.New(21000, 0)), + ), + }, + { + name: "get tx receipt with failed status", + preProcess: func() { + s.client.EXPECT().GetTxReceipt(gomock.Any(), txHash).Return(&gethtypes.Receipt{ + Status: gethtypes.ReceiptStatusFailed, + GasUsed: 21000, + BlockNumber: big.NewInt(txBlock), + }, nil) + }, + out: evm.NewConfirmTxResult( + txHash, + evm.TX_STATUS_FAILED, + decimal.NullDecimal{}, + ), + }, + { + name: "get tx receipt but not confirmed block", + preProcess: func() { + currentBlock := txBlock + int64(s.chainProvider.Config.BlockConfirmation) - 1 + + s.client.EXPECT().GetTxReceipt(gomock.Any(), txHash).Return(&gethtypes.Receipt{ + Status: gethtypes.ReceiptStatusSuccessful, + GasUsed: 21000, + BlockNumber: big.NewInt(txBlock), + }, nil) + s.client.EXPECT().GetBlockHeight(gomock.Any()).Return(uint64(currentBlock), nil) + }, + out: evm.NewConfirmTxResult( + txHash, + evm.TX_STATUS_UNMINED, + decimal.NullDecimal{}, + ), + }, + } + + for _, tc := range testcases { + s.Run(tc.name, func() { + if tc.preProcess != nil { + tc.preProcess() + } + + expect, err := s.chainProvider.CheckConfirmedTx(context.Background(), txHash) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + s.Require().Equal(tc.out, expect) + } + }) + } +} diff --git a/relayer/chains/evm/sender.go b/relayer/chains/evm/sender.go index 7ce9333..b55b054 100644 --- a/relayer/chains/evm/sender.go +++ b/relayer/chains/evm/sender.go @@ -39,7 +39,7 @@ func (cp *EVMChainProvider) LoadFreeSenders( freeSenders := make(chan *Sender, len(cp.KeyInfo)) for keyName := range cp.KeyInfo { - key, err := cp.getKeyFromKeyName(keyName, passphrase) + key, err := cp.GetKeyFromKeyName(keyName, passphrase) if err != nil { return err } diff --git a/relayer/chains/evm/sender_test.go b/relayer/chains/evm/sender_test.go new file mode 100644 index 0000000..fd803c4 --- /dev/null +++ b/relayer/chains/evm/sender_test.go @@ -0,0 +1,158 @@ +package evm_test + +import ( + "context" + "encoding/hex" + "os" + "path" + "testing" + "time" + + "github.com/ethereum/go-ethereum/crypto" + "github.com/pelletier/go-toml/v2" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + "github.com/bandprotocol/falcon/relayer/chains" + "github.com/bandprotocol/falcon/relayer/chains/evm" +) + +const ( + privateKey1 = "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80" + address1 = "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266" + privateKey2 = "0x59c6995e998f97a5a0044966f0945389dc9e86dae88c7a8412f4603b6b78690d" + address2 = "0x70997970C51812dc3A010C7d01b50e0d17dc79C8" +) + +var evmCfg = &evm.EVMChainProviderConfig{ + BaseChainProviderConfig: chains.BaseChainProviderConfig{ + Endpoints: []string{"http://localhost:8545"}, + ChainType: chains.ChainTypeEVM, + MaxRetry: 3, + ChainID: 31337, + TunnelRouterAddress: "0xDc64a140Aa3E981100a9becA4E685f962f0cF6C9", + QueryTimeout: 3 * time.Second, + ExecuteTimeout: 3 * time.Second, + }, + BlockConfirmation: 5, + WaitingTxDuration: time.Second * 3, + CheckingTxInterval: time.Second, + LivelinessCheckingInterval: 15 * time.Minute, + GasType: evm.GasTypeEIP1559, + GasMultiplier: 1.1, +} + +type SenderTestSuite struct { + suite.Suite + + ctx context.Context + chainProvider *evm.EVMChainProvider + log *zap.Logger + homePath string +} + +func TestSenderTestSuite(t *testing.T) { + suite.Run(t, new(SenderTestSuite)) +} + +// SetupTest sets up the test suite by creating a temporary directory and declare mock objects. +func (s *SenderTestSuite) SetupTest() { + var err error + tmpDir := s.T().TempDir() + + log, err := zap.NewDevelopment() + s.Require().NoError(err) + + // mock objects. + s.log = zap.NewNop() + + chainName := "testnet" + + client := evm.NewClient(chainName, evmCfg, log) + + s.chainProvider, err = evm.NewEVMChainProvider(chainName, client, evmCfg, log, tmpDir) + s.Require().NoError(err) + + s.ctx = context.Background() + s.homePath = tmpDir +} + +func TestLoadKeyInfo(t *testing.T) { + tmpDir := t.TempDir() + chainName := "testnet" + + // write mock keyInfo at keyInfo's path + keyInfo := make(evm.KeyInfo) + keyInfo["key1"] = "" + keyInfo["key2"] = "" + b, err := toml.Marshal(&keyInfo) + require.NoError(t, err) + + keyInfoDir := path.Join(tmpDir, "keys", chainName, "info") + keyInfoPath := path.Join(keyInfoDir, "info.toml") + // Create the info folder if doesn't exist + err = os.MkdirAll(keyInfoDir, os.ModePerm) + require.NoError(t, err) + // Create the file and write the default config to the given location. + f, err := os.Create(keyInfoPath) + require.NoError(t, err) + defer f.Close() + + _, err = f.Write(b) + require.NoError(t, err) + + // load keyInfo + actual, err := evm.LoadKeyInfo(tmpDir, chainName) + require.NoError(t, err) + + require.Equal(t, keyInfo, actual) +} + +func (s *SenderTestSuite) TestLoadFreeSenders() { + keyName1 := "key1" + keyName2 := "key2" + + // Add two mock keys to the chain provider + _, err := s.chainProvider.AddKeyWithPrivateKey(keyName1, privateKey1, s.homePath, "") + s.Require().NoError(err) + + _, err = s.chainProvider.AddKeyWithPrivateKey(keyName2, privateKey2, s.homePath, "") + s.Require().NoError(err) + + // Load free senders + err = s.chainProvider.LoadFreeSenders(s.homePath, "") + s.Require().NoError(err) + + // Validate the FreeSenders channel is populated correctly + count := len(s.chainProvider.KeyInfo) + s.Require(). + Equal(count, len(s.chainProvider.FreeSenders)) + + // Create a map to check properties of retrieved senders + expectedSenders := map[string]string{ + address1: privateKey1, + address2: privateKey2, + } + + // Check all senders in the channel + for i := 0; i < count; i++ { + sender := <-s.chainProvider.FreeSenders + s.Require().NotNil(sender) + + actualAddress := sender.Address.Hex() + actualPrivateKey := evm.StripPrivateKeyPrefix( + hex.EncodeToString(crypto.FromECDSA(sender.PrivateKey)), + ) + + expectedPrivateKey, exists := expectedSenders[actualAddress] + s.Require().True(exists, "Unexpected sender address: %s", actualAddress) + + // Validate the private key matches + s.Require(). + Equal(evm.StripPrivateKeyPrefix(expectedPrivateKey), evm.StripPrivateKeyPrefix(actualPrivateKey)) + + // Remove the validated sender from the map + delete(expectedSenders, actualAddress) + } +} diff --git a/relayer/chains/evm/types.go b/relayer/chains/evm/types.go index f363122..62c6f53 100644 --- a/relayer/chains/evm/types.go +++ b/relayer/chains/evm/types.go @@ -38,20 +38,17 @@ type ConfirmTxResult struct { TxHash string Status TxStatus GasUsed decimal.NullDecimal - GasType GasType } func NewConfirmTxResult( txHash string, status TxStatus, gasUsed decimal.NullDecimal, - gasType GasType, ) *ConfirmTxResult { return &ConfirmTxResult{ TxHash: txHash, Status: status, GasUsed: gasUsed, - GasType: gasType, } } diff --git a/relayer/chains/evm/utils.go b/relayer/chains/evm/utils.go index 5ad1c39..dd305f4 100644 --- a/relayer/chains/evm/utils.go +++ b/relayer/chains/evm/utils.go @@ -20,17 +20,19 @@ func HexToAddress(s string) (gethcommon.Address, error) { return gethcommon.HexToAddress(s), nil } -// ConvertPrivateKeyStrToHex removes the "0x" prefix from the given private key string, if present. -func ConvertPrivateKeyStrToHex(privateKey string) string { +// StripPrivateKeyPrefix removes the "0x" prefix from the given private key string, if present. +func StripPrivateKeyPrefix(privateKey string) string { return strings.TrimPrefix(privateKey, privateKeyPrefix) } // MultiplyWithFloat64 multiplies a big.Int value with a float64 multiplier and convert back to big.Int. func MultiplyBigIntWithFloat64(value *big.Int, multiplier float64) *big.Int { - multiplierBig := big.NewFloat(multiplier) - valueBig := new(big.Float).SetInt(value) - valueBig.Mul(valueBig, multiplierBig) + // Define precision scale + scale := 1_000_000 - valueBigInt, _ := valueBig.Int(nil) - return valueBigInt + multiplierScaled := int64(multiplier * float64(scale)) + valueScaled := new(big.Int).Mul(value, big.NewInt(multiplierScaled)) + result := new(big.Int).Div(valueScaled, big.NewInt(int64(scale))) + + return result } diff --git a/relayer/chains/evm/utils_test.go b/relayer/chains/evm/utils_test.go index 95d242a..0d603d6 100644 --- a/relayer/chains/evm/utils_test.go +++ b/relayer/chains/evm/utils_test.go @@ -2,6 +2,7 @@ package evm_test import ( "fmt" + "math/big" "testing" "github.com/stretchr/testify/require" @@ -66,3 +67,79 @@ func TestHexToAddress(t *testing.T) { }) } } + +func TestStripPrivateKeyPrefix(t *testing.T) { + tests := []struct { + name string + privateKey string + expected string + }{ + { + name: "With 0x prefix", + privateKey: "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", + expected: "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", + }, + { + name: "Without 0x prefix", + privateKey: "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", + expected: "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := evm.StripPrivateKeyPrefix(tt.privateKey) + require.Equal(t, tt.expected, result, "unexpected result for privateKey %s", tt.privateKey) + }) + } +} + +func TestMultiplyBigIntWithFloat64(t *testing.T) { + tests := []struct { + name string + input *big.Int + multiplier float64 + expected *big.Int + }{ + { + name: "Multiply positive value", + input: big.NewInt(100), + multiplier: 2.5, + expected: big.NewInt(250), + }, + { + name: "Multiply by zero", + input: big.NewInt(100), + multiplier: 0, + expected: big.NewInt(0), + }, + { + name: "Multiply negative value", + input: big.NewInt(-100), + multiplier: 1.5, + expected: big.NewInt(-150), + }, + { + name: "Multiply large number", + input: big.NewInt(1e6), + multiplier: 1.1, + expected: big.NewInt(1100000), + }, + { + name: "Multiply by fractional multiplier", + input: big.NewInt(100), + multiplier: 0.333, + expected: big.NewInt(33), // Rounded down + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := evm.MultiplyBigIntWithFloat64(tt.input, tt.multiplier) + if result.Cmp(tt.expected) != 0 { + t.Errorf("MultiplyBigIntWithFloat64(%v, %f) = %v, want %v", + tt.input, tt.multiplier, result, tt.expected) + } + }) + } +} diff --git a/relayer/chains/provider.go b/relayer/chains/provider.go index ad49740..ffb586c 100644 --- a/relayer/chains/provider.go +++ b/relayer/chains/provider.go @@ -44,20 +44,20 @@ type KeyProvider interface { passphrase string, ) (*chainstypes.Key, error) - // IsKeyNameExist checks whether a key with the specified keyName already exists in storage. - IsKeyNameExist(keyName string) bool + // DeleteKey deletes the key information and private key + DeleteKey(homePath, keyName, passphrase string) error // ExportPrivateKey exports private key of specified key name. ExportPrivateKey(keyName string, passphrase string) (string, error) - // DeleteKey deletes the key information and private key - DeleteKey(homePath, keyName, passphrase string) error - // ListKeys lists all keys ListKeys() []*chainstypes.Key // ShowKey shows the address of the given key - ShowKey(keyName string) string + ShowKey(keyName string) (string, error) + + // IsKeyNameExist checks whether a key with the specified keyName already exists in storage. + IsKeyNameExist(keyName string) bool // LoadFreeSenders loads key info to prepare to relay the packet LoadFreeSenders(homePath, passphrase string) error diff --git a/relayer/config_test.go b/relayer/config_test.go index eb9dafe..2a98860 100644 --- a/relayer/config_test.go +++ b/relayer/config_test.go @@ -1,6 +1,7 @@ package relayer_test import ( + "fmt" "os" "path" "testing" @@ -17,80 +18,128 @@ import ( func TestLoadConfig(t *testing.T) { tmpDir := t.TempDir() - customConfigPath := "" cfgPath := path.Join(tmpDir, "config", "config.toml") - app := relayer.NewApp(nil, tmpDir, false, nil) - - // Prepare config before test - err := app.InitConfigFile(tmpDir, customConfigPath) - require.NoError(t, err) - - actual, err := relayer.LoadConfig(cfgPath) - require.NoError(t, err) - expect := relayer.DefaultConfig() - require.Equal(t, expect, actual) -} - -func TestLoadConfigNotFound(t *testing.T) { - tmpDir := t.TempDir() - cfgPath := path.Join(tmpDir, "config", "config.toml") - - _, err := relayer.LoadConfig(cfgPath) - require.ErrorContains(t, err, "no such file or directory") -} - -func TestLoadConfigInvalidChainProviderConfig(t *testing.T) { - tmpDir := t.TempDir() - cfgPath := path.Join(tmpDir, "config.toml") - - // create new toml config file - cfgText := `[target_chains.testnet] -chain_type = 'evms' -` - - err := os.WriteFile(cfgPath, []byte(cfgText), 0o600) - require.NoError(t, err) - - _, err = relayer.LoadConfig(cfgPath) - require.ErrorContains(t, err, "unsupported chain type: evms") -} - -func TestParseChainProviderConfigTypeEVM(t *testing.T) { - w := relayer.ChainProviderConfigWrapper{ - "chain_type": "evm", - "endpoints": []string{"http://localhost:8545"}, - } - - cfg, err := relayer.ParseChainProviderConfig(w) - - expect := &evm.EVMChainProviderConfig{ - BaseChainProviderConfig: chains.BaseChainProviderConfig{ - Endpoints: []string{"http://localhost:8545"}, - ChainType: chains.ChainTypeEVM, + testcases := []struct { + name string + preProcess func(t *testing.T) + postProcess func(t *testing.T) + out *relayer.Config + err error + }{ + { + name: "read default config", + preProcess: func(t *testing.T) { + app := relayer.NewApp(nil, tmpDir, false, nil) + err := app.InitConfigFile(tmpDir, "") + require.NoError(t, err) + }, + out: relayer.DefaultConfig(), + postProcess: func(t *testing.T) { + err := os.Remove(cfgPath) + require.NoError(t, err) + }, + }, + { + name: "no config file", + err: fmt.Errorf("no such file or directory"), + }, + { + name: "invalid config file; invalid chain type", + preProcess: func(t *testing.T) { + // create new toml config file + cfgText := `[target_chains.testnet] + chain_type = 'evms' + ` + + err := os.WriteFile(cfgPath, []byte(cfgText), 0o600) + require.NoError(t, err) + }, + err: fmt.Errorf("unsupported chain type: evms"), + postProcess: func(t *testing.T) { + err := os.Remove(cfgPath) + require.NoError(t, err) + }, }, } - require.NoError(t, err) - require.Equal(t, expect, cfg) -} -func TestParseChainProviderConfigTypeNotFound(t *testing.T) { - w := relayer.ChainProviderConfigWrapper{ - "chain_type": "evms", - "endpoints": []string{"http://localhost:8545"}, + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + if tc.preProcess != nil { + tc.preProcess(t) + } + + if tc.postProcess != nil { + defer tc.postProcess(t) + } + + actual, err := relayer.LoadConfig(cfgPath) + if tc.err != nil { + require.ErrorContains(t, err, tc.err.Error()) + } else { + require.NoError(t, err) + require.Equal(t, tc.out, actual) + } + }) } - - _, err := relayer.ParseChainProviderConfig(w) - require.ErrorContains(t, err, "unsupported chain type: evms") } -func TestParseChainProviderConfigNoChainType(t *testing.T) { - w := relayer.ChainProviderConfigWrapper{ - "endpoints": []string{"http://localhost:8545"}, +func TestParseChainProviderConfig(t *testing.T) { + testcases := []struct { + name string + in relayer.ChainProviderConfigWrapper + out chains.ChainProviderConfig + err error + }{ + { + name: "valid evm chain", + in: relayer.ChainProviderConfigWrapper{ + "chain_type": "evm", + "endpoints": []string{"http://localhost:8545"}, + }, + out: &evm.EVMChainProviderConfig{ + BaseChainProviderConfig: chains.BaseChainProviderConfig{ + Endpoints: []string{"http://localhost:8545"}, + ChainType: chains.ChainTypeEVM, + }, + }, + }, + { + name: "chain type not found", + in: relayer.ChainProviderConfigWrapper{ + "chain_type": "evms", + "endpoints": []string{"http://localhost:8545"}, + }, + err: fmt.Errorf("unsupported chain type: evms"), + }, + { + name: "missing chain type", + in: relayer.ChainProviderConfigWrapper{ + "endpoints": []string{"http://localhost:8545"}, + }, + err: fmt.Errorf("chain_type is required"), + }, + { + name: "chain type not string", + in: relayer.ChainProviderConfigWrapper{ + "chain_type": []string{"evm"}, + "endpoints": []string{"http://localhost:8545"}, + }, + err: fmt.Errorf("chain_type is required"), + }, } - _, err := relayer.ParseChainProviderConfig(w) - require.ErrorContains(t, err, "chain_type is required") + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + actual, err := relayer.ParseChainProviderConfig(tc.in) + if tc.err != nil { + require.ErrorContains(t, err, tc.err.Error()) + } else { + require.NoError(t, err) + require.Equal(t, tc.out, actual) + } + }) + } } func TestParseConfigInvalidChainProviderConfig(t *testing.T) { diff --git a/relayer/tunnel_relayer_test.go b/relayer/tunnel_relayer_test.go new file mode 100644 index 0000000..ddfe79a --- /dev/null +++ b/relayer/tunnel_relayer_test.go @@ -0,0 +1,258 @@ +package relayer_test + +import ( + "context" + "fmt" + "math/big" + "testing" + "time" + + cmbytes "github.com/cometbft/cometbft/libs/bytes" + "github.com/stretchr/testify/suite" + "go.uber.org/mock/gomock" + "go.uber.org/zap" + + "github.com/bandprotocol/falcon/internal/relayertest/mocks" + "github.com/bandprotocol/falcon/relayer" + bandtypes "github.com/bandprotocol/falcon/relayer/band/types" + chaintypes "github.com/bandprotocol/falcon/relayer/chains/types" +) + +type TunnelRelayerTestSuite struct { + suite.Suite + + app *relayer.App + ctx context.Context + chainProvider *mocks.MockChainProvider + client *mocks.MockClient + tunnelRelayer *relayer.TunnelRelayer +} + +const ( + defaultTunnelID = uint64(1) + defaultContractAddress = "" + defaultCheckingPacketInterval = time.Minute + defaultBandLatestSequence = uint64(1) + defaultTargetChainSequence = uint64(0) +) + +// SetupTest sets up the test suite by creating mock objects and initializing the TunnelRelayer. +func (s *TunnelRelayerTestSuite) SetupTest() { + ctrl := gomock.NewController(s.T()) + + s.chainProvider = mocks.NewMockChainProvider(ctrl) + s.client = mocks.NewMockClient(ctrl) + s.ctx = context.Background() + + tunnelRelayer := relayer.NewTunnelRelayer( + zap.NewNop(), + defaultTunnelID, + defaultContractAddress, + defaultCheckingPacketInterval, + s.client, + s.chainProvider, + ) + s.tunnelRelayer = &tunnelRelayer +} + +func TestTunnelRelayerTestSuite(t *testing.T) { + suite.Run(t, new(TunnelRelayerTestSuite)) +} + +// Helper function to mock GetTunnel. +func (s *TunnelRelayerTestSuite) mockGetTunnel(bandLatestSequence uint64) { + s.client.EXPECT().GetTunnel(s.ctx, s.tunnelRelayer.TunnelID).Return(bandtypes.NewTunnel( + s.tunnelRelayer.TunnelID, + bandLatestSequence, + "", + "", + true, + ), nil) +} + +// Helper function to mock QueryTunnelInfo. +func (s *TunnelRelayerTestSuite) mockQueryTunnelInfo(sequence uint64, isActive bool) { + s.chainProvider.EXPECT(). + QueryTunnelInfo(s.ctx, s.tunnelRelayer.TunnelID, s.tunnelRelayer.ContractAddress). + Return(&chaintypes.Tunnel{ + ID: s.tunnelRelayer.TunnelID, + TargetAddress: s.tunnelRelayer.ContractAddress, + IsActive: isActive, + LatestSequence: sequence, + Balance: big.NewInt(1), + }, nil) +} + +// Helper function to create a mock Packet. +func createMockPacket(tunnelID, sequence uint64, status string) *bandtypes.Packet { + signalPrices := []bandtypes.SignalPrice{ + {SignalID: "signal1", Price: 100}, + {SignalID: "signal2", Price: 200}, + } + evmSignature := bandtypes.NewEVMSignature( + cmbytes.HexBytes("0x1234"), + cmbytes.HexBytes("0xabcd"), + ) + + signing := bandtypes.NewSigning( + 1, + cmbytes.HexBytes("0xdeadbeef"), + evmSignature, + status, + ) + + return bandtypes.NewPacket( + tunnelID, + sequence, + signalPrices, + signing, + nil, + ) +} + +func (s *TunnelRelayerTestSuite) TestCheckAndRelay() { + testcases := []struct { + name string + preprocess func() + err error + }{ + { + name: "success", + preprocess: func() { + s.mockGetTunnel(defaultBandLatestSequence) + s.mockQueryTunnelInfo(defaultTargetChainSequence, true) + + packet := createMockPacket( + s.tunnelRelayer.TunnelID, + defaultTargetChainSequence+1, + "SIGNING_STATUS_SUCCESS", + ) + s.client.EXPECT(). + GetTunnelPacket(gomock.Any(), s.tunnelRelayer.TunnelID, defaultTargetChainSequence+1). + Return(packet, nil) + s.chainProvider.EXPECT().RelayPacket(gomock.Any(), packet).Return(nil) + + // Check and relay the packet for the second time + s.mockGetTunnel(defaultBandLatestSequence) + s.mockQueryTunnelInfo(defaultTargetChainSequence+1, true) + }, + }, + { + name: "failed to get tunnel on band client", + preprocess: func() { + s.client.EXPECT(). + GetTunnel(s.ctx, s.tunnelRelayer.TunnelID). + Return(nil, fmt.Errorf("failed to get tunnel")) + }, + err: fmt.Errorf("failed to get tunnel"), + }, + { + name: "failed to query chain tunnel info", + preprocess: func() { + s.mockGetTunnel(defaultBandLatestSequence) + s.chainProvider.EXPECT(). + QueryTunnelInfo(gomock.Any(), s.tunnelRelayer.TunnelID, s.tunnelRelayer.ContractAddress). + Return(nil, fmt.Errorf("failed to query tunnel info")) + }, + err: fmt.Errorf("failed to query tunnel info"), + }, + { + name: "target chain not active", + preprocess: func() { + s.mockGetTunnel(defaultBandLatestSequence) + s.mockQueryTunnelInfo(defaultTargetChainSequence, false) + }, + err: nil, + }, + { + name: "no new packet to relay", + preprocess: func() { + s.mockGetTunnel(defaultBandLatestSequence) + s.mockQueryTunnelInfo(defaultTargetChainSequence+1, true) + }, + err: nil, + }, + { + name: "fail to get a new packet", + preprocess: func() { + s.mockGetTunnel(defaultBandLatestSequence) + s.mockQueryTunnelInfo(defaultTargetChainSequence, true) + + s.client.EXPECT(). + GetTunnelPacket(gomock.Any(), s.tunnelRelayer.TunnelID, defaultTargetChainSequence+1). + Return(nil, fmt.Errorf("failed to get packet")) + }, + err: fmt.Errorf("failed to get packet"), + }, + { + name: "signing status fallen", + preprocess: func() { + s.mockGetTunnel(defaultBandLatestSequence) + s.mockQueryTunnelInfo(defaultTargetChainSequence, true) + + packet := createMockPacket( + s.tunnelRelayer.TunnelID, + defaultTargetChainSequence+1, + "SIGNING_STATUS_FALLEN", + ) + + s.client.EXPECT(). + GetTunnelPacket(s.ctx, s.tunnelRelayer.TunnelID, defaultTargetChainSequence+1). + Return(packet, nil) + }, + err: fmt.Errorf(("signing status is not success")), + }, + { + name: "signing status is waiting", + preprocess: func() { + s.mockGetTunnel(defaultBandLatestSequence) + s.mockQueryTunnelInfo(defaultTargetChainSequence, true) + + packet := createMockPacket( + s.tunnelRelayer.TunnelID, + defaultTargetChainSequence+1, + "SIGNING_STATUS_WAITING", + ) + + s.client.EXPECT(). + GetTunnelPacket(s.ctx, s.tunnelRelayer.TunnelID, defaultTargetChainSequence+1). + Return(packet, nil) + }, + err: nil, + }, + { + name: "failed to relay packet", + preprocess: func() { + s.mockGetTunnel(defaultBandLatestSequence) + s.mockQueryTunnelInfo(defaultTargetChainSequence, true) + + packet := createMockPacket( + s.tunnelRelayer.TunnelID, + defaultTargetChainSequence+1, + "SIGNING_STATUS_SUCCESS", + ) + + s.client.EXPECT(). + GetTunnelPacket(s.ctx, s.tunnelRelayer.TunnelID, defaultTargetChainSequence+1). + Return(packet, nil) + s.chainProvider.EXPECT().RelayPacket(s.ctx, packet).Return(fmt.Errorf("failed to relay packet")) + }, + err: fmt.Errorf("failed to relay packet"), + }, + } + + for _, tc := range testcases { + s.T().Run(tc.name, func(t *testing.T) { + if tc.preprocess != nil { + tc.preprocess() + } + + err := s.tunnelRelayer.CheckAndRelay(s.ctx) + if tc.err != nil { + s.Require().ErrorContains(err, tc.err.Error()) + } else { + s.Require().NoError(err) + } + }) + } +} diff --git a/scripts/mockgen.sh b/scripts/mockgen.sh index 44ed803..e644941 100755 --- a/scripts/mockgen.sh +++ b/scripts/mockgen.sh @@ -3,6 +3,7 @@ mockgen_cmd="mockgen" $mockgen_cmd -source=relayer/chains/config.go -package mocks -destination internal/relayertest/mocks/chain_provider_config.go $mockgen_cmd -source=relayer/chains/provider.go -package mocks -destination internal/relayertest/mocks/chain_provider.go +$mockgen_cmd -source=relayer/chains/evm/client.go -mock_names Client=MockEVMClient -package mocks -destination internal/relayertest/mocks/chain_evm_client.go $mockgen_cmd -source=relayer/band/client.go -package mocks -destination internal/relayertest/mocks/band_client.go $mockgen_cmd -package mocks -mock_names QueryClient=MockTunnelQueryClient -destination internal/relayertest/mocks/tunnel_query_client.go github.com/bandprotocol/chain/v3/x/tunnel/types QueryClient $mockgen_cmd -package mocks -mock_names QueryClient=MockBandtssQueryClient -destination internal/relayertest/mocks/bandtss_query_client.go github.com/bandprotocol/chain/v3/x/bandtss/types QueryClient