diff --git a/relayer/log.go b/cmd/log.go similarity index 57% rename from relayer/log.go rename to cmd/log.go index 9290b12..f67172f 100644 --- a/relayer/log.go +++ b/cmd/log.go @@ -1,4 +1,4 @@ -package relayer +package cmd import ( "fmt" @@ -6,11 +6,13 @@ import ( "time" zaplogfmt "github.com/jsternberg/zap-logfmt" + "github.com/spf13/viper" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) -func newRootLogger(format string, logLevel string) (*zap.Logger, error) { +// newLogger creates a new root logger with the given log format and log level. +func newLogger(format string, logLevel string) (*zap.Logger, error) { config := zap.NewProductionEncoderConfig() config.EncodeTime = func(ts time.Time, encoder zapcore.PrimitiveArrayEncoder) { encoder.AppendString(ts.UTC().Format("2006-01-02T15:04:05.000000Z07:00")) @@ -51,3 +53,28 @@ func newRootLogger(format string, logLevel string) (*zap.Logger, error) { return logger, nil } + +// initLogger initializes the logger with the given default log level. +func initLogger(defaultLogLevel string) (log *zap.Logger, err error) { + logFormat := viper.GetString("log-format") + + logLevel := viper.GetString("log-level") + if viper.GetBool("debug") { + logLevel = "debug" + } + if logLevel == "" && defaultLogLevel != "" { + logLevel = defaultLogLevel + } + + // initialize logger only if user run command "start" or log level is "debug" + if os.Args[1] == "start" || logLevel == "debug" { + log, err = newLogger(logFormat, logLevel) + if err != nil { + return nil, err + } + } else { + log = zap.NewNop() + } + + return log, nil +} diff --git a/cmd/root.go b/cmd/root.go index 62e72ad..9c305b9 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -17,18 +17,22 @@ import ( "go.uber.org/zap" falcon "github.com/bandprotocol/falcon/relayer" + "github.com/bandprotocol/falcon/relayer/store" ) const ( appName = "falcon" defaultCoinType = 60 + + PassphraseEnvKey = "PASSPHRASE" ) var defaultHome = filepath.Join(os.Getenv("HOME"), ".falcon") // NewRootCmd returns the root command for falcon. func NewRootCmd(log *zap.Logger) *cobra.Command { - app := falcon.NewApp(log, defaultHome, false, nil) + passphrase := os.Getenv(PassphraseEnvKey) + app := falcon.NewApp(log, defaultHome, false, nil, passphrase, nil) // RootCmd represents the base command when called without any subcommands rootCmd := &cobra.Command{ @@ -45,15 +49,29 @@ func NewRootCmd(log *zap.Logger) *cobra.Command { } rootCmd.PersistentPreRunE = func(cmd *cobra.Command, _ []string) error { - // retrieve log level from viper - logLevelViper := viper.GetString("log-level") - if viper.GetBool("debug") { - logLevelViper = "debug" + // set up store + app.Store = store.NewFileSystem(app.HomePath, app.Passphrase) + + // load configuration + var err error + app.Config, err = app.Store.GetConfig() + if err != nil { + return err } - logFormat := viper.GetString("log-format") + // retrieve log level from config + configLogLevel := "" + if app.Config != nil { + configLogLevel = app.Config.Global.LogLevel + } + + // init log object + app.Log, err = initLogger(configLogLevel) + if err != nil { + return err + } - return app.Init(rootCmd.Context(), logLevelViper, logFormat) + return app.Init(rootCmd.Context()) } rootCmd.PersistentPostRun = func(cmd *cobra.Command, _ []string) { diff --git a/internal/relayertest/constants.go b/internal/relayertest/constants.go index e89dcb0..81f2e3a 100644 --- a/internal/relayertest/constants.go +++ b/internal/relayertest/constants.go @@ -4,10 +4,10 @@ import ( _ "embed" "time" - falcon "github.com/bandprotocol/falcon/relayer" "github.com/bandprotocol/falcon/relayer/band" "github.com/bandprotocol/falcon/relayer/chains" "github.com/bandprotocol/falcon/relayer/chains/evm" + "github.com/bandprotocol/falcon/relayer/config" ) //go:embed testdata/default_config.toml @@ -19,8 +19,8 @@ var CustomCfgText string //go:embed testdata/custom_config_with_time_str.toml var CustomCfgTextWithTimeStr string -var CustomCfg = falcon.Config{ - Global: falcon.GlobalConfig{ +var CustomCfg = config.Config{ + Global: config.GlobalConfig{ CheckingPacketInterval: 1 * time.Minute, SyncTunnelsInterval: 5 * time.Minute, MaxCheckingPacketPenaltyDuration: 1 * time.Hour, @@ -32,7 +32,7 @@ var CustomCfg = falcon.Config{ Timeout: 3 * time.Second, LivelinessCheckingInterval: 5 * time.Minute, }, - TargetChains: chains.ChainProviderConfigs{ + TargetChains: config.ChainProviderConfigs{ "testnet": &evm.EVMChainProviderConfig{ BaseChainProviderConfig: chains.BaseChainProviderConfig{ Endpoints: []string{"http://localhost:8545"}, diff --git a/internal/relayertest/mocks/band_chain_query.go b/internal/relayertest/mocks/band_chain_query.go index d77f786..3e11396 100644 --- a/internal/relayertest/mocks/band_chain_query.go +++ b/internal/relayertest/mocks/band_chain_query.go @@ -13,8 +13,8 @@ import ( context "context" reflect "reflect" - types "github.com/bandprotocol/falcon/internal/bandchain/bandtss" - types0 "github.com/bandprotocol/falcon/internal/bandchain/tunnel" + bandtss "github.com/bandprotocol/falcon/internal/bandchain/bandtss" + tunnel "github.com/bandprotocol/falcon/internal/bandchain/tunnel" gomock "go.uber.org/mock/gomock" grpc "google.golang.org/grpc" ) @@ -44,14 +44,14 @@ func (m *MockQueryClient) EXPECT() *MockQueryClientMockRecorder { } // Packet mocks base method. -func (m *MockQueryClient) Packet(ctx context.Context, in *types0.QueryPacketRequest, opts ...grpc.CallOption) (*types0.QueryPacketResponse, error) { +func (m *MockQueryClient) Packet(ctx context.Context, in *tunnel.QueryPacketRequest, opts ...grpc.CallOption) (*tunnel.QueryPacketResponse, error) { m.ctrl.T.Helper() varargs := []any{ctx, in} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Packet", varargs...) - ret0, _ := ret[0].(*types0.QueryPacketResponse) + ret0, _ := ret[0].(*tunnel.QueryPacketResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -64,14 +64,14 @@ func (mr *MockQueryClientMockRecorder) Packet(ctx, in any, opts ...any) *gomock. } // Signing mocks base method. -func (m *MockQueryClient) Signing(ctx context.Context, in *types.QuerySigningRequest, opts ...grpc.CallOption) (*types.QuerySigningResponse, error) { +func (m *MockQueryClient) Signing(ctx context.Context, in *bandtss.QuerySigningRequest, opts ...grpc.CallOption) (*bandtss.QuerySigningResponse, error) { m.ctrl.T.Helper() varargs := []any{ctx, in} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Signing", varargs...) - ret0, _ := ret[0].(*types.QuerySigningResponse) + ret0, _ := ret[0].(*bandtss.QuerySigningResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -84,14 +84,14 @@ func (mr *MockQueryClientMockRecorder) Signing(ctx, in any, opts ...any) *gomock } // Tunnel mocks base method. -func (m *MockQueryClient) Tunnel(ctx context.Context, in *types0.QueryTunnelRequest, opts ...grpc.CallOption) (*types0.QueryTunnelResponse, error) { +func (m *MockQueryClient) Tunnel(ctx context.Context, in *tunnel.QueryTunnelRequest, opts ...grpc.CallOption) (*tunnel.QueryTunnelResponse, error) { m.ctrl.T.Helper() varargs := []any{ctx, in} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Tunnel", varargs...) - ret0, _ := ret[0].(*types0.QueryTunnelResponse) + ret0, _ := ret[0].(*tunnel.QueryTunnelResponse) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -104,14 +104,14 @@ func (mr *MockQueryClientMockRecorder) Tunnel(ctx, in any, opts ...any) *gomock. } // Tunnels mocks base method. -func (m *MockQueryClient) Tunnels(ctx context.Context, in *types0.QueryTunnelsRequest, opts ...grpc.CallOption) (*types0.QueryTunnelsResponse, error) { +func (m *MockQueryClient) Tunnels(ctx context.Context, in *tunnel.QueryTunnelsRequest, opts ...grpc.CallOption) (*tunnel.QueryTunnelsResponse, error) { m.ctrl.T.Helper() varargs := []any{ctx, in} for _, a := range opts { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Tunnels", varargs...) - ret0, _ := ret[0].(*types0.QueryTunnelsResponse) + ret0, _ := ret[0].(*tunnel.QueryTunnelsResponse) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/internal/relayertest/mocks/chain_provider_config.go b/internal/relayertest/mocks/chain_provider_config.go index 9793146..573b809 100644 --- a/internal/relayertest/mocks/chain_provider_config.go +++ b/internal/relayertest/mocks/chain_provider_config.go @@ -13,6 +13,7 @@ import ( reflect "reflect" chains "github.com/bandprotocol/falcon/relayer/chains" + wallet "github.com/bandprotocol/falcon/relayer/wallet" gomock "go.uber.org/mock/gomock" zap "go.uber.org/zap" ) @@ -41,19 +42,33 @@ func (m *MockChainProviderConfig) EXPECT() *MockChainProviderConfigMockRecorder return m.recorder } +// GetChainType mocks base method. +func (m *MockChainProviderConfig) GetChainType() chains.ChainType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChainType") + ret0, _ := ret[0].(chains.ChainType) + return ret0 +} + +// GetChainType indicates an expected call of GetChainType. +func (mr *MockChainProviderConfigMockRecorder) GetChainType() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChainType", reflect.TypeOf((*MockChainProviderConfig)(nil).GetChainType)) +} + // NewChainProvider mocks base method. -func (m *MockChainProviderConfig) NewChainProvider(chainName string, log *zap.Logger, homePath string, debug bool) (chains.ChainProvider, error) { +func (m *MockChainProviderConfig) NewChainProvider(chainName string, log *zap.Logger, homePath string, debug bool, wallet wallet.Wallet) (chains.ChainProvider, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewChainProvider", chainName, log, homePath, debug) + ret := m.ctrl.Call(m, "NewChainProvider", chainName, log, homePath, debug, wallet) ret0, _ := ret[0].(chains.ChainProvider) ret1, _ := ret[1].(error) return ret0, ret1 } // NewChainProvider indicates an expected call of NewChainProvider. -func (mr *MockChainProviderConfigMockRecorder) NewChainProvider(chainName, log, homePath, debug any) *gomock.Call { +func (mr *MockChainProviderConfigMockRecorder) NewChainProvider(chainName, log, homePath, debug, wallet any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewChainProvider", reflect.TypeOf((*MockChainProviderConfig)(nil).NewChainProvider), chainName, log, homePath, debug) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewChainProvider", reflect.TypeOf((*MockChainProviderConfig)(nil).NewChainProvider), chainName, log, homePath, debug, wallet) } // Validate mocks base method. diff --git a/main.go b/main.go index 37db008..b2baaa4 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,20 @@ package main import ( + "fmt" "os" + "github.com/joho/godotenv" + "github.com/bandprotocol/falcon/cmd" ) func main() { + // loading .env file + if err := godotenv.Load(); err != nil && !os.IsNotExist(err) { + panic(fmt.Sprintf("Error due to loading .env file; %v", err)) + } + if err := cmd.Execute(); err != nil { os.Exit(1) } diff --git a/relayer/app.go b/relayer/app.go index 1e4e441..759b041 100644 --- a/relayer/app.go +++ b/relayer/app.go @@ -7,17 +7,15 @@ import ( "fmt" "math/big" "os" - "path" - "github.com/joho/godotenv" - "github.com/pelletier/go-toml/v2" "go.uber.org/zap" - "github.com/bandprotocol/falcon/internal" "github.com/bandprotocol/falcon/relayer/band" bandtypes "github.com/bandprotocol/falcon/relayer/band/types" "github.com/bandprotocol/falcon/relayer/chains" chainstypes "github.com/bandprotocol/falcon/relayer/chains/types" + "github.com/bandprotocol/falcon/relayer/config" + "github.com/bandprotocol/falcon/relayer/store" "github.com/bandprotocol/falcon/relayer/types" ) @@ -25,7 +23,6 @@ const ( ConfigFolderName = "config" ConfigFileName = "config.toml" PassphraseFileName = "passphrase.hash" - PassphraseEnvKey = "PASSPHRASE" ) // App is the main application struct. @@ -33,11 +30,12 @@ type App struct { Log *zap.Logger HomePath string Debug bool - Config *Config + Config *config.Config + Store store.Store - TargetChains chains.ChainProviders - BandClient band.Client - EnvPassphrase string + TargetChains chains.ChainProviders + BandClient band.Client + Passphrase string } // NewApp creates a new App instance. @@ -45,35 +43,23 @@ func NewApp( log *zap.Logger, homePath string, debug bool, - config *Config, + config *config.Config, + passphrase string, + store store.Store, ) *App { app := App{ - Log: log, - HomePath: homePath, - Debug: debug, - Config: config, + Log: log, + HomePath: homePath, + Debug: debug, + Config: config, + Store: store, + Passphrase: passphrase, } return &app } // Init initialize the application. -func (a *App) Init(ctx context.Context, logLevel, logFormat string) error { - if a.Config == nil { - if err := a.LoadConfigFile(); err != nil { - return err - } - } - - // initialize logger, if not already initialized - if a.Log == nil { - if err := a.initLogger(logLevel, logFormat); err != nil { - return err - } - } - - // load passphrase from .env file or system environment variables - a.EnvPassphrase = a.loadEnvPassphrase() - +func (a *App) Init(ctx context.Context) error { // if config is not initialized, return if a.Config == nil { return nil @@ -105,32 +91,21 @@ func (a *App) initBandClient(ctx context.Context) error { return nil } -// initLogger initializes the logger with the given log level. -func (a *App) initLogger(logLevel, logFormat string) error { - if logLevel == "" && a.Config != nil { - logLevel = a.Config.Global.LogLevel - } - - // initialize logger only if user run command "start" or log level is "debug" - if os.Args[1] == "start" || logLevel == "debug" { - log, err := newRootLogger(logFormat, logLevel) - if err != nil { - return err - } - a.Log = log - } else { - a.Log = zap.NewNop() - } - - return nil -} - // initTargetChains initializes the target chains. func (a *App) initTargetChains() error { a.TargetChains = make(chains.ChainProviders) for chainName, chainConfig := range a.Config.TargetChains { - cp, err := chainConfig.NewChainProvider(chainName, a.Log, a.HomePath, a.Debug) + wallet, err := a.Store.NewWallet(chainConfig.GetChainType(), chainName) + if err != nil { + a.Log.Error("Wallet registry not found", + zap.Error(err), + zap.String("chain_name", chainName), + ) + return err + } + + cp, err := chainConfig.NewChainProvider(chainName, a.Log, a.HomePath, a.Debug, wallet) if err != nil { a.Log.Error("Cannot create chain provider", zap.Error(err), @@ -145,107 +120,43 @@ func (a *App) initTargetChains() error { return nil } -// LoadConfigFile reads config file into a.Config if file is present. -func (a *App) LoadConfigFile() error { - cfgPath := path.Join(a.HomePath, ConfigFolderName, ConfigFileName) - - // check if file doesn't exist, exit the function as the config may not be initialized. - if _, err := os.Stat(cfgPath); os.IsNotExist(err) { - return nil - } else if err != nil { - return err - } - - // read the config from config path - cfg, err := LoadConfig(cfgPath) - if err != nil { - return err - } - - // save configuration - a.Config = cfg - - return nil -} - // InitConfigFile initializes the configuration to the given path. func (a *App) InitConfigFile(homePath string, customFilePath string) error { - cfgDir := path.Join(homePath, ConfigFolderName) - cfgPath := path.Join(cfgDir, ConfigFileName) - - // check if the config file already exists - // https://stackoverflow.com/questions/12518876/how-to-check-if-a-file-exists-in-go - if _, err := os.Stat(cfgPath); err == nil { - return fmt.Errorf("config already exists: %s", cfgPath) - } else if !os.IsNotExist(err) { + // Check if config already exists + if ok, err := a.Store.HasConfig(); err != nil { return err + } else if ok { + return fmt.Errorf("config already exists") } // Load config from given custom file path if exists - var cfg *Config - var err error + var cfg *config.Config switch { case customFilePath != "": - cfg, err = LoadConfig(customFilePath) // Initialize with CustomConfig if file is provided + b, err := os.ReadFile(customFilePath) if err != nil { - return fmt.Errorf("LoadConfig file %v error %v", customFilePath, err) + return fmt.Errorf("cannot read a config file %s: %w", customFilePath, err) } - default: - cfg = DefaultConfig() // Initialize with DefaultConfig if no file is provided - } - - // Marshal config object into bytes - b, err := toml.Marshal(cfg) - if err != nil { - return err - } - - // Create the home folder if doesn't exist - if err := internal.CheckAndCreateFolder(homePath); err != nil { - return err - } - - // Create the config folder if doesn't exist - if err := internal.CheckAndCreateFolder(cfgDir); err != nil { - return err - } - - // Create the file and write the default config to the given location. - f, err := os.Create(cfgPath) - if err != nil { - return err - } - defer f.Close() - if _, err = f.Write(b); err != nil { - return err + cfg, err = config.ParseConfig(b) + if err != nil { + return fmt.Errorf("parsing config error %w", err) + } + default: + cfg = config.DefaultConfig() // Initialize with DefaultConfig if no file is provided } - return nil + return a.Store.SaveConfig(cfg) } // InitPassphrase hashes the provided passphrase and saves it to the given path. func (a *App) InitPassphrase() error { // Load and hash the passphrase h := sha256.New() - h.Write([]byte(a.EnvPassphrase)) - b := h.Sum(nil) - - cfgDir := path.Join(a.HomePath, ConfigFolderName) - passphrasePath := path.Join(cfgDir, PassphraseFileName) - - // Create the file and write the hashed passphrase to the given location. - f, err := os.Create(passphrasePath) - if err != nil { - return err - } - defer f.Close() + h.Write([]byte(a.Passphrase)) + passphrase := h.Sum(nil) - if _, err = f.Write(b); err != nil { - return err - } - - return nil + return a.Store.SavePassphrase(passphrase) } // QueryTunnelInfo queries tunnel information by given tunnel ID @@ -303,23 +214,13 @@ func (a *App) AddChainConfig(chainName string, filePath string) error { return fmt.Errorf("existing chain name : %s", chainName) } - chainProviderConfig, err := LoadChainConfig(filePath) + chainProviderConfig, err := config.LoadChainConfig(filePath) if err != nil { return err } a.Config.TargetChains[chainName] = chainProviderConfig - - cfgDir := path.Join(a.HomePath, ConfigFolderName) - cfgPath := path.Join(cfgDir, ConfigFileName) - - // Marshal config object into bytes - b, err := toml.Marshal(a.Config) - if err != nil { - return err - } - - return os.WriteFile(cfgPath, b, 0o600) + return a.Store.SaveConfig(a.Config) } // DeleteChainConfig deletes the chain configuration from the config file. @@ -333,17 +234,7 @@ func (a *App) DeleteChainConfig(chainName string) error { } delete(a.Config.TargetChains, chainName) - - cfgDir := path.Join(a.HomePath, ConfigFolderName) - cfgPath := path.Join(cfgDir, ConfigFileName) - - // Marshal config object into bytes - b, err := toml.Marshal(a.Config) - if err != nil { - return err - } - - return os.WriteFile(cfgPath, b, 0o600) + return a.Store.SaveConfig(a.Config) } // GetChainConfig retrieves the chain configuration by given chain name. @@ -375,7 +266,7 @@ func (a *App) AddKey( return nil, fmt.Errorf("config does not exist: %s", a.HomePath) } - if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.Passphrase); err != nil { return nil, err } @@ -384,7 +275,7 @@ func (a *App) AddKey( return nil, fmt.Errorf("chain name does not exist: %s", chainName) } - keyOutput, err := cp.AddKey(keyName, mnemonic, privateKey, a.HomePath, coinType, account, index, a.EnvPassphrase) + keyOutput, err := cp.AddKey(keyName, mnemonic, privateKey, a.HomePath, coinType, account, index, a.Passphrase) if err != nil { return nil, err } @@ -398,7 +289,7 @@ func (a *App) DeleteKey(chainName string, keyName string) error { return fmt.Errorf("config does not exist: %s", a.HomePath) } - if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.Passphrase); err != nil { return err } @@ -407,7 +298,7 @@ func (a *App) DeleteKey(chainName string, keyName string) error { return fmt.Errorf("chain name does not exist: %s", chainName) } - return cp.DeleteKey(a.HomePath, keyName, a.EnvPassphrase) + return cp.DeleteKey(a.HomePath, keyName, a.Passphrase) } // ExportKey exports the private key from the chain provider. @@ -416,7 +307,7 @@ func (a *App) ExportKey(chainName string, keyName string) (string, error) { return "", fmt.Errorf("config does not exist: %s", a.HomePath) } - if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.Passphrase); err != nil { return "", err } @@ -425,7 +316,7 @@ func (a *App) ExportKey(chainName string, keyName string) (string, error) { return "", fmt.Errorf("chain name does not exist: %s", chainName) } - privateKey, err := cp.ExportPrivateKey(keyName, a.EnvPassphrase) + privateKey, err := cp.ExportPrivateKey(keyName, a.Passphrase) if err != nil { return "", err } @@ -480,22 +371,6 @@ func (a *App) QueryBalance(ctx context.Context, chainName string, keyName string return cp.QueryBalance(ctx, keyName) } -// loadEnvPassphrase retrieves the passphrase string from the .env file or system environment variables. -// It first attempts to load the .env file. If the file is not found or cannot be loaded, -// it falls back to retrieving the "PASSPHRASE" variable from the system environment variables. -func (a *App) loadEnvPassphrase() string { - // load passphrase from .env first. if not present, use env variable from command - if err := godotenv.Load(); err != nil { - a.Log.Debug( - ".env file not found, attempting to use system environment variables", - zap.Error(err), - ) - } else { - a.Log.Debug("Loaded .env file successfully, attempting to use variable from .env file") - } - return os.Getenv(PassphraseEnvKey) -} - // ValidatePassphrase checks if the provided passphrase (from the environment) // matches the hashed passphrase stored on disk. func (a *App) ValidatePassphrase(envPassphrase string) error { @@ -505,15 +380,12 @@ func (a *App) ValidatePassphrase(envPassphrase string) error { envb := h.Sum(nil) // load passphrase from local disk - cfgDir := path.Join(a.HomePath, ConfigFolderName) - passphrasePath := path.Join(cfgDir, PassphraseFileName) - - b, err := os.ReadFile(passphrasePath) + storedPassphrase, err := a.Store.GetPassphrase() if err != nil { return err } - if !bytes.Equal(envb, b) { + if !bytes.Equal(envb, storedPassphrase) { return fmt.Errorf("invalid passphrase: the provided passphrase does not match the stored passphrase") } @@ -531,13 +403,13 @@ func (a *App) Start(ctx context.Context, tunnelIDs []uint64) error { } // validate passphrase - if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.Passphrase); err != nil { return err } // initialize target chain providers for chainName, chainProvider := range a.TargetChains { - if err := chainProvider.LoadFreeSenders(a.HomePath, a.EnvPassphrase); err != nil { + if err := chainProvider.LoadFreeSenders(a.HomePath, a.Passphrase); err != nil { a.Log.Error("Cannot load keys in target chain", zap.Error(err), zap.String("chain_name", chainName), @@ -598,7 +470,7 @@ func (a *App) Relay(ctx context.Context, tunnelID uint64) error { return err } - if err := a.ValidatePassphrase(a.EnvPassphrase); err != nil { + if err := a.ValidatePassphrase(a.Passphrase); err != nil { return err } @@ -607,7 +479,7 @@ func (a *App) Relay(ctx context.Context, tunnelID uint64) error { return fmt.Errorf("target chain provider not found: %s", tunnel.TargetChainID) } - if err := chainProvider.LoadFreeSenders(a.HomePath, a.EnvPassphrase); err != nil { + if err := chainProvider.LoadFreeSenders(a.HomePath, a.Passphrase); err != nil { a.Log.Error("Cannot load keys in target chain", zap.Error(err), zap.String("chain_name", tunnel.TargetChainID), diff --git a/relayer/app_test.go b/relayer/app_test.go index 33d2f3b..f931287 100644 --- a/relayer/app_test.go +++ b/relayer/app_test.go @@ -23,6 +23,8 @@ import ( "github.com/bandprotocol/falcon/relayer/chains" "github.com/bandprotocol/falcon/relayer/chains/evm" chainstypes "github.com/bandprotocol/falcon/relayer/chains/types" + "github.com/bandprotocol/falcon/relayer/config" + "github.com/bandprotocol/falcon/relayer/store" "github.com/bandprotocol/falcon/relayer/types" ) @@ -46,7 +48,7 @@ func (s *AppTestSuite) SetupTest() { s.chainProvider = mocks.NewMockChainProvider(ctrl) s.client = mocks.NewMockClient(ctrl) - cfg := relayer.Config{ + cfg := config.Config{ BandChain: band.Config{ RpcEndpoints: []string{"http://localhost:26659"}, LivelinessCheckingInterval: 5 * time.Minute, @@ -54,7 +56,7 @@ func (s *AppTestSuite) SetupTest() { TargetChains: map[string]chains.ChainProviderConfig{ "testnet_evm": s.chainProviderConfig, }, - Global: relayer.GlobalConfig{}, + Global: config.GlobalConfig{}, } cfgFolder := path.Join(tmpDir, relayer.ConfigFolderName) @@ -65,11 +67,12 @@ func (s *AppTestSuite) SetupTest() { Log: log, HomePath: tmpDir, Config: &cfg, + Store: store.NewFileSystem(tmpDir, "secret"), TargetChains: map[string]chains.ChainProvider{ "testnet_evm": s.chainProvider, }, - BandClient: s.client, - EnvPassphrase: "secret", + BandClient: s.client, + Passphrase: "secret", } // Call InitPassphrase @@ -86,13 +89,13 @@ func (s *AppTestSuite) TestInitConfig() { name string preprocess func() in string - out *relayer.Config + out *config.Config err error }{ { name: "success - default", in: "", - out: relayer.DefaultConfig(), + out: config.DefaultConfig(), }, { name: "config already exists", @@ -101,7 +104,7 @@ func (s *AppTestSuite) TestInitConfig() { s.Require().NoError(err) }, in: "", - err: fmt.Errorf("config already exists:"), + err: fmt.Errorf("config already exists"), }, { name: "init config from specific file", @@ -122,13 +125,13 @@ func (s *AppTestSuite) TestInitConfig() { s.Require().NoError(err) }, in: path.Join(s.app.HomePath, "custom.toml"), - out: &relayer.Config{ + out: &config.Config{ BandChain: band.Config{ RpcEndpoints: []string{"http://localhost:26659"}, Timeout: 50, }, TargetChains: map[string]chains.ChainProviderConfig{}, - Global: relayer.GlobalConfig{ + Global: config.GlobalConfig{ CheckingPacketInterval: time.Minute, }, }, @@ -175,7 +178,7 @@ func (s *AppTestSuite) TestAddChainConfig() { type Input struct { chainName string cfgPath string - existingCfg *relayer.Config + existingCfg *config.Config } testcases := []struct { name string @@ -215,7 +218,7 @@ func (s *AppTestSuite) TestAddChainConfig() { in: Input{ chainName: "testnet", cfgPath: path.Join(newHomePath, "chain_config.toml"), - existingCfg: &relayer.Config{ + existingCfg: &config.Config{ TargetChains: map[string]chains.ChainProviderConfig{ "testnet": &evm.EVMChainProviderConfig{}, }, @@ -237,15 +240,16 @@ func (s *AppTestSuite) TestAddChainConfig() { } // init app - app := relayer.NewApp(nil, newHomePath, false, tc.in.existingCfg) + fs := store.NewFileSystem(newHomePath, s.app.Passphrase) + + app := relayer.NewApp(nil, newHomePath, false, tc.in.existingCfg, "", fs) if app.Config == nil { err := app.InitConfigFile(newHomePath, "") s.Require().NoError(err) s.Require().FileExists(path.Join(newHomePath, "config", "config.toml")) - err = app.LoadConfigFile() + app.Config, err = fs.GetConfig() s.Require().NoError(err) - s.Require().NotNil(app.Config) } err = app.AddChainConfig(tc.in.chainName, tc.in.cfgPath) @@ -298,12 +302,13 @@ func (s *AppTestSuite) TestDeleteChainConfig() { for _, tc := range testcases { s.Run(tc.name, func() { - app := relayer.NewApp(nil, newHomePath, false, nil) + fs := store.NewFileSystem(newHomePath, s.app.Passphrase) + app := relayer.NewApp(nil, newHomePath, false, nil, "", fs) err := app.InitConfigFile(newHomePath, customCfgPath) s.Require().NoError(err) // load config file - err = app.LoadConfigFile() + app.Config, err = fs.GetConfig() s.Require().NoError(err) err = app.DeleteChainConfig(tc.in) @@ -487,7 +492,7 @@ func (s *AppTestSuite) TestInitPassphrase() { // Verify file content hasher := sha256.New() - hasher.Write([]byte(s.app.EnvPassphrase)) + hasher.Write([]byte(s.app.Passphrase)) expectedHash := hasher.Sum(nil) actualContent, err := os.ReadFile(passphrasePath) @@ -526,7 +531,7 @@ func (s *AppTestSuite) TestAddKey() { uint32(60), uint(0), uint(0), - s.app.EnvPassphrase, + s.app.Passphrase, ). Return(chainstypes.NewKey("", "0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266", ""), nil) }, @@ -547,7 +552,7 @@ func (s *AppTestSuite) TestAddKey() { uint32(60), uint(0), uint(0), - s.app.EnvPassphrase, + s.app.Passphrase, ). Return(nil, fmt.Errorf("add key error")) }, @@ -603,7 +608,7 @@ func (s *AppTestSuite) TestDeleteKey() { keyName: "testkey", preprocess: func() { s.chainProvider.EXPECT(). - DeleteKey(s.app.HomePath, "testkey", s.app.EnvPassphrase). + DeleteKey(s.app.HomePath, "testkey", s.app.Passphrase). Return(nil) }, }, @@ -613,7 +618,7 @@ func (s *AppTestSuite) TestDeleteKey() { keyName: "testkey", preprocess: func() { s.chainProvider.EXPECT(). - DeleteKey(s.app.HomePath, "testkey", s.app.EnvPassphrase). + DeleteKey(s.app.HomePath, "testkey", s.app.Passphrase). Return(fmt.Errorf("delete key error")) }, err: fmt.Errorf("delete key error"), @@ -658,7 +663,7 @@ func (s *AppTestSuite) TestExportKey() { keyName: "testkey", preprocess: func() { s.chainProvider.EXPECT(). - ExportPrivateKey("testkey", s.app.EnvPassphrase). + ExportPrivateKey("testkey", s.app.Passphrase). Return("0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", nil) }, out: "0xac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80", @@ -669,7 +674,7 @@ func (s *AppTestSuite) TestExportKey() { keyName: "testkey", preprocess: func() { s.chainProvider.EXPECT(). - ExportPrivateKey("testkey", s.app.EnvPassphrase). + ExportPrivateKey("testkey", s.app.Passphrase). Return("", fmt.Errorf("export key error")) }, err: fmt.Errorf("export key error"), diff --git a/relayer/chains/config.go b/relayer/chains/config.go index 3ea25d2..47e9073 100644 --- a/relayer/chains/config.go +++ b/relayer/chains/config.go @@ -4,6 +4,8 @@ import ( "time" "go.uber.org/zap" + + "github.com/bandprotocol/falcon/relayer/wallet" ) // BaseChainProviderConfig contains common field for particular chain provider. @@ -21,9 +23,6 @@ type BaseChainProviderConfig struct { // ChainProviderConfigs is a collection of ChainProviderConfig interfaces (mapped by chainName) type ChainProviderConfigs map[string]ChainProviderConfig -// ChainProviders is a collection of ChainProvider interfaces (mapped by chainName) -type ChainProviders map[string]ChainProvider - // ChainProviderConfig defines the interface for creating a chain provider object. type ChainProviderConfig interface { NewChainProvider( @@ -31,6 +30,9 @@ type ChainProviderConfig interface { log *zap.Logger, homePath string, debug bool, + wallet wallet.Wallet, ) (ChainProvider, error) + + GetChainType() ChainType Validate() error } diff --git a/relayer/chains/evm/config.go b/relayer/chains/evm/config.go index 9bf5c0f..daa25ad 100644 --- a/relayer/chains/evm/config.go +++ b/relayer/chains/evm/config.go @@ -6,6 +6,7 @@ import ( "go.uber.org/zap" "github.com/bandprotocol/falcon/relayer/chains" + "github.com/bandprotocol/falcon/relayer/wallet" ) var _ chains.ChainProviderConfig = &EVMChainProviderConfig{} @@ -34,13 +35,18 @@ func (cpc *EVMChainProviderConfig) NewChainProvider( log *zap.Logger, homePath string, debug bool, + wallet wallet.Wallet, ) (chains.ChainProvider, error) { client := NewClient(chainName, cpc, log) - return NewEVMChainProvider(chainName, client, cpc, log, homePath) + return NewEVMChainProvider(chainName, client, cpc, log, homePath, wallet) } // Validate validates the EVM chain provider configuration. func (cpc *EVMChainProviderConfig) Validate() error { return nil } + +func (cpc *EVMChainProviderConfig) GetChainType() chains.ChainType { + return chains.ChainTypeEVM +} diff --git a/relayer/chains/evm/keys.go b/relayer/chains/evm/keys.go index 70bd672..1b46a06 100644 --- a/relayer/chains/evm/keys.go +++ b/relayer/chains/evm/keys.go @@ -4,17 +4,12 @@ import ( "crypto/ecdsa" "encoding/hex" "fmt" - "os" - "path" - "github.com/ethereum/go-ethereum/accounts" - keyStore "github.com/ethereum/go-ethereum/accounts/keystore" "github.com/ethereum/go-ethereum/crypto" hdwallet "github.com/miguelmota/go-ethereum-hdwallet" - "github.com/pelletier/go-toml/v2" - "github.com/bandprotocol/falcon/internal" chainstypes "github.com/bandprotocol/falcon/relayer/chains/types" + "github.com/bandprotocol/falcon/relayer/wallet" ) const ( @@ -77,7 +72,7 @@ func (cp *EVMChainProvider) AddKeyWithMnemonic( return nil, err } - return cp.finalizeKeyAddition(keyName, priv, mnemonic, homePath, passphrase) + return cp.finalizeKeyAddition(keyName, priv, mnemonic) } // AddKeyWithPrivateKey adds a key using a raw private key. @@ -94,7 +89,7 @@ func (cp *EVMChainProvider) AddKeyWithPrivateKey( } // No mnemonic is used, so pass an empty string - return cp.finalizeKeyAddition(keyName, priv, "", homePath, passphrase) + return cp.finalizeKeyAddition(keyName, priv, "") } // finalizeKeyAddition stores the private key and initializes the sender. @@ -102,49 +97,18 @@ func (cp *EVMChainProvider) finalizeKeyAddition( keyName string, priv *ecdsa.PrivateKey, mnemonic string, - homePath string, - passphrase string, ) (*chainstypes.Key, error) { - // Get public key from private key - publicKeyECDSA, ok := priv.Public().(*ecdsa.PublicKey) - if !ok { - return nil, fmt.Errorf("cannot assert type to *ecdsa.PublicKey") - } - - // Store private key and get account info - _, err := cp.storePrivateKey(priv, passphrase) + addr, err := cp.Wallet.SavePrivateKey(keyName, priv) if err != nil { return nil, err } - addressHex := crypto.PubkeyToAddress(*publicKeyECDSA).String() - - // Store key info and finalize - cp.KeyInfo[keyName] = addressHex - if err := cp.storeKeyInfo(homePath); err != nil { - return nil, err - } - - return chainstypes.NewKey(mnemonic, addressHex, ""), nil + return chainstypes.NewKey(mnemonic, addr, ""), nil } // DeleteKey deletes the given key name from the key store and removes its information. func (cp *EVMChainProvider) DeleteKey(homePath, keyName, passphrase string) error { - if !cp.IsKeyNameExist(keyName) { - return fmt.Errorf("key name does not exist: %s", keyName) - } - - address, err := HexToAddress(cp.KeyInfo[keyName]) - if err != nil { - return err - } - if err := cp.KeyStore.Delete(accounts.Account{Address: address}, passphrase); err != nil { - return err - } - - delete(cp.KeyInfo, keyName) - - return cp.storeKeyInfo(homePath) + return cp.Wallet.DeletePrivateKey(keyName) } // ExportPrivateKey exports private key of given key name. @@ -153,7 +117,7 @@ func (cp *EVMChainProvider) ExportPrivateKey(keyName, passphrase string) (string return "", fmt.Errorf("key name does not exist: %s", keyName) } - key, err := cp.GetKeyFromKeyName(keyName, passphrase) + key, err := cp.GetKeyFromKeyName(keyName) if err != nil { return "", err } @@ -162,70 +126,34 @@ func (cp *EVMChainProvider) ExportPrivateKey(keyName, passphrase string) (string // ListKeys lists all keys. func (cp *EVMChainProvider) ListKeys() []*chainstypes.Key { - res := make([]*chainstypes.Key, 0, len(cp.KeyInfo)) - for keyName, address := range cp.KeyInfo { + keyNames := cp.Wallet.GetNames() + + res := make([]*chainstypes.Key, 0, len(keyNames)) + for _, keyName := range keyNames { + address, _ := cp.Wallet.GetAddress(keyName) key := chainstypes.NewKey("", address, keyName) res = append(res, key) } + return res } // ShowKey shows key by the given name. func (cp *EVMChainProvider) ShowKey(keyName string) (string, error) { - if !cp.IsKeyNameExist(keyName) { + address, ok := cp.Wallet.GetAddress(keyName) + if !ok { return "", fmt.Errorf("key name does not exist: %s", keyName) } - return cp.KeyInfo[keyName], nil + return address, nil } // IsKeyNameExist checks whether the given key name is already in use. func (cp *EVMChainProvider) IsKeyNameExist(keyName string) bool { - _, ok := cp.KeyInfo[keyName] + _, ok := cp.Wallet.GetAddress(keyName) return ok } -// storePrivateKey stores private key to keyStore. -func (cp *EVMChainProvider) storePrivateKey( - priv *ecdsa.PrivateKey, - passphrase string, -) (*accounts.Account, error) { - accs, err := cp.KeyStore.ImportECDSA(priv, passphrase) - if err != nil { - return nil, err - } - return &accs, nil -} - -// storeKeyInfo stores key information. -func (cp *EVMChainProvider) storeKeyInfo(homePath string) error { - b, err := toml.Marshal(cp.KeyInfo) - if err != nil { - return err - } - - keyInfoDir := path.Join(homePath, keyDir, cp.ChainName, infoDir) - keyInfoPath := path.Join(keyInfoDir, infoFileName) - - // Create the info folder if doesn't exist - if err := internal.CheckAndCreateFolder(keyInfoDir); err != nil { - return err - } - - // Create the file and write the default config to the given location. - f, err := os.Create(keyInfoPath) - if err != nil { - return err - } - defer f.Close() - - if _, err = f.Write(b); err != nil { - return err - } - - return nil -} - // generatePrivateKey generates private key from given mnemonic. func (cp *EVMChainProvider) generatePrivateKey( mnemonic string, @@ -251,23 +179,6 @@ func (cp *EVMChainProvider) generatePrivateKey( return privatekey, nil } -func (cp *EVMChainProvider) GetKeyFromKeyName(keyName, passphrase string) (*keyStore.Key, error) { - address, err := HexToAddress(cp.KeyInfo[keyName]) - if err != nil { - return nil, err - } - - accs, err := cp.KeyStore.Find(accounts.Account{Address: address}) - if err != nil { - return nil, err - } - b, err := cp.KeyStore.Export(accs, passphrase, passphrase) - if err != nil { - return nil, err - } - key, err := keyStore.DecryptKey(b, passphrase) - if err != nil { - return nil, err - } - return key, nil +func (cp *EVMChainProvider) GetKeyFromKeyName(keyName string) (*wallet.Key, error) { + return cp.Wallet.GetKey(keyName) } diff --git a/relayer/chains/evm/keys_test.go b/relayer/chains/evm/keys_test.go index a514c23..8bb52ee 100644 --- a/relayer/chains/evm/keys_test.go +++ b/relayer/chains/evm/keys_test.go @@ -11,6 +11,7 @@ import ( "github.com/bandprotocol/falcon/relayer/chains/evm" chaintypes "github.com/bandprotocol/falcon/relayer/chains/types" + "github.com/bandprotocol/falcon/relayer/wallet" ) const ( @@ -39,7 +40,10 @@ func (s *KeysTestSuite) SetupTest() { chainName := "testnet" client := evm.NewClient(chainName, evmCfg, s.log) - chainProvider, err := evm.NewEVMChainProvider(chainName, client, evmCfg, s.log, s.homePath) + wallet, err := wallet.NewGethKeyStoreWallet("", s.homePath, chainName) + s.Require().NoError(err) + + chainProvider, err := evm.NewEVMChainProvider(chainName, client, evmCfg, s.log, s.homePath, wallet) s.Require().NoError(err) s.chainProvider = chainProvider } @@ -229,10 +233,6 @@ func (s *KeysTestSuite) TestDeleteKey() { // Ensure the key is no longer in the KeyInfo or KeyStore s.Require().False(s.chainProvider.IsKeyNameExist(keyName)) - addr, err := evm.HexToAddress(testAddress) - s.Require().NoError(err) - s.Require().False(s.chainProvider.KeyStore.HasAddress(addr)) - // Delete the key again should return error err = s.chainProvider.DeleteKey(s.homePath, keyName, "") s.Require().ErrorContains(err, "key name does not exist") @@ -327,7 +327,12 @@ func (s *KeysTestSuite) TestShowKey() { } func (s *KeysTestSuite) TestIsKeyNameExist() { - s.chainProvider.KeyInfo["testkey1"] = testAddress + priv, err := crypto.HexToECDSA(evm.StripPrivateKeyPrefix(testPrivateKey)) + s.Require().NoError(err) + + _, err = s.chainProvider.Wallet.SavePrivateKey("testkey1", priv) + s.Require().NoError(err) + expected := s.chainProvider.IsKeyNameExist("testkey1") s.Require().Equal(expected, true) @@ -345,7 +350,7 @@ func (s *KeysTestSuite) TestGetKeyFromKeyName() { s.Require().NoError(err) // Retrieve the key using the key name - key, err := s.chainProvider.GetKeyFromKeyName(keyName, "") + key, err := s.chainProvider.GetKeyFromKeyName(keyName) s.Require().NoError(err) s.Require().NotNil(key) @@ -353,6 +358,12 @@ func (s *KeysTestSuite) TestGetKeyFromKeyName() { s.Require().Equal(testPrivateKey[2:], hex.EncodeToString(crypto.FromECDSA(key.PrivateKey))) // Remove "0x" // Retrieve the key using the invalid passphrase should return error - _, err = s.chainProvider.GetKeyFromKeyName(keyName, "invalid") + _, err = s.chainProvider.GetKeyFromKeyName(keyName) + s.Require().NoError(err) + + s.chainProvider.Wallet, err = wallet.NewGethKeyStoreWallet("invalid", s.homePath, s.chainProvider.ChainName) + s.Require().NoError(err) + + _, err = s.chainProvider.GetKeyFromKeyName(keyName) s.Require().ErrorContains(err, "could not decrypt key with given password") } diff --git a/relayer/chains/evm/provider.go b/relayer/chains/evm/provider.go index 9b4f1e3..884fa42 100644 --- a/relayer/chains/evm/provider.go +++ b/relayer/chains/evm/provider.go @@ -4,13 +4,11 @@ import ( "context" "fmt" "math/big" - "path" "strings" "time" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts/abi" - keyStore "github.com/ethereum/go-ethereum/accounts/keystore" gethcommon "github.com/ethereum/go-ethereum/common" gethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/shopspring/decimal" @@ -19,6 +17,7 @@ import ( bandtypes "github.com/bandprotocol/falcon/relayer/band/types" "github.com/bandprotocol/falcon/relayer/chains" chainstypes "github.com/bandprotocol/falcon/relayer/chains/types" + "github.com/bandprotocol/falcon/relayer/wallet" ) var _ chains.ChainProvider = (*EVMChainProvider)(nil) @@ -31,7 +30,6 @@ type EVMChainProvider struct { Client Client GasType GasType - KeyInfo KeyInfo FreeSenders chan *Sender TunnelRouterAddress gethcommon.Address @@ -39,7 +37,7 @@ type EVMChainProvider struct { Log *zap.Logger - KeyStore *keyStore.KeyStore + Wallet wallet.Wallet } // NewEVMChainProvider creates a new EVM chain provider. @@ -49,6 +47,7 @@ func NewEVMChainProvider( cfg *EVMChainProviderConfig, log *zap.Logger, homePath string, + wallet wallet.Wallet, ) (*EVMChainProvider, error) { // load abis here abi, err := abi.JSON(strings.NewReader(gasPriceTunnelRouterABI)) @@ -69,24 +68,15 @@ func NewEVMChainProvider( return nil, fmt.Errorf("[EVMProvider] incorrect address: %w", err) } - keyStoreDir := path.Join(homePath, keyDir, chainName, privateKeyDir) - keyStore := keyStore.NewKeyStore(keyStoreDir, keyStore.StandardScryptN, keyStore.StandardScryptP) - - keyInfo, err := LoadKeyInfo(homePath, chainName) - if err != nil { - return nil, err - } - return &EVMChainProvider{ Config: cfg, ChainName: chainName, Client: client, GasType: cfg.GasType, - KeyInfo: keyInfo, TunnelRouterAddress: addr, TunnelRouterABI: abi, Log: log.With(zap.String("chain_name", chainName)), - KeyStore: keyStore, + Wallet: wallet, }, nil } @@ -555,7 +545,13 @@ func (cp *EVMChainProvider) QueryBalance( return nil, fmt.Errorf("[EVMProvider] failed to connect client: %w", err) } - address, err := HexToAddress(cp.KeyInfo[keyName]) + hexAddr, ok := cp.Wallet.GetAddress(keyName) + if !ok { + cp.Log.Error("Key name does not exist", zap.String("key_name", keyName)) + return nil, fmt.Errorf("key name does not exist: %s", keyName) + } + + address, err := HexToAddress(hexAddr) if err != nil { return nil, err } diff --git a/relayer/chains/evm/provider_eip1559_test.go b/relayer/chains/evm/provider_eip1559_test.go index 7abf9de..fd98e5f 100644 --- a/relayer/chains/evm/provider_eip1559_test.go +++ b/relayer/chains/evm/provider_eip1559_test.go @@ -17,6 +17,7 @@ import ( "github.com/bandprotocol/falcon/internal/relayertest/mocks" bandtypes "github.com/bandprotocol/falcon/relayer/band/types" "github.com/bandprotocol/falcon/relayer/chains/evm" + "github.com/bandprotocol/falcon/relayer/wallet" ) type EIP1559ProviderTestSuite struct { @@ -46,7 +47,11 @@ func (s *EIP1559ProviderTestSuite) SetupTest() { evmConfig.GasType = evm.GasTypeEIP1559 s.chainName = "testnet" s.homePath = s.T().TempDir() - chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), s.homePath) + + wallet, err := wallet.NewGethKeyStoreWallet("", s.homePath, s.chainName) + s.Require().NoError(err) + + chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), s.homePath, wallet) s.Require().NoError(err) s.chainProvider = chainProvider diff --git a/relayer/chains/evm/provider_legacy_test.go b/relayer/chains/evm/provider_legacy_test.go index ffb0d64..4640c62 100644 --- a/relayer/chains/evm/provider_legacy_test.go +++ b/relayer/chains/evm/provider_legacy_test.go @@ -17,6 +17,7 @@ import ( "github.com/bandprotocol/falcon/internal/relayertest/mocks" bandtypes "github.com/bandprotocol/falcon/relayer/band/types" "github.com/bandprotocol/falcon/relayer/chains/evm" + "github.com/bandprotocol/falcon/relayer/wallet" ) type LegacyProviderTestSuite struct { @@ -46,7 +47,11 @@ func (s *LegacyProviderTestSuite) SetupTest() { evmConfig.GasType = evm.GasTypeLegacy s.chainName = "testnet" s.homePath = s.T().TempDir() - chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), s.homePath) + + wallet, err := wallet.NewGethKeyStoreWallet("", s.homePath, s.chainName) + s.Require().NoError(err) + + chainProvider, err := evm.NewEVMChainProvider(s.chainName, s.client, &evmConfig, zap.NewNop(), s.homePath, wallet) s.Require().NoError(err) s.chainProvider = chainProvider diff --git a/relayer/chains/evm/provider_test.go b/relayer/chains/evm/provider_test.go index 0ad145f..287d4c8 100644 --- a/relayer/chains/evm/provider_test.go +++ b/relayer/chains/evm/provider_test.go @@ -22,6 +22,7 @@ import ( "github.com/bandprotocol/falcon/relayer/chains" "github.com/bandprotocol/falcon/relayer/chains/evm" chaintypes "github.com/bandprotocol/falcon/relayer/chains/types" + "github.com/bandprotocol/falcon/relayer/wallet" ) var baseEVMCfg = &evm.EVMChainProviderConfig{ @@ -106,6 +107,7 @@ func TestProviderTestSuite(t *testing.T) { func (s *ProviderTestSuite) SetupTest() { var err error tmpDir := s.T().TempDir() + s.homePath = tmpDir s.ctrl = gomock.NewController(s.T()) s.client = mocks.NewMockEVMClient(s.ctrl) @@ -116,11 +118,13 @@ func (s *ProviderTestSuite) SetupTest() { chainName := "testnet" s.chainName = chainName - s.chainProvider, err = evm.NewEVMChainProvider(s.chainName, s.client, baseEVMCfg, s.log, s.homePath) + wallet, err := wallet.NewGethKeyStoreWallet("", s.homePath, s.chainName) + s.Require().NoError(err) + + s.chainProvider, err = evm.NewEVMChainProvider(s.chainName, s.client, baseEVMCfg, s.log, s.homePath, wallet) s.Require().NoError(err) s.chainProvider.Client = s.client - s.homePath = tmpDir } func (s *ProviderTestSuite) TestQueryTunnelInfo() { diff --git a/relayer/chains/evm/sender.go b/relayer/chains/evm/sender.go index b55b054..38f40d0 100644 --- a/relayer/chains/evm/sender.go +++ b/relayer/chains/evm/sender.go @@ -36,15 +36,17 @@ func (cp *EVMChainProvider) LoadFreeSenders( return nil } - freeSenders := make(chan *Sender, len(cp.KeyInfo)) + keyNames := cp.Wallet.GetNames() + freeSenders := make(chan *Sender, len(keyNames)) - for keyName := range cp.KeyInfo { - key, err := cp.GetKeyFromKeyName(keyName, passphrase) + for _, keyName := range keyNames { + key, err := cp.GetKeyFromKeyName(keyName) if err != nil { return err } - freeSenders <- NewSender(key.PrivateKey, key.Address) + addr := gethcommon.HexToAddress(key.Address) + freeSenders <- NewSender(key.PrivateKey, addr) } cp.FreeSenders = freeSenders diff --git a/relayer/chains/evm/sender_test.go b/relayer/chains/evm/sender_test.go index fd803c4..9bf24a9 100644 --- a/relayer/chains/evm/sender_test.go +++ b/relayer/chains/evm/sender_test.go @@ -16,6 +16,7 @@ import ( "github.com/bandprotocol/falcon/relayer/chains" "github.com/bandprotocol/falcon/relayer/chains/evm" + "github.com/bandprotocol/falcon/relayer/wallet" ) const ( @@ -60,6 +61,7 @@ func TestSenderTestSuite(t *testing.T) { func (s *SenderTestSuite) SetupTest() { var err error tmpDir := s.T().TempDir() + s.homePath = tmpDir log, err := zap.NewDevelopment() s.Require().NoError(err) @@ -71,11 +73,13 @@ func (s *SenderTestSuite) SetupTest() { client := evm.NewClient(chainName, evmCfg, log) - s.chainProvider, err = evm.NewEVMChainProvider(chainName, client, evmCfg, log, tmpDir) + wallet, err := wallet.NewGethKeyStoreWallet("", s.homePath, chainName) + s.Require().NoError(err) + + s.chainProvider, err = evm.NewEVMChainProvider(chainName, client, evmCfg, log, tmpDir, wallet) s.Require().NoError(err) s.ctx = context.Background() - s.homePath = tmpDir } func TestLoadKeyInfo(t *testing.T) { @@ -125,7 +129,7 @@ func (s *SenderTestSuite) TestLoadFreeSenders() { s.Require().NoError(err) // Validate the FreeSenders channel is populated correctly - count := len(s.chainProvider.KeyInfo) + count := len(s.chainProvider.Wallet.GetNames()) s.Require(). Equal(count, len(s.chainProvider.FreeSenders)) diff --git a/relayer/chains/provider.go b/relayer/chains/provider.go index ffb586c..2828ee0 100644 --- a/relayer/chains/provider.go +++ b/relayer/chains/provider.go @@ -10,6 +10,9 @@ import ( chainstypes "github.com/bandprotocol/falcon/relayer/chains/types" ) +// ChainProviders is a collection of ChainProvider interfaces (mapped by chainName) +type ChainProviders map[string]ChainProvider + // ChainProvider defines the interface for the chain interaction with the destination chain. type ChainProvider interface { KeyProvider diff --git a/relayer/config.go b/relayer/config/config.go similarity index 86% rename from relayer/config.go rename to relayer/config/config.go index a0fe8c5..84884b0 100644 --- a/relayer/config.go +++ b/relayer/config/config.go @@ -1,4 +1,4 @@ -package relayer +package config import ( "fmt" @@ -14,6 +14,9 @@ import ( "github.com/bandprotocol/falcon/relayer/chains/evm" ) +// ChainProviderConfigs is a collection of ChainProviderConfig interfaces (mapped by chainName) +type ChainProviderConfigs map[string]chains.ChainProviderConfig + // GlobalConfig is the global configuration for the falcon tunnel relayer type GlobalConfig struct { LogLevel string `mapstructure:"log_level" toml:"log_level"` @@ -25,9 +28,9 @@ type GlobalConfig struct { // Config defines the configuration for the falcon tunnel relayer. type Config struct { - Global GlobalConfig `mapstructure:"global" toml:"global"` - BandChain band.Config `mapstructure:"bandchain" toml:"bandchain"` - TargetChains chains.ChainProviderConfigs `mapstructure:"target_chains" toml:"target_chains"` + Global GlobalConfig `mapstructure:"global" toml:"global"` + BandChain band.Config `mapstructure:"bandchain" toml:"bandchain"` + TargetChains ChainProviderConfigs `mapstructure:"target_chains" toml:"target_chains"` } // ChainProviderConfigWrapper is an intermediary type for parsing any object from config.toml file @@ -106,9 +109,9 @@ func DecodeConfigInputWrapperTOML(data []byte, cw *ConfigInputWrapper) error { return nil } -// ParseConfig converts a ConfigInputWrapper object to a Config object. -func ParseConfig(wrappedCfg *ConfigInputWrapper) (*Config, error) { - targetChains := make(chains.ChainProviderConfigs) +// ParseConfigInputWrapper converts a ConfigInputWrapper object to a Config object. +func ParseConfigInputWrapper(wrappedCfg *ConfigInputWrapper) (*Config, error) { + targetChains := make(ChainProviderConfigs) for name, provCfg := range wrappedCfg.TargetChains { newProvCfg, err := ParseChainProviderConfig(provCfg) if err != nil { @@ -143,20 +146,14 @@ func DefaultConfig() *Config { } } -// LoadConfig reads config file from given path and return config object -func LoadConfig(cfgPath string) (*Config, error) { - b, err := os.ReadFile(cfgPath) - if err != nil { - return nil, err - } - +func ParseConfig(data []byte) (*Config, error) { var cfgWrapper ConfigInputWrapper - if err := DecodeConfigInputWrapperTOML(b, &cfgWrapper); err != nil { + if err := DecodeConfigInputWrapperTOML(data, &cfgWrapper); err != nil { return nil, err } // convert ConfigWrapperInput to Config - cfg, err := ParseConfig(&cfgWrapper) + cfg, err := ParseConfigInputWrapper(&cfgWrapper) if err != nil { return nil, err } diff --git a/relayer/config_test.go b/relayer/config/config_test.go similarity index 75% rename from relayer/config_test.go rename to relayer/config/config_test.go index 2a98860..f40e4b4 100644 --- a/relayer/config_test.go +++ b/relayer/config/config_test.go @@ -1,4 +1,4 @@ -package relayer_test +package config_test import ( "fmt" @@ -14,9 +14,11 @@ import ( "github.com/bandprotocol/falcon/relayer/band" "github.com/bandprotocol/falcon/relayer/chains" "github.com/bandprotocol/falcon/relayer/chains/evm" + "github.com/bandprotocol/falcon/relayer/config" + "github.com/bandprotocol/falcon/relayer/store" ) -func TestLoadConfig(t *testing.T) { +func TestParseConfig(t *testing.T) { tmpDir := t.TempDir() cfgPath := path.Join(tmpDir, "config", "config.toml") @@ -24,26 +26,23 @@ func TestLoadConfig(t *testing.T) { name string preProcess func(t *testing.T) postProcess func(t *testing.T) - out *relayer.Config + out *config.Config err error }{ { name: "read default config", preProcess: func(t *testing.T) { - app := relayer.NewApp(nil, tmpDir, false, nil) + fs := store.NewFileSystem(tmpDir, "") + app := relayer.NewApp(nil, tmpDir, false, nil, "", fs) err := app.InitConfigFile(tmpDir, "") require.NoError(t, err) }, - out: relayer.DefaultConfig(), + out: config.DefaultConfig(), postProcess: func(t *testing.T) { err := os.Remove(cfgPath) require.NoError(t, err) }, }, - { - name: "no config file", - err: fmt.Errorf("no such file or directory"), - }, { name: "invalid config file; invalid chain type", preProcess: func(t *testing.T) { @@ -73,7 +72,10 @@ func TestLoadConfig(t *testing.T) { defer tc.postProcess(t) } - actual, err := relayer.LoadConfig(cfgPath) + data, err := os.ReadFile(cfgPath) + require.NoError(t, err) + + actual, err := config.ParseConfig(data) if tc.err != nil { require.ErrorContains(t, err, tc.err.Error()) } else { @@ -87,13 +89,13 @@ func TestLoadConfig(t *testing.T) { func TestParseChainProviderConfig(t *testing.T) { testcases := []struct { name string - in relayer.ChainProviderConfigWrapper + in config.ChainProviderConfigWrapper out chains.ChainProviderConfig err error }{ { name: "valid evm chain", - in: relayer.ChainProviderConfigWrapper{ + in: config.ChainProviderConfigWrapper{ "chain_type": "evm", "endpoints": []string{"http://localhost:8545"}, }, @@ -106,7 +108,7 @@ func TestParseChainProviderConfig(t *testing.T) { }, { name: "chain type not found", - in: relayer.ChainProviderConfigWrapper{ + in: config.ChainProviderConfigWrapper{ "chain_type": "evms", "endpoints": []string{"http://localhost:8545"}, }, @@ -114,14 +116,14 @@ func TestParseChainProviderConfig(t *testing.T) { }, { name: "missing chain type", - in: relayer.ChainProviderConfigWrapper{ + in: config.ChainProviderConfigWrapper{ "endpoints": []string{"http://localhost:8545"}, }, err: fmt.Errorf("chain_type is required"), }, { name: "chain type not string", - in: relayer.ChainProviderConfigWrapper{ + in: config.ChainProviderConfigWrapper{ "chain_type": []string{"evm"}, "endpoints": []string{"http://localhost:8545"}, }, @@ -131,7 +133,7 @@ func TestParseChainProviderConfig(t *testing.T) { for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - actual, err := relayer.ParseChainProviderConfig(tc.in) + actual, err := config.ParseChainProviderConfig(tc.in) if tc.err != nil { require.ErrorContains(t, err, tc.err.Error()) } else { @@ -142,34 +144,34 @@ func TestParseChainProviderConfig(t *testing.T) { } } -func TestParseConfigInvalidChainProviderConfig(t *testing.T) { - w := &relayer.ConfigInputWrapper{ - Global: relayer.GlobalConfig{CheckingPacketInterval: 1}, +func TestParseConfigInputWrapperInvalidChainProviderConfig(t *testing.T) { + w := &config.ConfigInputWrapper{ + Global: config.GlobalConfig{CheckingPacketInterval: 1}, BandChain: band.Config{ RpcEndpoints: []string{"http://localhost:26657", "http://localhost:26658"}, Timeout: 0, }, - TargetChains: map[string]relayer.ChainProviderConfigWrapper{ + TargetChains: map[string]config.ChainProviderConfigWrapper{ "testnet": { "chain_type": "evms", }, }, } - _, err := relayer.ParseConfig(w) + _, err := config.ParseConfigInputWrapper(w) require.ErrorContains(t, err, "unsupported chain type: evms") } -func TestUnmarshalConfig(t *testing.T) { +func TestParseConfigInputWrapper(t *testing.T) { // create new toml config file cfgText := relayertest.CustomCfgText // unmarshal them with Config into struct - var cfgWrapper relayer.ConfigInputWrapper - err := relayer.DecodeConfigInputWrapperTOML([]byte(cfgText), &cfgWrapper) + var cfgWrapper config.ConfigInputWrapper + err := config.DecodeConfigInputWrapperTOML([]byte(cfgText), &cfgWrapper) require.NoError(t, err) - cfg, err := relayer.ParseConfig(&cfgWrapper) + cfg, err := config.ParseConfigInputWrapper(&cfgWrapper) require.NoError(t, err) require.Equal(t, &relayertest.CustomCfg, cfg) @@ -191,7 +193,7 @@ func TestLoadChainConfig(t *testing.T) { require.NoError(t, err) // load chain config - actual, err := relayer.LoadChainConfig(cfgPath) + actual, err := config.LoadChainConfig(cfgPath) require.NoError(t, err) expect := relayertest.CustomCfg.TargetChains[chainName] diff --git a/relayer/store/filesystem.go b/relayer/store/filesystem.go new file mode 100644 index 0000000..1642124 --- /dev/null +++ b/relayer/store/filesystem.go @@ -0,0 +1,160 @@ +package store + +import ( + "fmt" + "os" + "path" + + "github.com/pelletier/go-toml/v2" + + "github.com/bandprotocol/falcon/relayer/chains" + "github.com/bandprotocol/falcon/relayer/config" + "github.com/bandprotocol/falcon/relayer/wallet" +) + +var _ Store = &FileSystem{} + +const ( + cfgDir = "config" + cfgFileName = "config.toml" + passphraseFileName = "passphrase.hash" +) + +type FileSystem struct { + HomePath string + Passphrase string +} + +// NewFileSystem creates a new filesystem store. +func NewFileSystem(homePath, passphrase string) *FileSystem { + return &FileSystem{ + HomePath: homePath, + Passphrase: passphrase, + } +} + +// HasConfig checks if the config file exists in the filesystem. +func (fs *FileSystem) HasConfig() (bool, error) { + cfgPath := fs.getConfigPath() + + // check if file doesn't exist, exit the function as the config may not be initialized. + if _, err := os.Stat(cfgPath); os.IsNotExist(err) { + return false, nil + } else if err != nil { + return false, err + } + + return true, nil +} + +// GetConfig reads the config file from the filesystem and returns the config object. +func (fs *FileSystem) GetConfig() (*config.Config, error) { + cfgPath := fs.getConfigPath() + + if ok, err := fs.HasConfig(); err != nil { + return nil, err + } else if !ok { + return nil, nil + } + + b, err := os.ReadFile(cfgPath) + if err != nil { + return nil, err + } + + return config.ParseConfig(b) +} + +// SaveConfig saves the given config object to the filesystem. +func (fs *FileSystem) SaveConfig(cfg *config.Config) error { + // Marshal config object into bytes + b, err := toml.Marshal(cfg) + if err != nil { + return err + } + + // Create the home and config folder if doesn't exist + if err := checkAndCreateFolder(fs.HomePath); err != nil { + return err + } + if err := checkAndCreateFolder(path.Join(fs.HomePath, cfgDir)); err != nil { + return err + } + + // Create the file and write the config to the given location. + f, err := os.Create(fs.getConfigPath()) + if err != nil { + return err + } + defer f.Close() + + if _, err = f.Write(b); err != nil { + return err + } + + return nil +} + +// GetPassphrase reads the passphrase from the filesystem and returns it. +func (fs *FileSystem) GetPassphrase() ([]byte, error) { + return os.ReadFile(fs.getPassphrasePath()) +} + +// SavePassphrase saves the given passphrase to the filesystem. +func (fs *FileSystem) SavePassphrase(passphrase []byte) error { + // Create the home and config folder if doesn't exist + if err := checkAndCreateFolder(fs.HomePath); err != nil { + return err + } + if err := checkAndCreateFolder(path.Join(fs.HomePath, cfgDir)); err != nil { + return err + } + + // Create the file and write the passphrase to the given location. + f, err := os.Create(fs.getPassphrasePath()) + if err != nil { + return err + } + defer f.Close() + + if _, err = f.Write(passphrase); err != nil { + return err + } + + return nil +} + +// getConfigPath returns the path to the config file. +func (fs *FileSystem) getConfigPath() string { + return path.Join(fs.HomePath, cfgDir, cfgFileName) +} + +// getPassphrasePath returns the path to the passphrase file. +func (fs *FileSystem) getPassphrasePath() string { + return path.Join(fs.HomePath, cfgDir, passphraseFileName) +} + +func (fs *FileSystem) NewWallet(chainType chains.ChainType, chainName string) (wallet.Wallet, error) { + switch chainType { + case chains.ChainTypeEVM: + return wallet.NewGethKeyStoreWallet(fs.Passphrase, fs.HomePath, chainName) + default: + return nil, fmt.Errorf("unsupported chain type: %s", chainType) + } +} + +// checkAndCreateFolder checks if the folder exists and creates it if it doesn't. +func checkAndCreateFolder(path string) error { + // If the folder exists and no error, return nil + _, err := os.Stat(path) + if err == nil { + return nil + } + + // If the folder does not exist, create it. + if os.IsNotExist(err) { + return os.Mkdir(path, os.ModePerm) + } + + return err +} diff --git a/relayer/store/filesystem_wallet.go b/relayer/store/filesystem_wallet.go new file mode 100644 index 0000000..1d7d3fb --- /dev/null +++ b/relayer/store/filesystem_wallet.go @@ -0,0 +1,35 @@ +package store + +import ( + "fmt" + + "github.com/bandprotocol/falcon/relayer/chains" + "github.com/bandprotocol/falcon/relayer/wallet" +) + +type WalletFactory interface { + NewWallet(chainType chains.ChainType, chainName string) (wallet.Wallet, error) +} + +var _ WalletFactory = &FileSystemWalletFactory{} + +type FileSystemWalletFactory struct { + HomePath string + Passphrase string +} + +func NewFileSystemWalletFactory(homePath, passphrase string) *FileSystemWalletFactory { + return &FileSystemWalletFactory{ + HomePath: homePath, + Passphrase: passphrase, + } +} + +func (fs *FileSystemWalletFactory) NewWallet(chainType chains.ChainType, chainName string) (wallet.Wallet, error) { + switch chainType { + case chains.ChainTypeEVM: + return wallet.NewGethKeyStoreWallet(fs.Passphrase, fs.HomePath, chainName) + default: + return nil, fmt.Errorf("unsupported chain type: %s", chainType) + } +} diff --git a/relayer/store/store.go b/relayer/store/store.go new file mode 100644 index 0000000..f9aa6c5 --- /dev/null +++ b/relayer/store/store.go @@ -0,0 +1,16 @@ +package store + +import ( + "github.com/bandprotocol/falcon/relayer/chains" + "github.com/bandprotocol/falcon/relayer/config" + "github.com/bandprotocol/falcon/relayer/wallet" +) + +type Store interface { + HasConfig() (bool, error) + GetConfig() (*config.Config, error) + SaveConfig(cfg *config.Config) error + GetPassphrase() ([]byte, error) + SavePassphrase(passphrase []byte) error + NewWallet(chainType chains.ChainType, chainName string) (wallet.Wallet, error) +} diff --git a/relayer/wallet/geth_keystore.go b/relayer/wallet/geth_keystore.go new file mode 100644 index 0000000..0375aca --- /dev/null +++ b/relayer/wallet/geth_keystore.go @@ -0,0 +1,179 @@ +package wallet + +import ( + "crypto/ecdsa" + "fmt" + "os" + "path" + + "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/accounts/keystore" + "github.com/pelletier/go-toml/v2" + + "github.com/bandprotocol/falcon/internal" +) + +var _ Wallet = &GethKeyStoreWallet{} + +type GethKeyStoreWallet struct { + passphrase string + store *keystore.KeyStore + keyNameToHexAddress map[string]string + keyNamePath string +} + +// NewGethKeyStoreWallet creates a new GethKeyStoreWallet instance +func NewGethKeyStoreWallet(passphrase, homeDir, chainName string) (*GethKeyStoreWallet, error) { + // create folders if not exists + if err := internal.CheckAndCreateFolder(homeDir); err != nil { + return nil, err + } + + keyDir := path.Join(homeDir, "keys") + if err := internal.CheckAndCreateFolder(keyDir); err != nil { + return nil, err + } + + keyChainDir := path.Join(keyDir, chainName) + if err := internal.CheckAndCreateFolder(keyChainDir); err != nil { + return nil, err + } + + keyStoreDir := path.Join(keyChainDir, "priv") + if err := internal.CheckAndCreateFolder(keyStoreDir); err != nil { + return nil, err + } + + keyNameDir := path.Join(keyChainDir, "info") + if err := internal.CheckAndCreateFolder(keyNameDir); err != nil { + return nil, err + } + + keyNamePath := path.Join(keyNameDir, "info.toml") + + // create keystore + store := keystore.NewKeyStore(keyStoreDir, keystore.StandardScryptN, keystore.StandardScryptP) + + // load keyNameToHexAddress map + keyNameToHexAddress := make(map[string]string) + if _, err := os.Stat(keyNamePath); err != nil && !os.IsNotExist(err) { + return nil, err + } else if err == nil { + b, err := os.ReadFile(keyNamePath) + if err != nil { + return nil, err + } + + // unmarshal them with Config into struct + err = toml.Unmarshal(b, &keyNameToHexAddress) + if err != nil { + return nil, err + } + } + + return &GethKeyStoreWallet{ + passphrase: passphrase, + store: store, + keyNamePath: keyNamePath, + keyNameToHexAddress: keyNameToHexAddress, + }, nil +} + +// SavePrivateKey saves the private key to the keystore and returns the account and update the keyNameToHexAddress map +func (w *GethKeyStoreWallet) SavePrivateKey(name string, privKey *ecdsa.PrivateKey) (string, error) { + acc, err := w.store.ImportECDSA(privKey, w.passphrase) + if err != nil { + return "", err + } + + addr := acc.Address.Hex() + w.keyNameToHexAddress[name] = addr + if err := w.saveKeyNameToHexAddresses(); err != nil { + return "", err + } + + return addr, nil +} + +// DeletePrivateKey deletes the private key from the keystore and returns the address +func (w *GethKeyStoreWallet) DeletePrivateKey(name string) error { + hexAddr, ok := w.keyNameToHexAddress[name] + if !ok { + return fmt.Errorf("key name does not exist: %s", name) + } + + addr, err := HexToAddress(hexAddr) + if err != nil { + return err + } + + if err := w.store.Delete(accounts.Account{Address: addr}, w.passphrase); err != nil { + return err + } + + delete(w.keyNameToHexAddress, name) + + return w.saveKeyNameToHexAddresses() +} + +// GetAddress returns the address of the given key name +func (w *GethKeyStoreWallet) GetAddress(name string) (string, bool) { + addr, ok := w.keyNameToHexAddress[name] + return addr, ok +} + +// GetNames returns the list of key names +func (w *GethKeyStoreWallet) GetNames() []string { + names := make([]string, 0, len(w.keyNameToHexAddress)) + for name := range w.keyNameToHexAddress { + names = append(names, name) + } + + return names +} + +// GetKey returns the private key and address of the given key name +func (w *GethKeyStoreWallet) GetKey(name string) (*Key, error) { + hexAddr, ok := w.keyNameToHexAddress[name] + if !ok { + return nil, fmt.Errorf("key name does not exist: %s", name) + } + + gethAddr, err := HexToAddress(hexAddr) + if err != nil { + return nil, err + } + + accs, err := w.store.Find(accounts.Account{Address: gethAddr}) + if err != nil { + return nil, err + } + + // need to export the key due to no direct access to the private key + b, err := w.store.Export(accs, w.passphrase, w.passphrase) + if err != nil { + return nil, err + } + + gethKey, err := keystore.DecryptKey(b, w.passphrase) + if err != nil { + return nil, err + } + + return &Key{ + Address: gethAddr.Hex(), + PrivateKey: gethKey.PrivateKey, + }, nil +} + +// saveKeyNameToHexAddresses writes the keyNameToHexAddress map to the file +func (w *GethKeyStoreWallet) saveKeyNameToHexAddresses() error { + f, err := os.Create(w.keyNamePath) + if err != nil { + return err + } + defer f.Close() + + encoder := toml.NewEncoder(f) + return encoder.Encode(w.keyNameToHexAddress) +} diff --git a/relayer/wallet/helper.go b/relayer/wallet/helper.go new file mode 100644 index 0000000..59bb42e --- /dev/null +++ b/relayer/wallet/helper.go @@ -0,0 +1,17 @@ +package wallet + +import ( + "fmt" + + "github.com/ethereum/go-ethereum/common" +) + +// HexToAddress checks a given string and converts it to an geth address. The string must +// be align with the ^(0x)?[0-9a-fA-F]{40}$ regex format, e.g. 0xe688b84b23f322a994A53dbF8E15FA82CDB71127. +func HexToAddress(s string) (common.Address, error) { + if !common.IsHexAddress(s) { + return common.Address{}, fmt.Errorf("invalid address: %s", s) + } + + return common.HexToAddress(s), nil +} diff --git a/relayer/wallet/wallet.go b/relayer/wallet/wallet.go new file mode 100644 index 0000000..107383f --- /dev/null +++ b/relayer/wallet/wallet.go @@ -0,0 +1,18 @@ +package wallet + +import ( + "crypto/ecdsa" +) + +type Key struct { + Address string + PrivateKey *ecdsa.PrivateKey +} + +type Wallet interface { + SavePrivateKey(name string, privKey *ecdsa.PrivateKey) (addr string, err error) + DeletePrivateKey(name string) error + GetNames() []string + GetAddress(name string) (addr string, ok bool) + GetKey(name string) (key *Key, err error) +}