diff --git a/core/interfaces.go b/core/interfaces.go index c78455c9..741cdfe0 100644 --- a/core/interfaces.go +++ b/core/interfaces.go @@ -5,7 +5,7 @@ import "github.com/ethereum/go-ethereum/rpc" // Func is the interface that wraps the methods for ABI encoding and decoding. type Func interface { - // EncodeArgs ABI-encodes the given args and prepends the Func's four-byte + // EncodeArgs ABI-encodes the given args and prepends the Func's 4-byte // selector. EncodeArgs(args ...interface{}) (input []byte, err error) diff --git a/func.go b/func.go index e36bf695..b95463f1 100644 --- a/func.go +++ b/func.go @@ -6,10 +6,9 @@ import ( "fmt" "reflect" - _abi "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/crypto" - "github.com/lmittmann/w3/core" - "github.com/lmittmann/w3/internal/abi" + _abi "github.com/lmittmann/w3/internal/abi" ) var ( @@ -19,25 +18,28 @@ var ( ErrInvalidType = errors.New("w3: invalid type") ErrEvmRevert = errors.New("w3: evm reverted") - revertSelector = crypto.Keccak256([]byte("Error(string)"))[:4] - approveSelector = crypto.Keccak256([]byte("approve(address,uint256)"))[:4] - transferSelector = crypto.Keccak256([]byte("transfer(address,uint256)"))[:4] - transferFromSelector = crypto.Keccak256([]byte("transferFrom(address,address,uint256)"))[:4] + revertSelector = selector("Error(string)") + approveSelector = selector("approve(address,uint256)") + transferSelector = selector("transfer(address,uint256)") + transferFromSelector = selector("transferFrom(address,address,uint256)") ) -type abiFunc struct { - Signature string - Selector []byte // four-byte selector - Args _abi.Arguments // input - Returns _abi.Arguments // output +// Func represents a Smart Contract function ABI binding. +// +// Func implements the core.Func interface. +type Func struct { + Signature string // Function signature + Selector [4]byte // 4-byte selector + Args abi.Arguments // Input arguments + Returns abi.Arguments // Output returns } // NewFunc returns a new Smart Contract function ABI binding from the given // Solidity function signature and its returns. // // An error is returned if the signature or returns parsing fails. -func NewFunc(signature, returns string) (core.Func, error) { - args, err := abi.Parse(signature) +func NewFunc(signature, returns string) (*Func, error) { + args, err := _abi.Parse(signature) if err != nil { return nil, fmt.Errorf("%w: %v", ErrInvalidABI, err) } @@ -45,7 +47,7 @@ func NewFunc(signature, returns string) (core.Func, error) { return nil, fmt.Errorf("%w: missing function name", ErrInvalidABI) } - returnArgs, err := abi.Parse(returns) + returnArgs, err := _abi.Parse(returns) if err != nil { return nil, fmt.Errorf("%w: %v", ErrInvalidABI, err) } @@ -53,9 +55,9 @@ func NewFunc(signature, returns string) (core.Func, error) { return nil, fmt.Errorf("%w: returns must not have a function name", ErrInvalidABI) } - return &abiFunc{ + return &Func{ Signature: args.Sig, - Selector: crypto.Keccak256([]byte(args.Sig))[:4], + Selector: selector(args.Sig), Args: args.Args, Returns: returnArgs.Args, }, nil @@ -63,7 +65,7 @@ func NewFunc(signature, returns string) (core.Func, error) { // MustNewFunc is like NewFunc but panics if the signature or returns parsing // fails. -func MustNewFunc(signature, returns string) core.Func { +func MustNewFunc(signature, returns string) *Func { fn, err := NewFunc(signature, returns) if err != nil { panic(err) @@ -71,9 +73,9 @@ func MustNewFunc(signature, returns string) core.Func { return fn } -// EncodeArgs ABI-encodes the given args and prepends the Func's four-byte +// EncodeArgs ABI-encodes the given args and prepends the Func's 4-byte // selector. -func (f *abiFunc) EncodeArgs(args ...interface{}) ([]byte, error) { +func (f *Func) EncodeArgs(args ...interface{}) ([]byte, error) { if len(f.Args) != len(args) { return nil, fmt.Errorf("%w: expected %d arguments, got %d", ErrArgumentMismatch, len(f.Args), len(args)) } @@ -83,11 +85,11 @@ func (f *abiFunc) EncodeArgs(args ...interface{}) ([]byte, error) { return nil, err } - return append(f.Selector, input...), nil + return append(f.Selector[:], input...), nil } // DecodeArgs ABI-decodes the given input to the given args. -func (f *abiFunc) DecodeArgs(input []byte, args ...interface{}) error { +func (f *Func) DecodeArgs(input []byte, args ...interface{}) error { if len(f.Args) != len(args) { return fmt.Errorf("%w: expected %d arguments, got %d", ErrArgumentMismatch, len(f.Args), len(args)) } @@ -106,14 +108,14 @@ func (f *abiFunc) DecodeArgs(input []byte, args ...interface{}) error { } // DecodeReturns ABI-decodes the given output to the given returns. -func (f *abiFunc) DecodeReturns(output []byte, returns ...interface{}) error { +func (f *Func) DecodeReturns(output []byte, returns ...interface{}) error { if len(f.Returns) != len(returns) { return fmt.Errorf("%w: expected %d returns, got %d", ErrReturnsMismatch, len(f.Returns), len(returns)) } // check the output for a revert reason - if bytes.HasPrefix(output, revertSelector) { - if reason, err := _abi.UnpackRevert(output); err != nil { + if bytes.HasPrefix(output, revertSelector[:]) { + if reason, err := abi.UnpackRevert(output); err != nil { return err } else { return fmt.Errorf("%w: %s", ErrEvmRevert, reason) @@ -121,13 +123,12 @@ func (f *abiFunc) DecodeReturns(output []byte, returns ...interface{}) error { } // Gracefully handle uncompliant ERC20 returns - if len(output) == 0 && - (bytes.Equal(f.Selector, approveSelector) || - bytes.Equal(f.Selector, transferSelector) || - bytes.Equal(f.Selector, transferFromSelector)) && - len(returns) == 1 { + if len(returns) == 1 && len(output) == 0 && + (f.Selector == approveSelector || + f.Selector == transferSelector || + f.Selector == transferFromSelector) { - if err := copyVal(_abi.BoolTy, returns[0], true); err != nil { + if err := copyVal(abi.BoolTy, returns[0], true); err != nil { return err } return nil @@ -156,7 +157,7 @@ func copyVal(t byte, dst, src interface{}) (err error) { rSrc := reflect.ValueOf(src) switch t { - case _abi.TupleTy: + case abi.TupleTy: err = copyTuple(rDst, rSrc) default: err = copyNonTuple(rDst, rSrc) @@ -199,3 +200,9 @@ func copyTuple(rDst, rSrc reflect.Value) error { } return nil } + +// selector returns the 4-byte selector of the given signature. +func selector(signature string) (selector [4]byte) { + copy(selector[:], crypto.Keccak256([]byte(signature))) + return +} diff --git a/func_test.go b/func_test.go index bd93b335..ab2e4089 100644 --- a/func_test.go +++ b/func_test.go @@ -6,13 +6,53 @@ import ( "strconv" "testing" - _abi "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/lmittmann/w3/core" ) +func TestNewFunc(t *testing.T) { + t.Parallel() + + tests := []struct { + Signature string + Returns string + WantFunc *Func + }{ + { + Signature: "transfer(address,uint256)", + Returns: "bool", + WantFunc: &Func{ + Signature: "transfer(address,uint256)", + Selector: [4]byte{0xa9, 0x05, 0x9c, 0xbb}, + }, + }, + { + Signature: "transfer(address recipient, uint256 amount)", + Returns: "bool success", + WantFunc: &Func{ + Signature: "transfer(address,uint256)", + Selector: [4]byte{0xa9, 0x05, 0x9c, 0xbb}, + }, + }, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + gotFunc, err := NewFunc(test.Signature, test.Returns) + if err != nil { + t.Fatalf("Failed to create new FUnc: %v", err) + } + + if diff := cmp.Diff(test.WantFunc, gotFunc, cmpopts.IgnoreFields(Func{}, "Args", "Returns")); diff != "" { + t.Fatalf("(-want, +got)\n%s", diff) + } + }) + } +} + func TestEncodeArgs(t *testing.T) { t.Parallel() @@ -242,23 +282,23 @@ func TestCopyValue(t *testing.T) { WantErr error }{ { - T: _abi.UintTy, + T: abi.UintTy, Dst: new(big.Int), Src: big.NewInt(42), }, { - T: _abi.UintTy, + T: abi.UintTy, Dst: new(big.Int), Src: big.NewInt(42), }, { - T: _abi.UintTy, + T: abi.UintTy, Dst: new(big.Int), Src: []byte{1, 2, 3}, WantErr: ErrInvalidType, }, { - T: _abi.BytesTy, + T: abi.BytesTy, Dst: &[]byte{}, Src: &[]byte{1, 2, 3}, },