From 255a301a715119a9ca94735a4bc8f7c1becd9645 Mon Sep 17 00:00:00 2001 From: Lyova Potyomkin Date: Mon, 13 Jan 2025 11:13:59 +0200 Subject: [PATCH] feat: allow approval based paymasters (#251) * feat: allow approval-based paymasters with proper checks * fix: calldata -> memory * fix: use load instead of slice * test: add fee limit & paymaster tests * fix: minor test tweaks * fix: avoid code repetition --- hardhat.config.ts | 2 +- src/libraries/SessionLib.sol | 96 ++++++++++---- src/test/TestPaymaster.sol | 57 ++++++++ src/validators/SessionKeyValidator.sol | 1 - test/SessionKeyTest.ts | 177 ++++++++++++++++++++++--- test/utils.ts | 30 ++++- 6 files changed, 314 insertions(+), 49 deletions(-) create mode 100644 src/test/TestPaymaster.sol diff --git a/hardhat.config.ts b/hardhat.config.ts index 72650e9a..8d0dec68 100644 --- a/hardhat.config.ts +++ b/hardhat.config.ts @@ -45,7 +45,7 @@ const config: HardhatUserConfig = { }, }, zksolc: { - version: "1.5.7", + version: "1.5.9", settings: { // https://era.zksync.io/docs/tools/hardhat/hardhat-zksync-solc.html#configuration // Native AA calls an internal system contract, so it needs extra permissions diff --git a/src/libraries/SessionLib.sol b/src/libraries/SessionLib.sol index a9995bc2..d766cccc 100644 --- a/src/libraries/SessionLib.sol +++ b/src/libraries/SessionLib.sol @@ -4,10 +4,13 @@ pragma solidity ^0.8.24; import { Transaction } from "@matterlabs/zksync-contracts/l2/system-contracts/libraries/TransactionHelper.sol"; import { IPaymasterFlow } from "@matterlabs/zksync-contracts/l2/system-contracts/interfaces/IPaymasterFlow.sol"; import { TimestampAsserterLocator } from "../helpers/TimestampAsserterLocator.sol"; +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { LibBytes } from "solady/src/utils/LibBytes.sol"; library SessionLib { using SessionLib for SessionLib.Constraint; using SessionLib for SessionLib.UsageLimit; + using LibBytes for bytes; // We do not permit session keys to be reused to open multiple sessions // (after one expires or is closed, e.g.). @@ -136,11 +139,11 @@ library SessionLib { function checkAndUpdate( Constraint memory constraint, UsageTracker storage tracker, - bytes calldata data, + bytes memory data, uint64 period ) internal { - uint256 index = 4 + constraint.index * 32; - bytes32 param = bytes32(data[index:index + 32]); + require(data.length >= 4 + constraint.index * 32 + 32, "Invalid data length"); + bytes32 param = data.load(4 + constraint.index * 32); Condition condition = constraint.condition; bytes32 refValue = constraint.refValue; @@ -161,6 +164,35 @@ library SessionLib { constraint.limit.checkAndUpdate(tracker, uint256(param), period); } + function checkCallPolicy( + SessionStorage storage state, + bytes memory data, + address target, + bytes4 selector, + CallSpec[] memory callPolicies, + uint64[] memory periodIds, + uint256 periodIdsOffset + ) internal returns (CallSpec memory) { + CallSpec memory callPolicy; + bool found = false; + + for (uint256 i = 0; i < callPolicies.length; i++) { + if (callPolicies[i].target == target && callPolicies[i].selector == selector) { + callPolicy = callPolicies[i]; + found = true; + break; + } + } + + require(found, "Call to this contract is not allowed"); + + for (uint256 i = 0; i < callPolicy.constraints.length; i++) { + callPolicy.constraints[i].checkAndUpdate(state.params[target][selector][i], data, periodIds[periodIdsOffset + i]); + } + + return callPolicy; + } + function validateFeeLimit( SessionStorage storage state, Transaction calldata transaction, @@ -193,9 +225,11 @@ library SessionLib { ) internal { // Here we additionally pass uint64[] periodId to check allowance limits // periodId is defined as block.timestamp / limit.period if limitType == Allowance, and 0 otherwise (which will be ignored). - // periodIds[0] is for fee limit, + // periodIds[0] is for fee limit (not used in this function), // periodIds[1] is for value limit, - // periodIds[2:] are for call constraints, if there are any. + // peroidIds[2:2+n] are for `ERC20.approve()` constraints, if an approval-based paymaster is used + // where `n` is the number of constraints in the `ERC20.approve()` policy if an approval-based paymaster is used, 0 otherwise. + // periodIds[2+n:] are for call constraints, if there are any. // It is required to pass them in (instead of computing via block.timestamp) since during validation // we can only assert the range of the timestamp, but not access its value. @@ -205,34 +239,44 @@ library SessionLib { require(transaction.to <= type(uint160).max, "Overflow"); address target = address(uint160(transaction.to)); + // Validate paymaster input + uint256 periodIdsOffset = 2; if (transaction.paymasterInput.length >= 4) { - bytes4 paymasterInputSelector = bytes4(transaction.paymasterInput[0:4]); - require( - paymasterInputSelector != IPaymasterFlow.approvalBased.selector, - "Approval based paymaster flow not allowed" - ); + bytes4 paymasterInputSelector = bytes4(transaction.paymasterInput[:4]); + // SsoAccount will automatically `approve()` a token for an approval-based paymaster in `prepareForPaymaster()` call. + // We need to make sure that the session spec allows this. + if (paymasterInputSelector == IPaymasterFlow.approvalBased.selector) { + require(transaction.paymasterInput.length >= 68, "Invalid paymaster input length"); + (address token, uint256 amount, ) = abi.decode(transaction.paymasterInput[4:], (address, uint256, bytes)); + bytes memory data = abi.encodeWithSelector(IERC20.approve.selector, transaction.paymaster, amount); + + // check that session allows .approve() for this token + CallSpec memory approvePolicy = checkCallPolicy( + state, + data, + token, + IERC20.approve.selector, + spec.callPolicies, + periodIds, + periodIdsOffset + ); + periodIdsOffset += approvePolicy.constraints.length; + } } if (transaction.data.length >= 4) { bytes4 selector = bytes4(transaction.data[:4]); - CallSpec memory callPolicy; - bool found = false; - - for (uint256 i = 0; i < spec.callPolicies.length; i++) { - if (spec.callPolicies[i].target == target && spec.callPolicies[i].selector == selector) { - callPolicy = spec.callPolicies[i]; - found = true; - break; - } - } - - require(found, "Call to this contract is not allowed"); + CallSpec memory callPolicy = checkCallPolicy( + state, + transaction.data, + target, + selector, + spec.callPolicies, + periodIds, + periodIdsOffset + ); require(transaction.value <= callPolicy.maxValuePerUse, "Value exceeds limit"); callPolicy.valueLimit.checkAndUpdate(state.callValue[target][selector], transaction.value, periodIds[1]); - - for (uint256 i = 0; i < callPolicy.constraints.length; i++) { - callPolicy.constraints[i].checkAndUpdate(state.params[target][selector][i], transaction.data, periodIds[i + 2]); - } } else { TransferSpec memory transferPolicy; bool found = false; diff --git a/src/test/TestPaymaster.sol b/src/test/TestPaymaster.sol new file mode 100644 index 00000000..7b091d7e --- /dev/null +++ b/src/test/TestPaymaster.sol @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import { IERC20 } from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import { IPaymaster, ExecutionResult, PAYMASTER_VALIDATION_SUCCESS_MAGIC } from "@matterlabs/zksync-contracts/l2/system-contracts/interfaces/IPaymaster.sol"; +import { IPaymasterFlow } from "@matterlabs/zksync-contracts/l2/system-contracts/interfaces/IPaymasterFlow.sol"; +import { TransactionHelper, Transaction } from "@matterlabs/zksync-contracts/l2/system-contracts/libraries/TransactionHelper.sol"; +import "@matterlabs/zksync-contracts/l2/system-contracts/Constants.sol"; + +contract TestPaymaster is IPaymaster { + modifier onlyBootloader() { + require(msg.sender == BOOTLOADER_FORMAL_ADDRESS, "Only bootloader can call this method"); + // Continue execution if called from the bootloader. + _; + } + + function validateAndPayForPaymasterTransaction( + bytes32, + bytes32, + Transaction calldata transaction + ) external payable onlyBootloader returns (bytes4 magic, bytes memory) { + magic = PAYMASTER_VALIDATION_SUCCESS_MAGIC; + + bytes4 paymasterInputSelector = bytes4(transaction.paymasterInput[:4]); + if (paymasterInputSelector == IPaymasterFlow.approvalBased.selector) { + (address token, uint256 amount, bytes memory data) = abi.decode( + transaction.paymasterInput[4:], + (address, uint256, bytes) + ); + + uint256 providedAllowance = IERC20(token).allowance(address(uint160(transaction.from)), address(this)); + + // For testing purposes any non-zero allowance of any token is enough + require(providedAllowance > 0, "Min allowance too low"); + IERC20(token).transferFrom(address(uint160(transaction.from)), address(this), amount); + } else if (paymasterInputSelector == IPaymasterFlow.general.selector) { + // For testing purposes any transaction is valid + } else { + revert("Unsupported paymaster flow"); + } + + uint256 requiredETH = transaction.gasLimit * transaction.maxFeePerGas; + (bool success, ) = payable(BOOTLOADER_FORMAL_ADDRESS).call{ value: requiredETH }(""); + require(success, "Paymaster out of funds"); + } + + function postTransaction( + bytes calldata _context, + Transaction calldata transaction, + bytes32, + bytes32, + ExecutionResult _txResult, + uint256 _maxRefundedGas + ) external payable override onlyBootloader {} + + receive() external payable {} +} diff --git a/src/validators/SessionKeyValidator.sol b/src/validators/SessionKeyValidator.sol index 8d3120dd..c3403933 100644 --- a/src/validators/SessionKeyValidator.sol +++ b/src/validators/SessionKeyValidator.sol @@ -93,7 +93,6 @@ contract SessionKeyValidator is IModuleValidator { interfaceId == type(IModule).interfaceId; } - // TODO: make the session owner able revoke its own key, in case it was leaked, to prevent further misuse? function revokeKey(bytes32 sessionHash) public { require(sessions[sessionHash].status[msg.sender] == SessionLib.Status.Active, "Nothing to revoke"); sessions[sessionHash].status[msg.sender] = SessionLib.Status.Closed; diff --git a/test/SessionKeyTest.ts b/test/SessionKeyTest.ts index 8e1ba5dc..b8331055 100644 --- a/test/SessionKeyTest.ts +++ b/test/SessionKeyTest.ts @@ -6,8 +6,8 @@ import { it } from "mocha"; import { SmartAccount, utils } from "zksync-ethers"; import type { ERC20 } from "../typechain-types"; -import { SsoBeacon__factory, SsoAccount__factory, SessionKeyValidator__factory } from "../typechain-types"; -import type { SsoBeacon } from "../typechain-types/src/SsoBeacon"; +import { SsoBeacon__factory, SsoAccount__factory, SessionKeyValidator__factory, TestPaymaster__factory } from "../typechain-types"; +import type { IPaymasterFlow, SsoBeacon, TestPaymaster } from "../typechain-types" import type { SessionLib } from "../typechain-types/src/validators/SessionKeyValidator"; import { ContractFixtures, getProvider, logInfo } from "./utils"; @@ -59,6 +59,21 @@ type PartialSession = { }[]; }; +type PaymasterParams = { + paymaster: string; + paymasterInput: string; +}; + +interface TransactionLike extends ethers.TransactionLike { + periodIds?: number[]; + paymasterParams?: PaymasterParams; + customData?: { + gasPerPubdata: number; + customSignature?: string; + paymasterParams?: PaymasterParams; + }; +} + async function getTimestamp() { if (hre.network.name == "inMemoryNode") { return Math.floor(await provider.send("config_getCurrentTimestamp", [])); @@ -92,7 +107,7 @@ class SessionTester { public session: SessionLib.SessionSpecStruct; public sessionAccount: SmartAccount; // having this is a bit hacky, but it's so we can provide correct period ids in the signature - aaTransaction: ethers.TransactionLike; + aaTransaction: TransactionLike; constructor(public proxyAccountAddress: string, sessionKeyModuleAddress: string) { this.sessionOwner = new Wallet(Wallet.createRandom().privateKey, provider); @@ -104,7 +119,12 @@ class SessionTester { sessionKeyModuleAddress, abiCoder.encode( [sessionSpecAbi, "uint64[]"], - [this.session, await this.periodIds(this.aaTransaction.to!, this.aaTransaction.data?.slice(0, 10))], + [ + this.session, + this.aaTransaction.periodIds + ? this.aaTransaction.periodIds + : await this.periodIds(this.aaTransaction.to!, this.aaTransaction.data?.slice(0, 10)) + ], ), ], ), @@ -185,26 +205,30 @@ class SessionTester { logInfo(`transaction gas used: ${receipt.gasUsed.toString()}`); } - async sessionTxSuccess(txRequest: ethers.TransactionLike = {}) { + async sessionTxSuccess(tx: TransactionLike = {}) { + const periodIds = tx.periodIds ?? await this.periodIds(tx.to!, tx.data?.slice(0, 10)); this.aaTransaction = { - ...await this.aaTxTemplate(await this.periodIds(txRequest.to!, txRequest.data?.slice(0, 10))), - ...txRequest, + ...await this.aaTxTemplate(periodIds), + ...tx, }; + this.aaTransaction.customData!.paymasterParams ??= tx.paymasterParams; this.aaTransaction.gasLimit = await provider.estimateGas(this.aaTransaction); logInfo(`\`sessionTx\` gas estimated: ${this.aaTransaction.gasLimit}`); const signedTransaction = await this.sessionAccount.signTransaction(this.aaTransaction); - const tx = await provider.broadcastTransaction(signedTransaction); - const receipt = await tx.wait(); + const sentTx = await provider.broadcastTransaction(signedTransaction); + const receipt = await sentTx.wait(); logInfo(`\`sessionTx\` gas used: ${receipt.gasUsed}`); } - async sessionTxFail(tx: ethers.TransactionLike = {}) { + async sessionTxFail(tx: TransactionLike = {}) { + const periodIds = tx.periodIds ?? await this.periodIds(tx.to!, tx.data?.slice(0, 10)); this.aaTransaction = { - ...await this.aaTxTemplate(await this.periodIds(tx.to!, tx.data?.slice(0, 10))), + ...await this.aaTxTemplate(periodIds), gasLimit: 100_000_000n, ...tx, }; + this.aaTransaction.customData!.paymasterParams ??= tx.paymasterParams; const signedTransaction = await this.sessionAccount.signTransaction(this.aaTransaction); await expect(provider.broadcastTransaction(signedTransaction)).to.be.reverted; @@ -384,13 +408,13 @@ describe("SessionKeyModule tests", function () { let erc20: ERC20; const sessionTarget = Wallet.createRandom().address; - it("should deploy and mint an ERC20 token", async () => { + before("should deploy and mint an ERC20 token", async () => { erc20 = await fixtures.deployERC20(proxyAccountAddress); expect(await erc20.balanceOf(proxyAccountAddress)).to.equal(10n ** 18n, "should have some tokens"); + tester = new SessionTester(proxyAccountAddress, await fixtures.getSessionKeyModuleAddress()); }); it("should create a session", async () => { - tester = new SessionTester(proxyAccountAddress, await fixtures.getSessionKeyModuleAddress()); await tester.createSession({ callPolicies: [{ target: await erc20.getAddress(), @@ -546,12 +570,133 @@ describe("SessionKeyModule tests", function () { }); }); + describe("Fee limit & Paymaster tests", function () { + let tester: SessionTester; + let erc20: ERC20; + let paymaster: TestPaymaster; + const sessionTarget = Wallet.createRandom().address; + let paymasterFlow: IPaymasterFlow; + + before("should deploy ERC20 token and test paymaster", async () => { + erc20 = await fixtures.deployERC20(proxyAccountAddress); + expect(await erc20.balanceOf(proxyAccountAddress)).to.gt(10n ** 15n, "should have some tokens"); + paymaster = await fixtures.deployTestPaymaster(); + paymasterFlow = await hre.ethers.getContractAt("IPaymasterFlow", ethers.ZeroAddress); + tester = new SessionTester(proxyAccountAddress, await fixtures.getSessionKeyModuleAddress()); + // fund paymaster + const tx = await fixtures.wallet.sendTransaction({ to: await paymaster.getAddress(), value: parseEther("1") }); + await tx.wait(); + }); + + it("should create a session with a fee limit", async () => { + await tester.createSession({ + feeLimit: { + limit: parseEther("0.01"), + }, + transferPolicies: [{ + target: sessionTarget, + maxValuePerUse: parseEther("0.01"), + }], + }); + }); + + it("should update fee limit after sending a transaction", async () => { + await tester.sessionTxSuccess({ + to: sessionTarget, + value: parseEther("0.01"), + }); + // @ts-ignore + const gas = tester.aaTransaction.gasLimit * tester.aaTransaction.gasPrice; + const sessionKeyModuleContract = await fixtures.getSessionKeyContract(); + const state = await sessionKeyModuleContract.sessionState(proxyAccountAddress, tester.session); + expect(state.feesRemaining).to.equal(parseEther("0.01") - gas, "should have deducted gas fees"); + expect(await provider.getBalance(sessionTarget)).to.equal(parseEther("0.01"), "session target should have received the funds"); + }); + + it("should send a transaction using general paymaster and ignore fee limit", async () => { + const sessionKeyModuleContract = await fixtures.getSessionKeyContract(); + const oldState = await sessionKeyModuleContract.sessionState(proxyAccountAddress, tester.session); + await tester.sessionTxSuccess({ + to: sessionTarget, + value: parseEther("0.01"), + paymasterParams: { + paymaster: await paymaster.getAddress(), + paymasterInput: paymasterFlow.interface.encodeFunctionData("general", ["0x"]), + }, + }); + const newState = await sessionKeyModuleContract.sessionState(proxyAccountAddress, tester.session); + expect(newState.feesRemaining).to.equal(oldState.feesRemaining, "should not have deducted fees"); + expect(await provider.getBalance(sessionTarget)).to.equal(parseEther("0.02"), "session target should have received the funds"); + }); + + it("should fail sending a transaction using approval-based paymaster", async () => { + await tester.sessionTxFail({ + to: sessionTarget, + value: parseEther("0.01"), + paymasterParams: { + paymaster: await paymaster.getAddress(), + paymasterInput: paymasterFlow.interface.encodeFunctionData("approvalBased", [await erc20.getAddress(), 1000, "0x"]), + }, + periodIds: [0, 0] + }); + }); + + it("should create a different session that allows paying fees with ERC20", async () => { + await tester.createSession({ + feeLimit: { limit: 0, }, + transferPolicies: [{ + target: sessionTarget, + maxValuePerUse: parseEther("0.01"), + }], + callPolicies: [{ + target: await erc20.getAddress(), + selector: erc20.interface.getFunction("approve").selector, + constraints: [ + // // spender is paymaster + { + index: 0, + refValue: ethers.zeroPadValue(await paymaster.getAddress(), 32), + condition: Condition.Equal, + }, + // // amount is 1000 tokens (lifetime limit) + { + index: 1, + limit: { limit: 1000 }, + }, + ] + }], + }); + }); + + it("should send a transaction using approval-based paymaster", async () => { + const sessionKeyModuleContract = await fixtures.getSessionKeyContract(); + let state = await sessionKeyModuleContract.sessionState(proxyAccountAddress, tester.session); + expect(state.callParams[0].remaining).to.equal(1000, "should have 1000 tokens remaining to approve"); + const oldPaymasterBalance = await erc20.balanceOf(await paymaster.getAddress()); + await tester.sessionTxSuccess({ + to: sessionTarget, + value: parseEther("0.01"), + paymasterParams: { + paymaster: await paymaster.getAddress(), + paymasterInput: paymasterFlow.interface.encodeFunctionData("approvalBased", [await erc20.getAddress(), 1000, "0x"]), + }, + periodIds: [0, 0, 0, 0] + }); + const newPaymasterBalance = await erc20.balanceOf(await paymaster.getAddress()); + expect(newPaymasterBalance).to.equal(oldPaymasterBalance + 1000n, "paymaster should have received the approved amount"); + state = await sessionKeyModuleContract.sessionState(proxyAccountAddress, tester.session); + expect(state.callParams[0].remaining).to.equal(0, "should have deducted the approved amount"); + expect(await provider.getBalance(sessionTarget)).to.equal(parseEther("0.03"), "session target should have received the funds"); + }); + }); + describe("Module install/uninstall tests", function () { const ssoAbi = SsoAccount__factory.createInterface(); + let sessionModuleAddress: string; let tester: SessionTester; before(async () => { - const sessionModuleAddress = await fixtures.getSessionKeyModuleAddress(); + sessionModuleAddress = await fixtures.getSessionKeyModuleAddress(); tester = new SessionTester(proxyAccountAddress, sessionModuleAddress); }); @@ -569,7 +714,6 @@ describe("SessionKeyModule tests", function () { }); it("should uninstall the module", async () => { - const sessionModuleAddress = await fixtures.getSessionKeyModuleAddress(); await tester.sendAaTx(proxyAccountAddress, ssoAbi.encodeFunctionData("removeModuleValidator", [ sessionModuleAddress, abiCoder.encode(["bytes32[]"], [[]]) @@ -577,16 +721,13 @@ describe("SessionKeyModule tests", function () { }); it("should reinstall the module", async () => { - const sessionModuleAddress = await fixtures.getSessionKeyModuleAddress(); await tester.sendAaTx(proxyAccountAddress, ssoAbi.encodeFunctionData("addModuleValidator", [sessionModuleAddress, "0x"])); }); it("should unlink the module ignoring reverts", async () => { - const sessionModuleAddress = await fixtures.getSessionKeyModuleAddress(); // passing "0x" as the second argument would revert normally await tester.sendAaTx(proxyAccountAddress, ssoAbi.encodeFunctionData("unlinkModuleValidator", [sessionModuleAddress, "0x"])); }); }); - // TODO: session fee limit tests }); diff --git a/test/utils.ts b/test/utils.ts index 66e172db..479d2b85 100644 --- a/test/utils.ts +++ b/test/utils.ts @@ -8,11 +8,30 @@ import { promises } from "fs"; import * as hre from "hardhat"; import { ContractFactory, Provider, utils, Wallet } from "zksync-ethers"; import { base64UrlToUint8Array, getPublicKeyBytesFromPasskeySignature, unwrapEC2Signature } from "zksync-sso/utils"; - -import { AAFactory, ERC20, ExampleAuthServerPaymaster, SessionKeyValidator, SsoAccount, WebAuthValidator, SsoBeacon, AccountProxy__factory, AccountProxy } from "../typechain-types"; -import { AAFactory__factory, ERC20__factory, ExampleAuthServerPaymaster__factory, SessionKeyValidator__factory, SsoAccount__factory, WebAuthValidator__factory, SsoBeacon__factory } from "../typechain-types"; import { Address, isHex, toHex } from "viem"; +import type { + AAFactory, + ERC20, + ExampleAuthServerPaymaster, + SessionKeyValidator, + SsoAccount, + WebAuthValidator, + SsoBeacon, + AccountProxy +} from "../typechain-types"; +import { + AAFactory__factory, + AccountProxy__factory, + ERC20__factory, + ExampleAuthServerPaymaster__factory, + SessionKeyValidator__factory, + SsoAccount__factory, + WebAuthValidator__factory, + SsoBeacon__factory, + TestPaymaster__factory +} from "../typechain-types"; + export const ethersStaticSalt = new Uint8Array([ 205, 241, 161, 186, 101, 105, 79, 248, 98, 64, 50, 124, 168, 204, @@ -107,6 +126,11 @@ export class ContractFixtures { return ERC20__factory.connect(await contract.getAddress(), this.wallet); } + async deployTestPaymaster() { + const contract = await create2("TestPaymaster", this.wallet, ethersStaticSalt); + return TestPaymaster__factory.connect(await contract.getAddress(), this.wallet); + } + async deployExampleAuthServerPaymaster( aaFactoryAddress: string, sessionKeyValidatorAddress: string,