diff --git a/app/app.go b/app/app.go index 5ae20e82c..5147ec151 100644 --- a/app/app.go +++ b/app/app.go @@ -401,7 +401,7 @@ func New( &app.WasmKeeper, ) - app.FeeKeeper = feekeeper.NewKeeper(appCodec, keys[feetypes.StoreKey], memKeys[feetypes.MemStoreKey], app.GetSubspace(feetypes.ModuleName), app.IBCKeeper.ChannelKeeper, app.BankKeeper) + app.FeeKeeper = feekeeper.NewKeeper(appCodec, keys[feetypes.StoreKey], memKeys[feetypes.MemStoreKey], app.GetSubspace(feetypes.ModuleName), app.IBCKeeper.ChannelKeeper, app.BankKeeper, app.FeeGrantKeeper) feeModule := feerefunder.NewAppModule(appCodec, *app.FeeKeeper, app.AccountKeeper, app.BankKeeper) app.FeeBurnerKeeper = feeburnerkeeper.NewKeeper( diff --git a/go.mod b/go.mod index d232461d1..ff15ec2bf 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/CosmWasm/wasmvm v1.1.1 github.com/confio/ics23/go v0.7.0 github.com/cosmos/admin-module v0.0.0-00010101000000-000000000000 + github.com/cosmos/cosmos-proto v1.0.0-beta.1 github.com/cosmos/cosmos-sdk v0.45.7-0.20221104161803-456ca5663c5e github.com/cosmos/ibc-go/v3 v3.0.0 github.com/cosmos/interchain-security v0.2.0 @@ -17,7 +18,6 @@ require ( github.com/grpc-ecosystem/grpc-gateway v1.16.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.12.2 - github.com/regen-network/cosmos-proto v0.3.1 github.com/spf13/cast v1.5.0 github.com/spf13/cobra v1.6.1 github.com/stretchr/testify v1.8.1 @@ -105,6 +105,7 @@ require ( github.com/prometheus/procfs v0.7.3 // indirect github.com/rakyll/statik v0.1.7 // indirect github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 // indirect + github.com/regen-network/cosmos-proto v0.3.1 // indirect github.com/rs/cors v1.8.2 // indirect github.com/rs/zerolog v1.27.0 // indirect github.com/sasha-s/go-deadlock v0.2.1-0.20190427202633-1595213edefa // indirect diff --git a/go.sum b/go.sum index 46b13f8eb..1c2df165a 100644 --- a/go.sum +++ b/go.sum @@ -289,6 +289,8 @@ github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfc github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cosmos/btcutil v1.0.4 h1:n7C2ngKXo7UC9gNyMNLbzqz7Asuf+7Qv4gnX/rOdQ44= github.com/cosmos/btcutil v1.0.4/go.mod h1:Ffqc8Hn6TJUdDgHBwIZLtrLQC1KdJ9jGJl/TvgUaxbU= +github.com/cosmos/cosmos-proto v1.0.0-beta.1 h1:iDL5qh++NoXxG8hSy93FdYJut4XfgbShIocllGaXx/0= +github.com/cosmos/cosmos-proto v1.0.0-beta.1/go.mod h1:8k2GNZghi5sDRFw/scPL8gMSowT1vDA+5ouxL8GjaUE= github.com/cosmos/cosmos-sdk v0.44.2/go.mod h1:fwQJdw+aECatpTvQTo1tSfHEsxACdZYU80QCZUPnHr4= github.com/cosmos/cosmos-sdk v0.44.3/go.mod h1:bA3+VenaR/l/vDiYzaiwbWvRPWHMBX2jG0ygiFtiBp0= github.com/cosmos/cosmos-sdk v0.45.0/go.mod h1:XXS/asyCqWNWkx2rW6pSuen+EVcpAFxq6khrhnZgHaQ= diff --git a/proto/feerefunder/fee.proto b/proto/feerefunder/fee.proto index f0142f7ca..83eb0d181 100644 --- a/proto/feerefunder/fee.proto +++ b/proto/feerefunder/fee.proto @@ -28,6 +28,7 @@ message Fee { (gogoproto.nullable) = false, (gogoproto.castrepeated) = "github.com/cosmos/cosmos-sdk/types.Coins" ]; + string payer = 4; } message PacketID { diff --git a/proto/interchaintxs/v1/tx.proto b/proto/interchaintxs/v1/tx.proto index da652f7da..5e4c7f198 100644 --- a/proto/interchaintxs/v1/tx.proto +++ b/proto/interchaintxs/v1/tx.proto @@ -55,7 +55,7 @@ message MsgSubmitTx { // MsgSubmitTxResponse defines the response for Msg/SubmitTx message MsgSubmitTxResponse { - // channel's sequence_id for outgoing ibc packet. Unique per a channel. + // channel's sequence_id for outgoing ibc packet. Unique per a channel uint64 sequence_id = 1; // channel src channel on neutron side trasaction was submitted from string channel = 2; diff --git a/testutil/feerefunder/keeper/fee.go b/testutil/feerefunder/keeper/fee.go index 972627ee4..b1eb8abf1 100644 --- a/testutil/feerefunder/keeper/fee.go +++ b/testutil/feerefunder/keeper/fee.go @@ -18,7 +18,7 @@ import ( "github.com/neutron-org/neutron/x/feerefunder/types" ) -func FeeKeeper(t testing.TB, channelKeeper types.ChannelKeeper, bankKeeper types.BankKeeper) (*keeper.Keeper, sdk.Context) { +func FeeKeeper(t testing.TB, channelKeeper types.ChannelKeeper, bankKeeper types.BankKeeper, feegrantKeeper types.FeeGrantKeeper) (*keeper.Keeper, sdk.Context) { storeKey := sdk.NewKVStoreKey(types.StoreKey) memStoreKey := storetypes.NewMemoryStoreKey(types.MemStoreKey) @@ -44,6 +44,7 @@ func FeeKeeper(t testing.TB, channelKeeper types.ChannelKeeper, bankKeeper types paramsSubspace, channelKeeper, bankKeeper, + feegrantKeeper, ) ctx := sdk.NewContext(stateStore, tmproto.Header{}, false, log.NewNopLogger()) diff --git a/testutil/mocks/feegrant/types.go b/testutil/mocks/feegrant/types.go new file mode 100644 index 000000000..00f1a9c15 --- /dev/null +++ b/testutil/mocks/feegrant/types.go @@ -0,0 +1,64 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/cosmos/cosmos-sdk/x/feegrant (interfaces: FeeAllowanceI) + +// Package mock_feegrant is a generated GoMock package. +package mock_feegrant + +import ( + reflect "reflect" + + types "github.com/cosmos/cosmos-sdk/types" + gomock "github.com/golang/mock/gomock" +) + +// MockFeeAllowanceI is a mock of FeeAllowanceI interface. +type MockFeeAllowanceI struct { + ctrl *gomock.Controller + recorder *MockFeeAllowanceIMockRecorder +} + +// MockFeeAllowanceIMockRecorder is the mock recorder for MockFeeAllowanceI. +type MockFeeAllowanceIMockRecorder struct { + mock *MockFeeAllowanceI +} + +// NewMockFeeAllowanceI creates a new mock instance. +func NewMockFeeAllowanceI(ctrl *gomock.Controller) *MockFeeAllowanceI { + mock := &MockFeeAllowanceI{ctrl: ctrl} + mock.recorder = &MockFeeAllowanceIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFeeAllowanceI) EXPECT() *MockFeeAllowanceIMockRecorder { + return m.recorder +} + +// Accept mocks base method. +func (m *MockFeeAllowanceI) Accept(arg0 types.Context, arg1 types.Coins, arg2 []types.Msg) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept", arg0, arg1, arg2) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept. +func (mr *MockFeeAllowanceIMockRecorder) Accept(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockFeeAllowanceI)(nil).Accept), arg0, arg1, arg2) +} + +// ValidateBasic mocks base method. +func (m *MockFeeAllowanceI) ValidateBasic() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateBasic") + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateBasic indicates an expected call of ValidateBasic. +func (mr *MockFeeAllowanceIMockRecorder) ValidateBasic() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateBasic", reflect.TypeOf((*MockFeeAllowanceI)(nil).ValidateBasic)) +} diff --git a/testutil/mocks/feerefunder/types/keepers.go b/testutil/mocks/feerefunder/types/keepers.go index e18c9eaba..87a1c22f8 100644 --- a/testutil/mocks/feerefunder/types/keepers.go +++ b/testutil/mocks/feerefunder/types/keepers.go @@ -9,6 +9,7 @@ import ( types "github.com/cosmos/cosmos-sdk/types" types0 "github.com/cosmos/cosmos-sdk/x/auth/types" + feegrant "github.com/cosmos/cosmos-sdk/x/feegrant" types1 "github.com/cosmos/ibc-go/v3/modules/core/04-channel/types" gomock "github.com/golang/mock/gomock" ) @@ -166,3 +167,41 @@ func (mr *MockChannelKeeperMockRecorder) GetChannel(ctx, srcPort, srcChan interf mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChannel", reflect.TypeOf((*MockChannelKeeper)(nil).GetChannel), ctx, srcPort, srcChan) } + +// MockFeeGrantKeeper is a mock of FeeGrantKeeper interface. +type MockFeeGrantKeeper struct { + ctrl *gomock.Controller + recorder *MockFeeGrantKeeperMockRecorder +} + +// MockFeeGrantKeeperMockRecorder is the mock recorder for MockFeeGrantKeeper. +type MockFeeGrantKeeperMockRecorder struct { + mock *MockFeeGrantKeeper +} + +// NewMockFeeGrantKeeper creates a new mock instance. +func NewMockFeeGrantKeeper(ctrl *gomock.Controller) *MockFeeGrantKeeper { + mock := &MockFeeGrantKeeper{ctrl: ctrl} + mock.recorder = &MockFeeGrantKeeperMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFeeGrantKeeper) EXPECT() *MockFeeGrantKeeperMockRecorder { + return m.recorder +} + +// GetAllowance mocks base method. +func (m *MockFeeGrantKeeper) GetAllowance(ctx types.Context, granter, grantee types.AccAddress) (feegrant.FeeAllowanceI, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllowance", ctx, granter, grantee) + ret0, _ := ret[0].(feegrant.FeeAllowanceI) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllowance indicates an expected call of GetAllowance. +func (mr *MockFeeGrantKeeperMockRecorder) GetAllowance(ctx, granter, grantee interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllowance", reflect.TypeOf((*MockFeeGrantKeeper)(nil).GetAllowance), ctx, granter, grantee) +} diff --git a/testutil/mocks/gomock.go b/testutil/mocks/gomock.go index 2a03bba43..54c3177fa 100644 --- a/testutil/mocks/gomock.go +++ b/testutil/mocks/gomock.go @@ -6,3 +6,4 @@ package mocks //go:generate mockgen -source=./../../x/interchainqueries/types/expected_keepers.go -destination ./interchainqueries/types/expected_keepers.go //go:generate mockgen -source=./../../x/interchaintxs/types/expected_keepers.go -destination ./interchaintxs/types/expected_keepers.go //go:generate mockgen -source=./../../x/transfer/types/expected_keepers.go -destination ./transfer/types/expected_keepers.go +//go:generate mockgen -destination ./feegrant/types.go github.com/cosmos/cosmos-sdk/x/feegrant FeeAllowanceI diff --git a/testutil/mocks/interchaintxs/types/expected_keepers.go b/testutil/mocks/interchaintxs/types/expected_keepers.go index a4a17569c..7051d7baf 100644 --- a/testutil/mocks/interchaintxs/types/expected_keepers.go +++ b/testutil/mocks/interchaintxs/types/expected_keepers.go @@ -116,15 +116,15 @@ func (m *MockContractManagerKeeper) EXPECT() *MockContractManagerKeeperMockRecor } // AddContractFailure mocks base method. -func (m *MockContractManagerKeeper) AddContractFailure(ctx types.Context, channelId, address string, ackID uint64, ackType string) { +func (m *MockContractManagerKeeper) AddContractFailure(ctx types.Context, channelID, address string, ackID uint64, ackType string) { m.ctrl.T.Helper() - m.ctrl.Call(m, "AddContractFailure", ctx, channelId, address, ackID, ackType) + m.ctrl.Call(m, "AddContractFailure", ctx, channelID, address, ackID, ackType) } // AddContractFailure indicates an expected call of AddContractFailure. -func (mr *MockContractManagerKeeperMockRecorder) AddContractFailure(ctx, channelId, address, ackID, ackType interface{}) *gomock.Call { +func (mr *MockContractManagerKeeperMockRecorder) AddContractFailure(ctx, channelID, address, ackID, ackType interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddContractFailure", reflect.TypeOf((*MockContractManagerKeeper)(nil).AddContractFailure), ctx, channelId, address, ackID, ackType) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddContractFailure", reflect.TypeOf((*MockContractManagerKeeper)(nil).AddContractFailure), ctx, channelID, address, ackID, ackType) } // HasContractInfo mocks base method. @@ -330,8 +330,23 @@ func (mr *MockFeeRefunderKeeperMockRecorder) DistributeTimeoutFee(ctx, receiver, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DistributeTimeoutFee", reflect.TypeOf((*MockFeeRefunderKeeper)(nil).DistributeTimeoutFee), ctx, receiver, packetID) } +// GetPayerInfo mocks base method. +func (m *MockFeeRefunderKeeper) GetPayerInfo(ctx types.Context, sender, payer string) (types5.PayerInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPayerInfo", ctx, sender, payer) + ret0, _ := ret[0].(types5.PayerInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPayerInfo indicates an expected call of GetPayerInfo. +func (mr *MockFeeRefunderKeeperMockRecorder) GetPayerInfo(ctx, sender, payer interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPayerInfo", reflect.TypeOf((*MockFeeRefunderKeeper)(nil).GetPayerInfo), ctx, sender, payer) +} + // LockFees mocks base method. -func (m *MockFeeRefunderKeeper) LockFees(ctx types.Context, payer types.AccAddress, packetID types5.PacketID, fee types5.Fee) error { +func (m *MockFeeRefunderKeeper) LockFees(ctx types.Context, payer types5.PayerInfo, packetID types5.PacketID, fee types5.Fee) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LockFees", ctx, payer, packetID, fee) ret0, _ := ret[0].(error) diff --git a/testutil/mocks/transfer/types/expected_keepers.go b/testutil/mocks/transfer/types/expected_keepers.go index be7e75f5c..83c04e747 100644 --- a/testutil/mocks/transfer/types/expected_keepers.go +++ b/testutil/mocks/transfer/types/expected_keepers.go @@ -38,15 +38,15 @@ func (m *MockContractManagerKeeper) EXPECT() *MockContractManagerKeeperMockRecor } // AddContractFailure mocks base method. -func (m *MockContractManagerKeeper) AddContractFailure(ctx types.Context, channelId, address string, ackID uint64, ackType string) { +func (m *MockContractManagerKeeper) AddContractFailure(ctx types.Context, channelID, address string, ackID uint64, ackType string) { m.ctrl.T.Helper() - m.ctrl.Call(m, "AddContractFailure", ctx, channelId, address, ackID, ackType) + m.ctrl.Call(m, "AddContractFailure", ctx, channelID, address, ackID, ackType) } // AddContractFailure indicates an expected call of AddContractFailure. -func (mr *MockContractManagerKeeperMockRecorder) AddContractFailure(ctx, channelId, address, ackID, ackType interface{}) *gomock.Call { +func (mr *MockContractManagerKeeperMockRecorder) AddContractFailure(ctx, channelID, address, ackID, ackType interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddContractFailure", reflect.TypeOf((*MockContractManagerKeeper)(nil).AddContractFailure), ctx, channelId, address, ackID, ackType) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddContractFailure", reflect.TypeOf((*MockContractManagerKeeper)(nil).AddContractFailure), ctx, channelID, address, ackID, ackType) } // HasContractInfo mocks base method. @@ -155,8 +155,23 @@ func (mr *MockFeeRefunderKeeperMockRecorder) DistributeTimeoutFee(ctx, receiver, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DistributeTimeoutFee", reflect.TypeOf((*MockFeeRefunderKeeper)(nil).DistributeTimeoutFee), ctx, receiver, packetID) } +// GetPayerInfo mocks base method. +func (m *MockFeeRefunderKeeper) GetPayerInfo(ctx types.Context, sender, payer string) (types2.PayerInfo, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPayerInfo", ctx, sender, payer) + ret0, _ := ret[0].(types2.PayerInfo) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPayerInfo indicates an expected call of GetPayerInfo. +func (mr *MockFeeRefunderKeeperMockRecorder) GetPayerInfo(ctx, sender, payer interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPayerInfo", reflect.TypeOf((*MockFeeRefunderKeeper)(nil).GetPayerInfo), ctx, sender, payer) +} + // LockFees mocks base method. -func (m *MockFeeRefunderKeeper) LockFees(ctx types.Context, payer types.AccAddress, packetID types2.PacketID, fee types2.Fee) error { +func (m *MockFeeRefunderKeeper) LockFees(ctx types.Context, payer types2.PayerInfo, packetID types2.PacketID, fee types2.Fee) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LockFees", ctx, payer, packetID, fee) ret0, _ := ret[0].(error) diff --git a/x/feerefunder/genesis_test.go b/x/feerefunder/genesis_test.go index a2b84ee3e..c0914041b 100644 --- a/x/feerefunder/genesis_test.go +++ b/x/feerefunder/genesis_test.go @@ -32,7 +32,7 @@ func TestGenesis(t *testing.T) { require.EqualValues(t, genesisState.Params, types.DefaultParams()) - k, ctx := keeper.FeeKeeper(t, nil, nil) + k, ctx := keeper.FeeKeeper(t, nil, nil, nil) feerefunder.InitGenesis(ctx, *k, genesisState) got := feerefunder.ExportGenesis(ctx, *k) diff --git a/x/feerefunder/keeper/grpc_query_params_test.go b/x/feerefunder/keeper/grpc_query_params_test.go index a4527a0c3..de2e791b1 100644 --- a/x/feerefunder/keeper/grpc_query_params_test.go +++ b/x/feerefunder/keeper/grpc_query_params_test.go @@ -12,7 +12,7 @@ import ( ) func TestParamsQuery(t *testing.T) { - keeper, ctx := testkeeper.FeeKeeper(t, nil, nil) + keeper, ctx := testkeeper.FeeKeeper(t, nil, nil, nil) wctx := sdk.WrapSDKContext(ctx) params := types.DefaultParams() keeper.SetParams(ctx, params) diff --git a/x/feerefunder/keeper/keeper.go b/x/feerefunder/keeper/keeper.go index c9a8ba8e0..5523f1dad 100644 --- a/x/feerefunder/keeper/keeper.go +++ b/x/feerefunder/keeper/keeper.go @@ -18,12 +18,13 @@ import ( type ( Keeper struct { - cdc codec.BinaryCodec - bankKeeper types.BankKeeper - storeKey storetypes.StoreKey - memKey storetypes.StoreKey - paramstore paramtypes.Subspace - channelKeeper types.ChannelKeeper + cdc codec.BinaryCodec + bankKeeper types.BankKeeper + storeKey storetypes.StoreKey + memKey storetypes.StoreKey + paramstore paramtypes.Subspace + channelKeeper types.ChannelKeeper + feegrantKeeper types.FeeGrantKeeper } ) @@ -34,6 +35,7 @@ func NewKeeper( ps paramtypes.Subspace, channelKeeper types.ChannelKeeper, bankKeeper types.BankKeeper, + feegrantKeeper types.FeeGrantKeeper, ) *Keeper { // set KeyTable if it has not already been set if !ps.HasKeyTable() { @@ -41,12 +43,13 @@ func NewKeeper( } return &Keeper{ - cdc: cdc, - storeKey: storeKey, - memKey: memKey, - paramstore: ps, - channelKeeper: channelKeeper, - bankKeeper: bankKeeper, + cdc: cdc, + storeKey: storeKey, + memKey: memKey, + paramstore: ps, + channelKeeper: channelKeeper, + bankKeeper: bankKeeper, + feegrantKeeper: feegrantKeeper, } } @@ -54,17 +57,37 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger { return ctx.Logger().With("module", fmt.Sprintf("x/%s", types.ModuleName)) } -func (k Keeper) LockFees(ctx sdk.Context, payer sdk.AccAddress, packetID types.PacketID, fee types.Fee) error { +func (k Keeper) LockFees(ctx sdk.Context, payerInfo types.PayerInfo, packetID types.PacketID, fee types.Fee) error { k.Logger(ctx).Debug("Trying to lock fees", "packetID", packetID, "fee", fee) if _, ok := k.channelKeeper.GetChannel(ctx, packetID.PortId, packetID.ChannelId); !ok { return sdkerrors.Wrapf(channeltypes.ErrChannelNotFound, "channel with id %s and port %s not found", packetID.ChannelId, packetID.PortId) } + payer := payerInfo.Sender if err := k.checkFees(ctx, fee); err != nil { return sdkerrors.Wrapf(err, "failed to lock fees") } + if payerInfo.FeePayer != nil { + allowance, err := k.feegrantKeeper.GetAllowance(ctx, payerInfo.FeePayer, payerInfo.Sender) + if err != nil { + return sdkerrors.Wrapf(err, "failed to get allowance") + } + + if allowance != nil { // otherwise there is no allowance + coins := sdk.NewCoins() + coins = append(coins, fee.TimeoutFee...) + coins = append(coins, fee.AckFee...) + coins = append(coins, fee.RecvFee...) + _, err = allowance.Accept(ctx, coins, []sdk.Msg{}) + if err != nil { + return sdkerrors.Wrapf(err, "failed to accept allowance, it's expired or spent") + } + payer = payerInfo.FeePayer + } + } + feeInfo := types.FeeInfo{ Payer: payer.String(), Fee: fee, @@ -204,6 +227,29 @@ func (k Keeper) StoreFeeInfo(ctx sdk.Context, feeInfo types.FeeInfo) { store.Set(types.GetFeePacketKey(feeInfo.PacketId), bzFeeInfo) } +func (k Keeper) GetPayerInfo(ctx sdk.Context, sender string, payer string) (types.PayerInfo, error) { + senderAddr, err := sdk.AccAddressFromBech32(sender) + if err != nil { + k.Logger(ctx).Debug("Transfer: failed to parse sender address", "sender", sender) + return types.PayerInfo{}, sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "failed to parse address: %s", sender) + } + if payer == "" { + return types.PayerInfo{ + Sender: senderAddr, + FeePayer: nil, + }, nil + } + feePayerAddr, err := sdk.AccAddressFromBech32(payer) + if err != nil { + k.Logger(ctx).Debug("Transfer: failed to parse fee payer address", "fee payer", payer) + return types.PayerInfo{}, sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "failed to parse address: %s", payer) + } + return types.PayerInfo{ + Sender: senderAddr, + FeePayer: feePayerAddr, + }, nil +} + func (k Keeper) removeFeeInfo(ctx sdk.Context, packetID types.PacketID) { store := ctx.KVStore(k.storeKey) diff --git a/x/feerefunder/keeper/keeper_test.go b/x/feerefunder/keeper/keeper_test.go index 434bebaf8..af6a93fe8 100644 --- a/x/feerefunder/keeper/keeper_test.go +++ b/x/feerefunder/keeper/keeper_test.go @@ -5,6 +5,8 @@ import ( "strconv" "testing" + mock_feegrant "github.com/neutron-org/neutron/testutil/mocks/feegrant" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -24,11 +26,12 @@ import ( ) const ( - TestAddress = "neutron17dtl0mjt3t77kpuhg2edqzjpszulwhgzcdvagh" + TestAddress = "neutron17dtl0mjt3t77kpuhg2edqzjpszulwhgzcdvagh" + TestFeePayerAddress = "neutron1m9l358xunhhwds0568za49mzhvuxx9ux8xafx2" ) func TestKeeperCheckFees(t *testing.T) { - k, ctx := testutil_keeper.FeeKeeper(t, nil, nil) + k, ctx := testutil_keeper.FeeKeeper(t, nil, nil, nil) k.SetParams(ctx, types.Params{ MinFee: types.Fee{ @@ -124,9 +127,13 @@ func TestKeeperLockFees(t *testing.T) { defer ctrl.Finish() bankKeeper := mock_types.NewMockBankKeeper(ctrl) channelKeeper := mock_types.NewMockChannelKeeper(ctrl) - k, ctx := testutil_keeper.FeeKeeper(t, channelKeeper, bankKeeper) + feegrantKeeper := mock_types.NewMockFeeGrantKeeper(ctrl) + k, ctx := testutil_keeper.FeeKeeper(t, channelKeeper, bankKeeper, feegrantKeeper) - payer := sdk.MustAccAddressFromBech32(testutil.TestOwnerAddress) + payerInfo := types.PayerInfo{ + Sender: sdk.MustAccAddressFromBech32(testutil.TestOwnerAddress), + FeePayer: nil, + } k.SetParams(ctx, types.Params{ MinFee: types.Fee{ @@ -143,11 +150,11 @@ func TestKeeperLockFees(t *testing.T) { } channelKeeper.EXPECT().GetChannel(ctx, packet.PortId, packet.ChannelId).Return(channeltypes.Channel{}, false) - err := k.LockFees(ctx, payer, packet, types.Fee{}) + err := k.LockFees(ctx, payerInfo, packet, types.Fee{}) require.True(t, channeltypes.ErrChannelNotFound.Is(err)) channelKeeper.EXPECT().GetChannel(ctx, packet.PortId, packet.ChannelId).Return(channeltypes.Channel{}, true) - err = k.LockFees(ctx, payer, packet, types.Fee{}) + err = k.LockFees(ctx, payerInfo, packet, types.Fee{}) require.True(t, sdkerrors.ErrInsufficientFee.Is(err)) validFee := types.Fee{ @@ -156,18 +163,113 @@ func TestKeeperLockFees(t *testing.T) { TimeoutFee: sdk.NewCoins(sdk.NewCoin("denom1", sdk.NewInt(101))), } channelKeeper.EXPECT().GetChannel(ctx, packet.PortId, packet.ChannelId).Return(channeltypes.Channel{}, true) - bankKeeper.EXPECT().SendCoinsFromAccountToModule(ctx, payer, types.ModuleName, validFee.Total()).Return(fmt.Errorf("bank error")) - err = k.LockFees(ctx, payer, packet, validFee) + bankKeeper.EXPECT().SendCoinsFromAccountToModule(ctx, payerInfo.Sender, types.ModuleName, validFee.Total()).Return(fmt.Errorf("bank error")) + err = k.LockFees(ctx, payerInfo, packet, validFee) require.ErrorContains(t, err, "bank error") channelKeeper.EXPECT().GetChannel(ctx, packet.PortId, packet.ChannelId).Return(channeltypes.Channel{}, true) - bankKeeper.EXPECT().SendCoinsFromAccountToModule(ctx, payer, types.ModuleName, validFee.Total()).Return(nil) - err = k.LockFees(ctx, payer, packet, validFee) + bankKeeper.EXPECT().SendCoinsFromAccountToModule(ctx, payerInfo.Sender, types.ModuleName, validFee.Total()).Return(nil) + err = k.LockFees(ctx, payerInfo, packet, validFee) + require.NoError(t, err) + require.Equal(t, sdk.Events{ + sdk.NewEvent( + types.EventTypeLockFees, + sdk.NewAttribute(types.AttributeKeyPayer, payerInfo.Sender.String()), + sdk.NewAttribute(types.AttributeKeyPortID, packet.PortId), + sdk.NewAttribute(types.AttributeKeyChannelID, packet.ChannelId), + sdk.NewAttribute(types.AttributeKeySequence, strconv.FormatUint(packet.Sequence, 10)), + ), + sdk.NewEvent( + sdk.EventTypeMessage, + sdk.NewAttribute(sdk.AttributeKeyModule, types.ModuleName), + ), + }, ctx.EventManager().Events()) +} + +func TestKeeperGetPayerInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + bankKeeper := mock_types.NewMockBankKeeper(ctrl) + channelKeeper := mock_types.NewMockChannelKeeper(ctrl) + feegrantKeeper := mock_types.NewMockFeeGrantKeeper(ctrl) + k, ctx := testutil_keeper.FeeKeeper(t, channelKeeper, bankKeeper, feegrantKeeper) + + _, err := k.GetPayerInfo(ctx, "", "") + require.ErrorContains(t, err, "failed to parse address") + + p, err := k.GetPayerInfo(ctx, TestAddress, "") + require.NoError(t, err) + require.Equal(t, p, types.PayerInfo{ + Sender: sdk.MustAccAddressFromBech32(TestAddress), + FeePayer: nil, + }) + + p, err = k.GetPayerInfo(ctx, TestAddress, TestFeePayerAddress) + require.NoError(t, err) + require.Equal(t, p, types.PayerInfo{ + Sender: sdk.MustAccAddressFromBech32(TestAddress), + FeePayer: sdk.MustAccAddressFromBech32(TestFeePayerAddress), + }) +} + +func TestKeeperLockFeesAtFeePayeer(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + bankKeeper := mock_types.NewMockBankKeeper(ctrl) + channelKeeper := mock_types.NewMockChannelKeeper(ctrl) + feegrantKeeper := mock_types.NewMockFeeGrantKeeper(ctrl) + feeAllowance := mock_feegrant.NewMockFeeAllowanceI(ctrl) + k, ctx := testutil_keeper.FeeKeeper(t, channelKeeper, bankKeeper, feegrantKeeper) + + payerInfo := types.PayerInfo{ + Sender: sdk.MustAccAddressFromBech32(testutil.TestOwnerAddress), + FeePayer: sdk.MustAccAddressFromBech32(TestFeePayerAddress), + } + + k.SetParams(ctx, types.Params{ + MinFee: types.Fee{ + RecvFee: nil, + AckFee: sdk.NewCoins(sdk.NewCoin("denom1", sdk.NewInt(100)), sdk.NewCoin("denom2", sdk.NewInt(100))), + TimeoutFee: sdk.NewCoins(sdk.NewCoin("denom1", sdk.NewInt(100)), sdk.NewCoin("denom2", sdk.NewInt(100))), + }, + }) + + packet := types.PacketID{ + ChannelId: "channel-0", + PortId: "transfer", + Sequence: 111, + } + ackFee := sdk.NewCoins(sdk.NewCoin("denom1", sdk.NewInt(101))) + timeoutFee := sdk.NewCoins(sdk.NewCoin("denom1", sdk.NewInt(101))) + validFee := types.Fee{ + RecvFee: nil, + AckFee: ackFee, + TimeoutFee: timeoutFee, + } + + channelKeeper.EXPECT().GetChannel(ctx, packet.PortId, packet.ChannelId).Return(channeltypes.Channel{}, true) + feegrantKeeper.EXPECT().GetAllowance(ctx, payerInfo.FeePayer, payerInfo.Sender).Return(nil, fmt.Errorf("feegrant error")) + err := k.LockFees(ctx, payerInfo, packet, validFee) + require.ErrorContains(t, err, "feegrant error") + + channelKeeper.EXPECT().GetChannel(ctx, packet.PortId, packet.ChannelId).Return(channeltypes.Channel{}, true) + fees := append(ackFee, timeoutFee...) + feeAllowance.EXPECT().Accept(ctx, fees, []sdk.Msg{}).Return(false, fmt.Errorf("fee allowance accept error")) + feegrantKeeper.EXPECT().GetAllowance(ctx, payerInfo.FeePayer, payerInfo.Sender).Return(feeAllowance, nil) + err = k.LockFees(ctx, payerInfo, packet, validFee) + require.ErrorContains(t, err, "fee allowance accept error") + + channelKeeper.EXPECT().GetChannel(ctx, packet.PortId, packet.ChannelId).Return(channeltypes.Channel{}, true) + feeAllowance.EXPECT().Accept(ctx, fees, []sdk.Msg{}).Return(false, nil) + feegrantKeeper.EXPECT().GetAllowance(ctx, payerInfo.FeePayer, payerInfo.Sender).Return(feeAllowance, nil) + bankKeeper.EXPECT().SendCoinsFromAccountToModule(ctx, payerInfo.FeePayer, types.ModuleName, validFee.Total()).Return(nil) + err = k.LockFees(ctx, payerInfo, packet, validFee) require.NoError(t, err) require.Equal(t, sdk.Events{ sdk.NewEvent( types.EventTypeLockFees, - sdk.NewAttribute(types.AttributeKeyPayer, payer.String()), + sdk.NewAttribute(types.AttributeKeyPayer, payerInfo.FeePayer.String()), sdk.NewAttribute(types.AttributeKeyPortID, packet.PortId), sdk.NewAttribute(types.AttributeKeyChannelID, packet.ChannelId), sdk.NewAttribute(types.AttributeKeySequence, strconv.FormatUint(packet.Sequence, 10)), @@ -184,7 +286,8 @@ func TestDistributeAcknowledgementFee(t *testing.T) { defer ctrl.Finish() bankKeeper := mock_types.NewMockBankKeeper(ctrl) channelKeeper := mock_types.NewMockChannelKeeper(ctrl) - k, ctx := testutil_keeper.FeeKeeper(t, channelKeeper, bankKeeper) + feegrantKeeper := mock_types.NewMockFeeGrantKeeper(ctrl) + k, ctx := testutil_keeper.FeeKeeper(t, channelKeeper, bankKeeper, feegrantKeeper) validFee := types.Fee{ RecvFee: nil, @@ -250,7 +353,8 @@ func TestDistributeTimeoutFee(t *testing.T) { defer ctrl.Finish() bankKeeper := mock_types.NewMockBankKeeper(ctrl) channelKeeper := mock_types.NewMockChannelKeeper(ctrl) - k, ctx := testutil_keeper.FeeKeeper(t, channelKeeper, bankKeeper) + feegrantKeeper := mock_types.NewMockFeeGrantKeeper(ctrl) + k, ctx := testutil_keeper.FeeKeeper(t, channelKeeper, bankKeeper, feegrantKeeper) validFee := types.Fee{ RecvFee: nil, @@ -312,7 +416,7 @@ func TestDistributeTimeoutFee(t *testing.T) { } func TestFeeInfo(t *testing.T) { - k, ctx := testutil_keeper.FeeKeeper(t, nil, nil) + k, ctx := testutil_keeper.FeeKeeper(t, nil, nil, nil) validFee := types.Fee{ RecvFee: nil, AckFee: sdk.NewCoins(sdk.NewCoin("untrn", sdk.NewInt(1001))), diff --git a/x/feerefunder/keeper/params_test.go b/x/feerefunder/keeper/params_test.go index 890d21cb8..eaa1845ea 100644 --- a/x/feerefunder/keeper/params_test.go +++ b/x/feerefunder/keeper/params_test.go @@ -11,7 +11,7 @@ import ( ) func TestGetParams(t *testing.T) { - k, ctx := testkeeper.FeeKeeper(t, nil, nil) + k, ctx := testkeeper.FeeKeeper(t, nil, nil, nil) params := types.DefaultParams() k.SetParams(ctx, params) diff --git a/x/feerefunder/types/expected_keepers.go b/x/feerefunder/types/expected_keepers.go index eebfd51d9..3f14eff3d 100644 --- a/x/feerefunder/types/expected_keepers.go +++ b/x/feerefunder/types/expected_keepers.go @@ -3,6 +3,7 @@ package types import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/x/auth/types" + "github.com/cosmos/cosmos-sdk/x/feegrant" channeltypes "github.com/cosmos/ibc-go/v3/modules/core/04-channel/types" ) @@ -25,3 +26,7 @@ type BankKeeper interface { type ChannelKeeper interface { GetChannel(ctx sdk.Context, srcPort, srcChan string) (channel channeltypes.Channel, found bool) } + +type FeeGrantKeeper interface { + GetAllowance(ctx sdk.Context, granter, grantee sdk.AccAddress) (feegrant.FeeAllowanceI, error) +} diff --git a/x/feerefunder/types/fee.pb.go b/x/feerefunder/types/fee.pb.go index 77720d234..f788fa6e8 100644 --- a/x/feerefunder/types/fee.pb.go +++ b/x/feerefunder/types/fee.pb.go @@ -33,6 +33,7 @@ type Fee struct { AckFee github_com_cosmos_cosmos_sdk_types.Coins `protobuf:"bytes,2,rep,name=ack_fee,json=ackFee,proto3,castrepeated=github.com/cosmos/cosmos-sdk/types.Coins" json:"ack_fee" yaml:"ack_fee"` // the packet timeout fee TimeoutFee github_com_cosmos_cosmos_sdk_types.Coins `protobuf:"bytes,3,rep,name=timeout_fee,json=timeoutFee,proto3,castrepeated=github.com/cosmos/cosmos-sdk/types.Coins" json:"timeout_fee" yaml:"timeout_fee"` + Payer string `protobuf:"bytes,4,opt,name=payer,proto3" json:"payer,omitempty"` } func (m *Fee) Reset() { *m = Fee{} } @@ -89,6 +90,13 @@ func (m *Fee) GetTimeoutFee() github_com_cosmos_cosmos_sdk_types.Coins { return nil } +func (m *Fee) GetPayer() string { + if m != nil { + return m.Payer + } + return "" +} + type PacketID struct { ChannelId string `protobuf:"bytes,1,opt,name=channel_id,json=channelId,proto3" json:"channel_id,omitempty"` PortId string `protobuf:"bytes,2,opt,name=port_id,json=portId,proto3" json:"port_id,omitempty"` @@ -157,31 +165,32 @@ func init() { func init() { proto.RegisterFile("feerefunder/fee.proto", fileDescriptor_0c6cd4ef4b890305) } var fileDescriptor_0c6cd4ef4b890305 = []byte{ - // 380 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x92, 0x4f, 0x6b, 0xe2, 0x40, - 0x18, 0xc6, 0x33, 0xba, 0xf8, 0x67, 0x84, 0x5d, 0xc8, 0xee, 0xb2, 0xae, 0xb0, 0x51, 0x72, 0xca, - 0xc5, 0x0c, 0xba, 0xb7, 0x3d, 0xea, 0x22, 0x08, 0x3d, 0x14, 0x8f, 0x3d, 0x54, 0x92, 0xc9, 0x9b, - 0x18, 0x62, 0x66, 0x6c, 0x32, 0x91, 0x7a, 0xed, 0x27, 0xe8, 0xe7, 0xe8, 0x27, 0xf1, 0xe8, 0xb1, - 0x27, 0x5b, 0xf4, 0x1b, 0xf4, 0x5e, 0x28, 0x93, 0x8c, 0x25, 0x37, 0xf1, 0x34, 0xef, 0x1f, 0x9e, - 0xe7, 0xf7, 0x0c, 0xbc, 0xf8, 0xa7, 0x0f, 0x90, 0x80, 0x9f, 0x31, 0x0f, 0x12, 0xe2, 0x03, 0xd8, - 0xab, 0x84, 0x0b, 0xae, 0x7f, 0x67, 0x90, 0x89, 0x84, 0x33, 0xbb, 0xb4, 0xee, 0x18, 0x94, 0xa7, - 0x31, 0x4f, 0x89, 0xeb, 0xa4, 0x40, 0xd6, 0x03, 0x17, 0x84, 0x33, 0x20, 0x94, 0x87, 0xac, 0x10, - 0x75, 0x7e, 0x04, 0x3c, 0xe0, 0x79, 0x49, 0x64, 0x55, 0x4c, 0xcd, 0xf7, 0x0a, 0xae, 0x4e, 0x00, - 0xf4, 0x0d, 0x6e, 0x24, 0x40, 0xd7, 0x73, 0x1f, 0xa0, 0x8d, 0x7a, 0x55, 0xab, 0x35, 0xfc, 0x6d, - 0x17, 0x86, 0xb6, 0x34, 0xb4, 0x95, 0xa1, 0x3d, 0xe6, 0x21, 0x1b, 0x8d, 0xb7, 0xfb, 0xae, 0xf6, - 0xb6, 0xef, 0x7e, 0xdb, 0x38, 0xf1, 0xf2, 0x9f, 0x79, 0x12, 0x9a, 0x4f, 0x2f, 0x5d, 0x2b, 0x08, - 0xc5, 0x22, 0x73, 0x6d, 0xca, 0x63, 0xa2, 0x02, 0x15, 0x4f, 0x3f, 0xf5, 0x22, 0x22, 0x36, 0x2b, - 0x48, 0x73, 0x8f, 0x74, 0x56, 0x97, 0x32, 0x89, 0x5e, 0xe3, 0xba, 0x43, 0xa3, 0x9c, 0x5c, 0x39, - 0x47, 0x1e, 0x29, 0xf2, 0xd7, 0x82, 0xac, 0x74, 0x97, 0x81, 0x6b, 0x0e, 0x8d, 0x24, 0xf7, 0x01, - 0xe1, 0x96, 0x08, 0x63, 0xe0, 0x99, 0xc8, 0xe1, 0xd5, 0x73, 0xf0, 0x89, 0x82, 0xeb, 0x05, 0xbc, - 0xa4, 0xbd, 0x2c, 0x00, 0x56, 0xca, 0x09, 0x80, 0x79, 0x8b, 0x1b, 0xd7, 0x0e, 0x8d, 0x40, 0x4c, - 0xff, 0xeb, 0x7f, 0x30, 0xa6, 0x0b, 0x87, 0x31, 0x58, 0xce, 0x43, 0xaf, 0x8d, 0x7a, 0xc8, 0x6a, - 0xce, 0x9a, 0x6a, 0x32, 0xf5, 0xf4, 0x5f, 0xb8, 0xbe, 0xe2, 0x89, 0x90, 0xbb, 0x4a, 0xbe, 0xab, - 0xc9, 0x76, 0xea, 0xe9, 0x1d, 0xdc, 0x48, 0xe1, 0x2e, 0x03, 0x46, 0xe5, 0x27, 0x90, 0xf5, 0x65, - 0xf6, 0xd9, 0x8f, 0xae, 0xb6, 0x07, 0x03, 0xed, 0x0e, 0x06, 0x7a, 0x3d, 0x18, 0xe8, 0xf1, 0x68, - 0x68, 0xbb, 0xa3, 0xa1, 0x3d, 0x1f, 0x0d, 0xed, 0x66, 0x58, 0xca, 0xab, 0xee, 0xa9, 0xcf, 0x93, - 0xe0, 0x54, 0x93, 0x7b, 0x52, 0x3e, 0xbe, 0x3c, 0xbf, 0x5b, 0xcb, 0x8f, 0xe6, 0xef, 0x47, 0x00, - 0x00, 0x00, 0xff, 0xff, 0x63, 0x01, 0x6f, 0xe2, 0x98, 0x02, 0x00, 0x00, + // 394 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x92, 0x3f, 0x6e, 0xdb, 0x30, + 0x14, 0xc6, 0xc5, 0x28, 0xb5, 0x1d, 0x06, 0x68, 0x01, 0x35, 0x45, 0x5d, 0x03, 0x95, 0x03, 0x4d, + 0x5a, 0x22, 0x22, 0xe9, 0xd6, 0xd1, 0x29, 0x02, 0x18, 0xe8, 0x50, 0x78, 0xec, 0xd0, 0x80, 0xa2, + 0x9e, 0x14, 0x41, 0x11, 0xa9, 0x52, 0x94, 0x51, 0xad, 0x3d, 0x41, 0x81, 0xde, 0xa2, 0x27, 0xf1, + 0xe8, 0xb1, 0x93, 0x5b, 0xd8, 0x37, 0xe8, 0x09, 0x0a, 0x52, 0x74, 0xa0, 0xcd, 0xf0, 0xc4, 0xf7, + 0x87, 0xdf, 0xf7, 0x7b, 0x0f, 0x78, 0xf8, 0x55, 0x0a, 0x20, 0x21, 0x6d, 0x78, 0x02, 0x92, 0xa4, + 0x00, 0x51, 0x25, 0x85, 0x12, 0xde, 0x4b, 0x0e, 0x8d, 0x92, 0x82, 0x47, 0xbd, 0xf6, 0xc4, 0x67, + 0xa2, 0x2e, 0x45, 0x4d, 0x62, 0x5a, 0x03, 0x59, 0x5e, 0xc7, 0xa0, 0xe8, 0x35, 0x61, 0x22, 0xe7, + 0x9d, 0x68, 0x72, 0x91, 0x89, 0x4c, 0x98, 0x90, 0xe8, 0xa8, 0xab, 0x06, 0x3f, 0x5d, 0xec, 0xde, + 0x01, 0x78, 0x2d, 0x1e, 0x49, 0x60, 0xcb, 0xfb, 0x14, 0x60, 0x8c, 0x2e, 0xdd, 0xf0, 0xfc, 0xe6, + 0x4d, 0xd4, 0x19, 0x46, 0xda, 0x30, 0xb2, 0x86, 0xd1, 0xad, 0xc8, 0xf9, 0xec, 0x76, 0xb5, 0x99, + 0x3a, 0xff, 0x36, 0xd3, 0x17, 0x2d, 0x2d, 0x1f, 0xdf, 0x07, 0x7b, 0x61, 0xf0, 0xeb, 0xcf, 0x34, + 0xcc, 0x72, 0xf5, 0xd0, 0xc4, 0x11, 0x13, 0x25, 0xb1, 0x03, 0x75, 0xcf, 0x55, 0x9d, 0x14, 0x44, + 0xb5, 0x15, 0xd4, 0xc6, 0xa3, 0x5e, 0x0c, 0xb5, 0x4c, 0xa3, 0x97, 0x78, 0x48, 0x59, 0x61, 0xc8, + 0x27, 0x87, 0xc8, 0x33, 0x4b, 0x7e, 0xde, 0x91, 0xad, 0xee, 0x38, 0xf0, 0x80, 0xb2, 0x42, 0x73, + 0xbf, 0x23, 0x7c, 0xae, 0xf2, 0x12, 0x44, 0xa3, 0x0c, 0xdc, 0x3d, 0x04, 0xbf, 0xb3, 0x70, 0xaf, + 0x83, 0xf7, 0xb4, 0xc7, 0x0d, 0x80, 0xad, 0x52, 0x0f, 0x71, 0x81, 0x9f, 0x55, 0xb4, 0x05, 0x39, + 0x3e, 0xbd, 0x44, 0xe1, 0xd9, 0xa2, 0x4b, 0x82, 0x2f, 0x78, 0xf4, 0x89, 0xb2, 0x02, 0xd4, 0xfc, + 0x83, 0xf7, 0x16, 0x63, 0xf6, 0x40, 0x39, 0x87, 0xc7, 0xfb, 0x3c, 0x19, 0x23, 0xf3, 0xed, 0xcc, + 0x56, 0xe6, 0x89, 0xf7, 0x1a, 0x0f, 0x2b, 0x21, 0x95, 0xee, 0x9d, 0x98, 0xde, 0x40, 0xa7, 0xf3, + 0xc4, 0x9b, 0xe0, 0x51, 0x0d, 0x5f, 0x1b, 0xe0, 0x4c, 0xaf, 0x86, 0xc2, 0xd3, 0xc5, 0x53, 0x3e, + 0xfb, 0xb8, 0xda, 0xfa, 0x68, 0xbd, 0xf5, 0xd1, 0xdf, 0xad, 0x8f, 0x7e, 0xec, 0x7c, 0x67, 0xbd, + 0xf3, 0x9d, 0xdf, 0x3b, 0xdf, 0xf9, 0x7c, 0xd3, 0xdb, 0xc2, 0x5e, 0xd9, 0x95, 0x90, 0xd9, 0x3e, + 0x26, 0xdf, 0x48, 0xff, 0x24, 0xcd, 0x56, 0xf1, 0xc0, 0x9c, 0xd2, 0xbb, 0xff, 0x01, 0x00, 0x00, + 0xff, 0xff, 0x37, 0x55, 0xdb, 0x6f, 0xae, 0x02, 0x00, 0x00, } func (m *Fee) Marshal() (dAtA []byte, err error) { @@ -204,6 +213,13 @@ func (m *Fee) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l + if len(m.Payer) > 0 { + i -= len(m.Payer) + copy(dAtA[i:], m.Payer) + i = encodeVarintFee(dAtA, i, uint64(len(m.Payer))) + i-- + dAtA[i] = 0x22 + } if len(m.TimeoutFee) > 0 { for iNdEx := len(m.TimeoutFee) - 1; iNdEx >= 0; iNdEx-- { { @@ -326,6 +342,10 @@ func (m *Fee) Size() (n int) { n += 1 + l + sovFee(uint64(l)) } } + l = len(m.Payer) + if l > 0 { + n += 1 + l + sovFee(uint64(l)) + } return n } @@ -486,6 +506,38 @@ func (m *Fee) Unmarshal(dAtA []byte) error { return err } iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Payer", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowFee + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthFee + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthFee + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Payer = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipFee(dAtA[iNdEx:]) diff --git a/x/feerefunder/types/payerinfo.go b/x/feerefunder/types/payerinfo.go new file mode 100644 index 000000000..f394c1bd1 --- /dev/null +++ b/x/feerefunder/types/payerinfo.go @@ -0,0 +1,8 @@ +package types + +import sdk "github.com/cosmos/cosmos-sdk/types" + +type PayerInfo struct { + Sender sdk.AccAddress + FeePayer sdk.AccAddress +} diff --git a/x/interchaintxs/keeper/msg_server.go b/x/interchaintxs/keeper/msg_server.go index 67d04d6f1..0aa67e467 100644 --- a/x/interchaintxs/keeper/msg_server.go +++ b/x/interchaintxs/keeper/msg_server.go @@ -72,13 +72,12 @@ func (k Keeper) SubmitTx(goCtx context.Context, msg *ictxtypes.MsgSubmitTx) (*ic ctx := sdk.UnwrapSDKContext(goCtx) k.Logger(ctx).Debug("SubmitTx", "connection_id", msg.ConnectionId, "from_address", msg.FromAddress, "interchain_account_id", msg.InterchainAccountId) - senderAddr, err := sdk.AccAddressFromBech32(msg.FromAddress) + payerInfo, err := k.feeKeeper.GetPayerInfo(ctx, msg.FromAddress, msg.Fee.Payer) if err != nil { - k.Logger(ctx).Debug("SubmitTx: failed to parse sender address", "from_address", msg.FromAddress) - return nil, sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "failed to parse address: %s", msg.FromAddress) + return nil, sdkerrors.Wrapf(err, "failed to get payer info for sender: %s, payer: %s", msg.FromAddress, msg.Fee.Payer) } - if !k.contractManagerKeeper.HasContractInfo(ctx, senderAddr) { + if !k.contractManagerKeeper.HasContractInfo(ctx, payerInfo.Sender) { k.Logger(ctx).Debug("SubmitTx: contract not found", "from_address", msg.FromAddress) return nil, sdkerrors.Wrapf(ictxtypes.ErrNotContract, "%s is not a contract address", msg.FromAddress) } @@ -97,7 +96,7 @@ func (k Keeper) SubmitTx(goCtx context.Context, msg *ictxtypes.MsgSubmitTx) (*ic ) } - icaOwner := ictxtypes.NewICAOwnerFromAddress(senderAddr, msg.InterchainAccountId) + icaOwner := ictxtypes.NewICAOwnerFromAddress(payerInfo.Sender, msg.InterchainAccountId) portID, err := icatypes.NewControllerPortID(icaOwner.String()) if err != nil { @@ -137,7 +136,7 @@ func (k Keeper) SubmitTx(goCtx context.Context, msg *ictxtypes.MsgSubmitTx) (*ic ) } - if err := k.feeKeeper.LockFees(ctx, senderAddr, feetypes.NewPacketID(portID, channelID, sequence), msg.Fee); err != nil { + if err := k.feeKeeper.LockFees(ctx, payerInfo, feetypes.NewPacketID(portID, channelID, sequence), msg.Fee); err != nil { return nil, sdkerrors.Wrapf(err, "failed to lock fees to pay for SubmitTx msg: %s", msg) } diff --git a/x/interchaintxs/keeper/msg_server_test.go b/x/interchaintxs/keeper/msg_server_test.go index a53016ecc..c5e051c55 100644 --- a/x/interchaintxs/keeper/msg_server_test.go +++ b/x/interchaintxs/keeper/msg_server_test.go @@ -2,20 +2,17 @@ package keeper_test import ( "fmt" - "testing" - "time" - "github.com/cosmos/cosmos-sdk/codec" + sdk "github.com/cosmos/cosmos-sdk/types" types2 "github.com/cosmos/cosmos-sdk/x/capability/types" icatypes "github.com/cosmos/ibc-go/v3/modules/apps/27-interchain-accounts/types" host "github.com/cosmos/ibc-go/v3/modules/core/24-host" - + "github.com/golang/mock/gomock" feerefundertypes "github.com/neutron-org/neutron/x/feerefunder/types" "github.com/neutron-org/neutron/x/interchaintxs/keeper" - - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + "testing" + "time" codectypes "github.com/cosmos/cosmos-sdk/codec/types" @@ -98,10 +95,14 @@ func TestSubmitTx(t *testing.T) { require.Nil(t, resp) require.ErrorContains(t, err, "empty Msgs field is prohibited") + refundKeeper.EXPECT().GetPayerInfo(ctx, "", "").Return(feerefundertypes.PayerInfo{}, fmt.Errorf("failed to parse address")) resp, err = icak.SubmitTx(goCtx, &types.MsgSubmitTx{Msgs: []*codectypes.Any{&cosmosMsg}}) require.Nil(t, resp) require.ErrorContains(t, err, "failed to parse address") + refundKeeper.EXPECT().GetPayerInfo(ctx, submitMsg.FromAddress, submitMsg.Fee.Payer).Return( + feerefundertypes.PayerInfo{Sender: contractAddress, FeePayer: sdk.AccAddress{}}, nil, + ) cmKeeper.EXPECT().HasContractInfo(ctx, contractAddress).Return(false) resp, err = icak.SubmitTx(goCtx, &submitMsg) require.Nil(t, resp) @@ -110,13 +111,19 @@ func TestSubmitTx(t *testing.T) { params := icak.GetParams(ctx) maxMsgs := params.GetMsgSubmitTxMaxMessages() submitMsg.Msgs = make([]*codectypes.Any, maxMsgs+1) + refundKeeper.EXPECT().GetPayerInfo(ctx, submitMsg.FromAddress, submitMsg.Fee.Payer).Return( + feerefundertypes.PayerInfo{Sender: contractAddress, FeePayer: sdk.AccAddress{}}, nil, + ) cmKeeper.EXPECT().HasContractInfo(ctx, contractAddress).Return(true) resp, err = icak.SubmitTx(goCtx, &submitMsg) require.Nil(t, resp) require.ErrorContains(t, err, "MsgSubmitTx contains more messages than allowed") - submitMsg.Msgs = []*codectypes.Any{&cosmosMsg} + submitMsg.Msgs = []*codectypes.Any{&cosmosMsg} portID := "icacontroller-" + testutil.TestOwnerAddress + ".ica0" + refundKeeper.EXPECT().GetPayerInfo(ctx, submitMsg.FromAddress, submitMsg.Fee.Payer).Return( + feerefundertypes.PayerInfo{Sender: contractAddress, FeePayer: sdk.AccAddress{}}, nil, + ) cmKeeper.EXPECT().HasContractInfo(ctx, contractAddress).Return(true) icaKeeper.EXPECT().GetActiveChannelID(ctx, "connection-0", portID).Return("", false) resp, err = icak.SubmitTx(goCtx, &submitMsg) @@ -125,6 +132,9 @@ func TestSubmitTx(t *testing.T) { activeChannel := "channel-0" cmKeeper.EXPECT().HasContractInfo(ctx, contractAddress).Return(true) + refundKeeper.EXPECT().GetPayerInfo(ctx, submitMsg.FromAddress, submitMsg.Fee.Payer).Return( + feerefundertypes.PayerInfo{Sender: contractAddress, FeePayer: sdk.AccAddress{}}, nil, + ) icaKeeper.EXPECT().GetActiveChannelID(ctx, "connection-0", portID).Return(activeChannel, true) capabilityKeeper.EXPECT().GetCapability(ctx, host.ChannelCapabilityPath(portID, activeChannel)).Return(nil, false) resp, err = icak.SubmitTx(goCtx, &submitMsg) @@ -135,6 +145,9 @@ func TestSubmitTx(t *testing.T) { cmKeeper.EXPECT().HasContractInfo(ctx, contractAddress).Return(true) icaKeeper.EXPECT().GetActiveChannelID(ctx, "connection-0", portID).Return(activeChannel, true) capabilityKeeper.EXPECT().GetCapability(ctx, host.ChannelCapabilityPath(portID, activeChannel)).Return(&capability, true) + refundKeeper.EXPECT().GetPayerInfo(ctx, submitMsg.FromAddress, submitMsg.Fee.Payer).Return( + feerefundertypes.PayerInfo{Sender: contractAddress, FeePayer: sdk.AccAddress{}}, nil, + ) currCodec := icak.Codec icak.Codec = &codec.AminoCodec{} resp, err = icak.SubmitTx(goCtx, &submitMsg) @@ -145,6 +158,9 @@ func TestSubmitTx(t *testing.T) { cmKeeper.EXPECT().HasContractInfo(ctx, contractAddress).Return(true) icaKeeper.EXPECT().GetActiveChannelID(ctx, "connection-0", portID).Return(activeChannel, true) capabilityKeeper.EXPECT().GetCapability(ctx, host.ChannelCapabilityPath(portID, activeChannel)).Return(&capability, true) + refundKeeper.EXPECT().GetPayerInfo(ctx, submitMsg.FromAddress, submitMsg.Fee.Payer).Return( + feerefundertypes.PayerInfo{Sender: contractAddress, FeePayer: sdk.AccAddress{}}, nil, + ) channelKeeper.EXPECT().GetNextSequenceSend(ctx, portID, activeChannel).Return(uint64(0), false) resp, err = icak.SubmitTx(goCtx, &submitMsg) require.Nil(t, resp) @@ -154,8 +170,10 @@ func TestSubmitTx(t *testing.T) { cmKeeper.EXPECT().HasContractInfo(ctx, contractAddress).Return(true) icaKeeper.EXPECT().GetActiveChannelID(ctx, "connection-0", portID).Return(activeChannel, true) capabilityKeeper.EXPECT().GetCapability(ctx, host.ChannelCapabilityPath(portID, activeChannel)).Return(&capability, true) + payer := feerefundertypes.PayerInfo{Sender: contractAddress, FeePayer: sdk.AccAddress{}} + refundKeeper.EXPECT().GetPayerInfo(ctx, submitMsg.FromAddress, submitMsg.Fee.Payer).Return(payer, nil) channelKeeper.EXPECT().GetNextSequenceSend(ctx, portID, activeChannel).Return(sequence, true) - refundKeeper.EXPECT().LockFees(ctx, contractAddress, feerefundertypes.NewPacketID(portID, activeChannel, sequence), submitMsg.Fee).Return(fmt.Errorf("failed to lock fees")) + refundKeeper.EXPECT().LockFees(ctx, payer, feerefundertypes.NewPacketID(portID, activeChannel, sequence), submitMsg.Fee).Return(fmt.Errorf("failed to lock fees")) resp, err = icak.SubmitTx(goCtx, &submitMsg) require.Nil(t, resp) require.ErrorContains(t, err, "failed to lock fees to pay for SubmitTx msg") @@ -173,7 +191,9 @@ func TestSubmitTx(t *testing.T) { icaKeeper.EXPECT().GetActiveChannelID(ctx, "connection-0", portID).Return(activeChannel, true) capabilityKeeper.EXPECT().GetCapability(ctx, host.ChannelCapabilityPath(portID, activeChannel)).Return(&capability, true) channelKeeper.EXPECT().GetNextSequenceSend(ctx, portID, activeChannel).Return(sequence, true) - refundKeeper.EXPECT().LockFees(ctx, contractAddress, feerefundertypes.NewPacketID(portID, activeChannel, sequence), submitMsg.Fee).Return(nil) + payerInfo := feerefundertypes.PayerInfo{Sender: contractAddress, FeePayer: sdk.AccAddress{}} + refundKeeper.EXPECT().GetPayerInfo(ctx, submitMsg.FromAddress, submitMsg.Fee.Payer).Return(payerInfo, nil) + refundKeeper.EXPECT().LockFees(ctx, payerInfo, feerefundertypes.NewPacketID(portID, activeChannel, sequence), submitMsg.Fee).Return(nil) icaKeeper.EXPECT().SendTx(ctx, &capability, "connection-0", portID, packetData, uint64(timeoutTimestamp)).Return(uint64(0), fmt.Errorf("faile to send tx")) resp, err = icak.SubmitTx(goCtx, &submitMsg) require.Nil(t, resp) @@ -183,7 +203,8 @@ func TestSubmitTx(t *testing.T) { icaKeeper.EXPECT().GetActiveChannelID(ctx, "connection-0", portID).Return(activeChannel, true) capabilityKeeper.EXPECT().GetCapability(ctx, host.ChannelCapabilityPath(portID, activeChannel)).Return(&capability, true) channelKeeper.EXPECT().GetNextSequenceSend(ctx, portID, activeChannel).Return(sequence, true) - refundKeeper.EXPECT().LockFees(ctx, contractAddress, feerefundertypes.NewPacketID(portID, activeChannel, sequence), submitMsg.Fee).Return(nil) + refundKeeper.EXPECT().GetPayerInfo(ctx, submitMsg.FromAddress, submitMsg.Fee.Payer).Return(payerInfo, nil) + refundKeeper.EXPECT().LockFees(ctx, payerInfo, feerefundertypes.NewPacketID(portID, activeChannel, sequence), submitMsg.Fee).Return(nil) icaKeeper.EXPECT().SendTx(ctx, &capability, "connection-0", portID, packetData, uint64(timeoutTimestamp)).Return(uint64(0), nil) resp, err = icak.SubmitTx(goCtx, &submitMsg) require.Equal(t, types.MsgSubmitTxResponse{ diff --git a/x/interchaintxs/types/expected_keepers.go b/x/interchaintxs/types/expected_keepers.go index 75dddf149..f80905c4c 100644 --- a/x/interchaintxs/types/expected_keepers.go +++ b/x/interchaintxs/types/expected_keepers.go @@ -41,9 +41,10 @@ type ICAControllerKeeper interface { } type FeeRefunderKeeper interface { - LockFees(ctx sdk.Context, payer sdk.AccAddress, packetID feerefundertypes.PacketID, fee feerefundertypes.Fee) error + LockFees(ctx sdk.Context, payer feerefundertypes.PayerInfo, packetID feerefundertypes.PacketID, fee feerefundertypes.Fee) error DistributeAcknowledgementFee(ctx sdk.Context, receiver sdk.AccAddress, packetID feerefundertypes.PacketID) DistributeTimeoutFee(ctx sdk.Context, receiver sdk.AccAddress, packetID feerefundertypes.PacketID) + GetPayerInfo(ctx sdk.Context, sender string, payer string) (feerefundertypes.PayerInfo, error) } type ScopedKeeper interface { diff --git a/x/interchaintxs/types/tx.pb.go b/x/interchaintxs/types/tx.pb.go index 7738be28a..dbc65e268 100644 --- a/x/interchaintxs/types/tx.pb.go +++ b/x/interchaintxs/types/tx.pb.go @@ -6,12 +6,12 @@ package types import ( context "context" fmt "fmt" + _ "github.com/cosmos/cosmos-proto" types "github.com/cosmos/cosmos-sdk/codec/types" _ "github.com/gogo/protobuf/gogoproto" grpc1 "github.com/gogo/protobuf/grpc" proto "github.com/gogo/protobuf/proto" types1 "github.com/neutron-org/neutron/x/feerefunder/types" - _ "github.com/regen-network/cosmos-proto" _ "google.golang.org/genproto/googleapis/api/annotations" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" @@ -161,7 +161,7 @@ var xxx_messageInfo_MsgSubmitTx proto.InternalMessageInfo // MsgSubmitTxResponse defines the response for Msg/SubmitTx type MsgSubmitTxResponse struct { - // channel's sequence_id for outgoing ibc packet. Unique per a channel. + // channel's sequence_id for outgoing ibc packet. Unique per a channel.y SequenceId uint64 `protobuf:"varint,1,opt,name=sequence_id,json=sequenceId,proto3" json:"sequence_id,omitempty"` // channel src channel on neutron side trasaction was submitted from Channel string `protobuf:"bytes,2,opt,name=channel,proto3" json:"channel,omitempty"` diff --git a/x/transfer/keeper/keeper.go b/x/transfer/keeper/keeper.go index 797fefe36..1ddcc17b6 100644 --- a/x/transfer/keeper/keeper.go +++ b/x/transfer/keeper/keeper.go @@ -27,10 +27,9 @@ type KeeperTransferWrapper struct { func (k KeeperTransferWrapper) Transfer(goCtx context.Context, msg *wrappedtypes.MsgTransfer) (*wrappedtypes.MsgTransferResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - senderAddr, err := sdk.AccAddressFromBech32(msg.Sender) + payerInfo, err := k.FeeKeeper.GetPayerInfo(ctx, msg.Sender, msg.Fee.Payer) if err != nil { - k.Logger(ctx).Debug("Transfer: failed to parse sender address", "sender", msg.Sender) - return nil, sdkerrors.Wrapf(sdkerrors.ErrInvalidAddress, "failed to parse address: %s", msg.Sender) + return nil, sdkerrors.Wrapf(err, "failed to get payer info for sender: %s, payer: %s", msg.Sender, msg.Fee.Payer) } sequence, found := k.channelKeeper.GetNextSequenceSend(ctx, msg.SourcePort, msg.SourceChannel) @@ -43,8 +42,8 @@ func (k KeeperTransferWrapper) Transfer(goCtx context.Context, msg *wrappedtypes // if the sender is a contract, lock fees. // Because contracts are required to pay fees for the acknowledgements - if k.ContractManagerKeeper.HasContractInfo(ctx, senderAddr) { - if err := k.FeeKeeper.LockFees(ctx, senderAddr, feetypes.NewPacketID(msg.SourcePort, msg.SourceChannel, sequence), msg.Fee); err != nil { + if k.ContractManagerKeeper.HasContractInfo(ctx, payerInfo.Sender) { + if err := k.FeeKeeper.LockFees(ctx, payerInfo, feetypes.NewPacketID(msg.SourcePort, msg.SourceChannel, sequence), msg.Fee); err != nil { return nil, sdkerrors.Wrapf(err, "failed to lock fees to pay for transfer msg: %v", msg) } } diff --git a/x/transfer/types/expected_keepers.go b/x/transfer/types/expected_keepers.go index 3b6b60914..88e5d2a49 100644 --- a/x/transfer/types/expected_keepers.go +++ b/x/transfer/types/expected_keepers.go @@ -18,9 +18,10 @@ type ContractManagerKeeper interface { } type FeeRefunderKeeper interface { - LockFees(ctx sdk.Context, payer sdk.AccAddress, packetID feerefundertypes.PacketID, fee feerefundertypes.Fee) error + LockFees(ctx sdk.Context, payer feerefundertypes.PayerInfo, packetID feerefundertypes.PacketID, fee feerefundertypes.Fee) error DistributeAcknowledgementFee(ctx sdk.Context, receiver sdk.AccAddress, packetID feerefundertypes.PacketID) DistributeTimeoutFee(ctx sdk.Context, receiver sdk.AccAddress, packetID feerefundertypes.PacketID) + GetPayerInfo(ctx sdk.Context, sender string, payer string) (feerefundertypes.PayerInfo, error) } // ChannelKeeper defines the expected IBC channel keeper