diff --git a/packages/shield-controller/CHANGELOG.md b/packages/shield-controller/CHANGELOG.md index 325e8f0f216..da6198508ea 100644 --- a/packages/shield-controller/CHANGELOG.md +++ b/packages/shield-controller/CHANGELOG.md @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Bump `@metamask/transaction-controller` from `^60.7.0` to `^60.8.0` ([#6883](https://github.com/MetaMask/core/pull/6883)) +- Updated internal coverage result polling and log logic. ([#6847](https://github.com/MetaMask/core/pull/6847)) + - Added cancellation logic to the polling. + - Updated implementation of timeout. + - Cancel any pending requests before starting new polling or logging. +- Updated TransactionMeta comparison in `TransactionController:stateChange` subscriber to avoid triggering multiple check coverage result unnecessarily. ([#6847](https://github.com/MetaMask/core/pull/6847)) +- Removed `Personal Sign` check from the check signature coverage result. ([#6847](https://github.com/MetaMask/core/pull/6847)) ## [0.3.2] diff --git a/packages/shield-controller/package.json b/packages/shield-controller/package.json index 4cc53300e2a..28e3c1f311a 100644 --- a/packages/shield-controller/package.json +++ b/packages/shield-controller/package.json @@ -48,7 +48,9 @@ }, "dependencies": { "@metamask/base-controller": "^8.4.1", - "@metamask/utils": "^11.8.1" + "@metamask/controller-utils": "^11.14.1", + "@metamask/utils": "^11.8.1", + "cockatiel": "^3.1.2" }, "devDependencies": { "@babel/runtime": "^7.23.9", diff --git a/packages/shield-controller/src/ShieldController.test.ts b/packages/shield-controller/src/ShieldController.test.ts index b741aaa5dfc..f223645421f 100644 --- a/packages/shield-controller/src/ShieldController.test.ts +++ b/packages/shield-controller/src/ShieldController.test.ts @@ -11,6 +11,7 @@ import { } from '@metamask/transaction-controller'; import { ShieldController } from './ShieldController'; +import { TX_META_SIMULATION_DATA_MOCKS } from '../tests/data'; import { createMockBackend, MOCK_COVERAGE_ID } from '../tests/mocks/backend'; import { createMockMessenger } from '../tests/mocks/messenger'; import { @@ -169,6 +170,47 @@ describe('ShieldController', () => { }); }); + TX_META_SIMULATION_DATA_MOCKS.forEach( + ({ description, previousSimulationData, newSimulationData }) => { + it(`should check coverage when ${description}`, async () => { + const { baseMessenger, backend } = setup(); + const previousTxMeta = { + ...generateMockTxMeta(), + simulationData: previousSimulationData, + }; + const coverageResultReceived = + setupCoverageResultReceived(baseMessenger); + + // Add transaction. + baseMessenger.publish( + 'TransactionController:stateChange', + { transactions: [previousTxMeta] } as TransactionControllerState, + undefined as never, + ); + expect(await coverageResultReceived).toBeUndefined(); + expect(backend.checkCoverage).toHaveBeenCalledWith({ + txMeta: previousTxMeta, + }); + + // Simulate transaction. + const txMeta2 = { ...previousTxMeta }; + txMeta2.simulationData = newSimulationData; + const coverageResultReceived2 = + setupCoverageResultReceived(baseMessenger); + baseMessenger.publish( + 'TransactionController:stateChange', + { transactions: [txMeta2] } as TransactionControllerState, + undefined as never, + ); + expect(await coverageResultReceived2).toBeUndefined(); + expect(backend.checkCoverage).toHaveBeenCalledWith({ + coverageId: MOCK_COVERAGE_ID, + txMeta: txMeta2, + }); + }); + }, + ); + it('throws an error when the coverage ID has changed', async () => { const { controller, backend } = setup(); backend.checkCoverage.mockResolvedValueOnce({ diff --git a/packages/shield-controller/src/ShieldController.ts b/packages/shield-controller/src/ShieldController.ts index 470f2982709..1b04e73fe67 100644 --- a/packages/shield-controller/src/ShieldController.ts +++ b/packages/shield-controller/src/ShieldController.ts @@ -5,7 +5,6 @@ import type { } from '@metamask/base-controller'; import { SignatureRequestStatus, - SignatureRequestType, type SignatureRequest, type SignatureStateChange, } from '@metamask/signature-controller'; @@ -236,10 +235,7 @@ export class ShieldController extends BaseController< // Check coverage if the signature request is new and has type // `personal_sign`. - if ( - !previousSignatureRequest && - signatureRequest.type === SignatureRequestType.PersonalSign - ) { + if (!previousSignatureRequest) { this.checkSignatureCoverage(signatureRequest).catch( // istanbul ignore next (error) => log('Error checking coverage:', error), @@ -268,15 +264,15 @@ export class ShieldController extends BaseController< ); for (const transaction of transactions) { const previousTransaction = previousTransactionsById.get(transaction.id); + // Check if the simulation data has changed. + const simulationDataChanged = this.#compareTransactionSimulationData( + transaction.simulationData, + previousTransaction?.simulationData, + ); // Check coverage if the transaction is new or if the simulation data has // changed. - if ( - !previousTransaction || - // Checking reference equality is sufficient because this object is - // replaced if the simulation data has changed. - previousTransaction.simulationData !== transaction.simulationData - ) { + if (!previousTransaction || simulationDataChanged) { this.checkCoverage(transaction).catch( // istanbul ignore next (error) => log('Error checking coverage:', error), @@ -443,4 +439,61 @@ export class ShieldController extends BaseController< #getLatestCoverageId(itemId: string): string | undefined { return this.state.coverageResults[itemId]?.results[0]?.coverageId; } + + /** + * Compares the simulation data of a transaction. + * + * @param simulationData - The simulation data of the transaction. + * @param previousSimulationData - The previous simulation data of the transaction. + * @returns Whether the simulation data has changed. + */ + #compareTransactionSimulationData( + simulationData?: TransactionMeta['simulationData'], + previousSimulationData?: TransactionMeta['simulationData'], + ) { + if (!simulationData && !previousSimulationData) { + return false; + } + + // check the simulation error + if ( + simulationData?.error?.code !== previousSimulationData?.error?.code || + simulationData?.error?.message !== previousSimulationData?.error?.message + ) { + return true; + } + + // check the native balance change + if ( + simulationData?.nativeBalanceChange?.difference !== + previousSimulationData?.nativeBalanceChange?.difference || + simulationData?.nativeBalanceChange?.newBalance !== + previousSimulationData?.nativeBalanceChange?.newBalance || + simulationData?.nativeBalanceChange?.previousBalance !== + previousSimulationData?.nativeBalanceChange?.previousBalance || + simulationData?.nativeBalanceChange?.isDecrease !== + previousSimulationData?.nativeBalanceChange?.isDecrease + ) { + return true; + } + + // check the token balance changes + if ( + simulationData?.tokenBalanceChanges?.length !== + previousSimulationData?.tokenBalanceChanges?.length || + simulationData?.tokenBalanceChanges?.some( + (tokenBalanceChange, index) => + tokenBalanceChange.difference !== + previousSimulationData?.tokenBalanceChanges?.[index]?.difference, + ) + ) { + return true; + } + + // check the isUpdatedAfterSecurityCheck + return ( + simulationData?.isUpdatedAfterSecurityCheck !== + previousSimulationData?.isUpdatedAfterSecurityCheck + ); + } } diff --git a/packages/shield-controller/src/backend.test.ts b/packages/shield-controller/src/backend.test.ts index b176059b61e..41d69550ec6 100644 --- a/packages/shield-controller/src/backend.test.ts +++ b/packages/shield-controller/src/backend.test.ts @@ -45,6 +45,11 @@ function setup({ } describe('ShieldRemoteBackend', () => { + afterEach(() => { + // Clean up mocks after each test + jest.clearAllMocks(); + }); + it('should check coverage', async () => { const { backend, fetchMock, getAccessToken } = setup(); @@ -143,7 +148,7 @@ describe('ShieldRemoteBackend', () => { const txMeta = generateMockTxMeta(); await expect(backend.checkCoverage({ txMeta })).rejects.toThrow( - 'Timeout waiting for coverage result', + 'getCoverageResult: Request timed out', ); // Waiting here ensures coverage of the unexpected error and lets us know diff --git a/packages/shield-controller/src/backend.ts b/packages/shield-controller/src/backend.ts index dcc863850de..2e50b778913 100644 --- a/packages/shield-controller/src/backend.ts +++ b/packages/shield-controller/src/backend.ts @@ -1,6 +1,7 @@ import type { SignatureRequest } from '@metamask/signature-controller'; import type { TransactionMeta } from '@metamask/transaction-controller'; +import { PollingWithTimeoutAndAbort } from './polling-with-timeout-abort'; import type { CheckCoverageRequest, CheckSignatureCoverageRequest, @@ -58,6 +59,8 @@ export class ShieldRemoteBackend implements ShieldBackend { readonly #fetch: typeof globalThis.fetch; + readonly #pollingWithTimeout: PollingWithTimeoutAndAbort; + constructor({ getAccessToken, getCoverageResultTimeout = 5000, // milliseconds @@ -76,6 +79,10 @@ export class ShieldRemoteBackend implements ShieldBackend { this.#getCoverageResultPollInterval = getCoverageResultPollInterval; this.#baseUrl = baseUrl; this.#fetch = fetchFn; + this.#pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: getCoverageResultTimeout, + pollInterval: getCoverageResultPollInterval, + }); } async checkCoverage(req: CheckCoverageRequest): Promise { @@ -90,6 +97,7 @@ export class ShieldRemoteBackend implements ShieldBackend { const txCoverageResultUrl = `${this.#baseUrl}/v1/transaction/coverage/result`; const coverageResult = await this.#getCoverageResult(coverageId, { + requestId: req.txMeta.id, coverageResultUrl: txCoverageResultUrl, }); return { @@ -114,6 +122,7 @@ export class ShieldRemoteBackend implements ShieldBackend { const signatureCoverageResultUrl = `${this.#baseUrl}/v1/signature/coverage/result`; const coverageResult = await this.#getCoverageResult(coverageId, { + requestId: req.signatureRequest.id, coverageResultUrl: signatureCoverageResultUrl, }); return { @@ -132,6 +141,9 @@ export class ShieldRemoteBackend implements ShieldBackend { ...initBody, }; + // clean up the pending coverage result polling + await this.#pollingWithTimeout.abortPendingRequest(req.signatureRequest.id); + const res = await this.#fetch( `${this.#baseUrl}/v1/signature/coverage/log`, { @@ -153,6 +165,9 @@ export class ShieldRemoteBackend implements ShieldBackend { ...initBody, }; + // clean up the pending coverage result polling + await this.#pollingWithTimeout.abortPendingRequest(req.txMeta.id); + const res = await this.#fetch( `${this.#baseUrl}/v1/transaction/coverage/log`, { @@ -183,7 +198,8 @@ export class ShieldRemoteBackend implements ShieldBackend { async #getCoverageResult( coverageId: string, - configs: { + config: { + requestId: string; coverageResultUrl: string; timeout?: number; pollInterval?: number; @@ -192,41 +208,34 @@ export class ShieldRemoteBackend implements ShieldBackend { const reqBody: GetCoverageResultRequest = { coverageId, }; - - const timeout = configs?.timeout ?? this.#getCoverageResultTimeout; - const pollInterval = - configs?.pollInterval ?? this.#getCoverageResultPollInterval; - + const pollingOptions = { + timeout: config.timeout ?? this.#getCoverageResultTimeout, + pollInterval: config.pollInterval ?? this.#getCoverageResultPollInterval, + fnName: 'getCoverageResult', + }; const headers = await this.#createHeaders(); - return await new Promise((resolve, reject) => { - let timeoutReached = false; - setTimeout(() => { - timeoutReached = true; - reject(new Error('Timeout waiting for coverage result')); - }, timeout); - const poll = async (): Promise => { - // The timeoutReached variable is modified in the timeout callback. - // eslint-disable-next-line no-unmodified-loop-condition - while (!timeoutReached) { - const startTime = Date.now(); - const res = await this.#fetch(configs.coverageResultUrl, { - method: 'POST', - headers, - body: JSON.stringify(reqBody), - }); - if (res.status === 200) { - return (await res.json()) as GetCoverageResultResponse; - } - await sleep(pollInterval - (Date.now() - startTime)); - } - // The following line will not have an effect as the upper level promise - // will already be rejected by now. - throw new Error('unexpected error'); - }; + const requestCoverageFn = async ( + signal: AbortSignal, + ): Promise => { + const res = await this.#fetch(config.coverageResultUrl, { + method: 'POST', + headers, + body: JSON.stringify(reqBody), + signal, + }); + if (res.status === 200) { + return (await res.json()) as GetCoverageResultResponse; + } + throw new Error(`Failed to get coverage result: ${res.status}`); + }; - poll().then(resolve).catch(reject); - }); + const coverageResult = await this.#pollingWithTimeout.pollRequest( + config.requestId, + requestCoverageFn, + pollingOptions, + ); + return coverageResult; } async #createHeaders() { @@ -238,16 +247,6 @@ export class ShieldRemoteBackend implements ShieldBackend { } } -/** - * Sleep for a specified amount of time. - * - * @param ms - The number of milliseconds to sleep. - * @returns A promise that resolves after the specified amount of time. - */ -async function sleep(ms: number) { - return new Promise((resolve) => setTimeout(resolve, ms)); -} - /** * Make the body for the init coverage check request. * diff --git a/packages/shield-controller/src/polling-with-policy.test.ts b/packages/shield-controller/src/polling-with-policy.test.ts new file mode 100644 index 00000000000..f6ff30e8483 --- /dev/null +++ b/packages/shield-controller/src/polling-with-policy.test.ts @@ -0,0 +1,145 @@ +import { PollingWithCockatielPolicy } from './polling-with-policy'; +import { delay } from '../tests/utils'; + +describe('PollingWithCockatielPolicy', () => { + it('should return the success result', async () => { + const policy = new PollingWithCockatielPolicy(); + const result = await policy.start('test', async () => { + return 'test'; + }); + expect(result).toBe('test'); + }); + + it('should retry the request and complete successfully', async () => { + const policy = new PollingWithCockatielPolicy(); + let invocationCount = 0; + const mockRequestFn = jest + .fn() + .mockImplementation(async (_abortSignal: AbortSignal) => { + invocationCount += 1; + return new Promise((resolve, reject) => { + setTimeout(() => { + // eslint-disable-next-line jest/no-conditional-in-test + if (invocationCount < 3) { + reject(new Error('test error')); + } + resolve('test'); + }, 100); + }); + }); + const result = await policy.start('test', mockRequestFn); + expect(result).toBe('test'); + expect(mockRequestFn).toHaveBeenCalledTimes(3); + }); + + it('should not retry when the error is not retryable', async () => { + const policy = new PollingWithCockatielPolicy(); + const mockRequestFn = jest + .fn() + .mockImplementation(async (_abortSignal: AbortSignal) => { + return new Promise((_resolve, reject) => { + const error = new Error('Not retryable error') as { + shouldRetry?: boolean; + }; + error.shouldRetry = false; + // eslint-disable-next-line @typescript-eslint/prefer-promise-reject-errors + reject(error); + }); + }); + await expect(policy.start('test', mockRequestFn)).rejects.toThrow( + 'Not retryable error', + ); + expect(mockRequestFn).toHaveBeenCalledTimes(1); + }); + + it('should throw an error when the retry exceeds the max retries', async () => { + const policy = new PollingWithCockatielPolicy({ + maxRetries: 3, + }); + + const requestFn = jest + .fn() + .mockImplementation(async (_abortSignal: AbortSignal) => { + return new Promise((_resolve, reject) => { + setTimeout(() => { + reject(new Error('test error')); + }, 100); + }); + }); + + const result = policy.start('test', requestFn); + await expect(result).rejects.toThrow('test error'); + expect(requestFn).toHaveBeenCalledTimes(4); + }); + + it('should throw a `Request Cancelled` error when the request is aborted', async () => { + const policy = new PollingWithCockatielPolicy({ + maxRetries: 3, + }); + + const requestFn = jest + .fn() + .mockImplementation(async (abortSignal: AbortSignal) => { + return new Promise((_resolve, reject) => { + setTimeout(() => { + // eslint-disable-next-line jest/no-conditional-in-test + if (abortSignal.aborted) { + reject(new Error('test error')); + } + reject(new Error('test error')); + }, 100); + }); + }); + + const result = policy.start('test', requestFn); + await delay(10); + policy.abortPendingRequest('test'); + await expect(result).rejects.toThrow('Request cancelled'); + }); + + it('should throw a `Request Cancelled` error when a new request is started with the same request id', async () => { + const policy = new PollingWithCockatielPolicy(); + + const requestFn = jest + .fn() + .mockImplementation(async (abortSignal: AbortSignal) => { + return new Promise((resolve, reject) => { + setTimeout(() => { + // eslint-disable-next-line jest/no-conditional-in-test + if (abortSignal.aborted) { + reject(new Error('test error')); + } + resolve('test'); + }, 100); + }); + }); + + const result = policy.start('test', requestFn); + await delay(10); + const secondResult = policy.start('test', requestFn); + await expect(result).rejects.toThrow('Request cancelled'); + expect(await secondResult).toBe('test'); + }); + + it('should resolve the result when two requests are started with the different request ids', async () => { + const policy = new PollingWithCockatielPolicy(); + + const requestFn = (result: string) => + jest.fn().mockImplementation(async (abortSignal: AbortSignal) => { + return new Promise((resolve, reject) => { + // eslint-disable-next-line jest/no-conditional-in-test + if (abortSignal.aborted) { + reject(new Error('test error')); + } + setTimeout(() => { + resolve(result); + }, 100); + }); + }); + + const result = policy.start('test', requestFn('test')); + const secondResult = policy.start('test2', requestFn('test2')); + expect(await result).toBe('test'); + expect(await secondResult).toBe('test2'); + }); +}); diff --git a/packages/shield-controller/src/polling-with-policy.ts b/packages/shield-controller/src/polling-with-policy.ts new file mode 100644 index 00000000000..252e3ef41de --- /dev/null +++ b/packages/shield-controller/src/polling-with-policy.ts @@ -0,0 +1,84 @@ +import { + createServicePolicy, + type CreateServicePolicyOptions, + type ServicePolicy, +} from '@metamask/controller-utils'; +import { handleWhen } from 'cockatiel'; + +export type RequestFn = ( + signal: AbortSignal, +) => Promise; + +export class PollingWithCockatielPolicy { + readonly #policy: ServicePolicy; + + readonly #requestEntry = new Map(); + + constructor(policyOptions: CreateServicePolicyOptions = {}) { + const retryFilterPolicy = handleWhen(this.#shouldRetry); + this.#policy = createServicePolicy({ + ...policyOptions, + retryFilterPolicy, + }); + } + + async start(requestId: string, requestFn: RequestFn) { + this.abortPendingRequest(requestId); + const abortController = this.addNewRequestEntry(requestId); + const disposableListeners = this.#registerListeners(); + + try { + const result = await this.#policy.execute( + async ({ signal: abortSignal }) => { + return requestFn(abortSignal); + }, + abortController.signal, + ); + return result; + } catch (error) { + if (abortController.signal.aborted) { + throw new Error('Request cancelled'); + } + throw error; + } finally { + this.#unregisterListeners(disposableListeners); + } + } + + addNewRequestEntry(requestId: string) { + const abortController = new AbortController(); + this.#requestEntry.set(requestId, abortController); + return abortController; + } + + abortPendingRequest(requestId: string) { + const abortController = this.#requestEntry.get(requestId); + abortController?.abort(); + } + + // TODO: remove + #registerListeners(): { dispose: () => void }[] { + const disposableListeners = []; + disposableListeners.push( + this.#policy.onBreak((data) => { + console.log('onBreak', data); + }), + ); + disposableListeners.push( + this.#policy.circuitBreakerPolicy.onStateChange((data) => { + console.log('onStateChange', data); + }), + ); + return disposableListeners; + } + + // TODO: remove + #unregisterListeners(disposableListeners: { dispose: () => void }[]) { + disposableListeners.forEach((disposable) => disposable.dispose()); + } + + #shouldRetry(error: unknown): boolean { + const errorWithRetryStatus = error as { shouldRetry?: boolean }; + return errorWithRetryStatus.shouldRetry !== false; + } +} diff --git a/packages/shield-controller/src/polling-with-timeout-abort.test.ts b/packages/shield-controller/src/polling-with-timeout-abort.test.ts new file mode 100644 index 00000000000..59a9295d37b --- /dev/null +++ b/packages/shield-controller/src/polling-with-timeout-abort.test.ts @@ -0,0 +1,128 @@ +import { PollingWithTimeoutAndAbort } from './polling-with-timeout-abort'; +import { delay } from '../tests/utils'; + +describe('PollingWithTimeoutAndAbort', () => { + it('should timeout when the request does not resolve within the timeout period', async () => { + const pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: 100, + pollInterval: 10, + }); + + const requestFn = jest + .fn() + .mockImplementation(async (_signal: AbortSignal) => { + return new Promise((_resolve, reject) => { + setTimeout(() => { + reject(new Error('test error')); + }, 10); + }); + }); + + await expect( + pollingWithTimeout.pollRequest('test', requestFn, { + fnName: 'test', + timeout: 100, + }), + ).rejects.toThrow('test: Request timed out'); + }); + + it('should timeout with default polling options', async () => { + const pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: 100, + pollInterval: 10, + }); + const requestFn = jest + .fn() + .mockImplementation(async (_signal: AbortSignal) => { + return new Promise((_resolve, reject) => { + setTimeout(() => { + reject(new Error('test error')); + }, 10); + }); + }); + + await expect( + pollingWithTimeout.pollRequest('test', requestFn), + ).rejects.toThrow('Request timed out'); + }); + + it('should abort pending requests when new request is made', async () => { + const pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: 1000, + pollInterval: 20, + }); + + const requestFn = jest + .fn() + .mockImplementation(async (signal: AbortSignal) => { + return new Promise((resolve, reject) => { + setTimeout(() => { + // eslint-disable-next-line jest/no-conditional-in-test -- we want to simulate the abort signal being triggered during the request + if (signal.aborted) { + reject(new Error('test error')); + } + resolve('test result'); + }, 100); + }); + }); + + const firstAttempt = pollingWithTimeout.pollRequest('test', requestFn, { + fnName: 'test', + }); + await delay(15); // small delay to let the first request start + const secondAttempt = pollingWithTimeout.pollRequest('test', requestFn, { + fnName: 'test', + }); + + await expect(firstAttempt).rejects.toThrow('test: Request cancelled'); // first request should be aborted by the second request + const result = await secondAttempt; + expect(result).toBe('test result'); // second request should succeed + }); + + it('should abort pending requests when abortPendingRequest is called', async () => { + const pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: 1000, + pollInterval: 20, + }); + + const requestFn = jest + .fn() + .mockImplementation(async (_signal: AbortSignal) => { + return new Promise((_resolve, reject) => { + setTimeout(() => { + reject(new Error('test error')); + }, 100); + }); + }); + + const request = pollingWithTimeout.pollRequest('test', requestFn, { + fnName: 'test', + }); + await delay(15); // small delay to let the request start + await pollingWithTimeout.abortPendingRequest('test'); + await expect(request).rejects.toThrow('test: Request cancelled'); + }); + + it('should get the result when the request succeeds', async () => { + const pollingWithTimeout = new PollingWithTimeoutAndAbort({ + timeout: 1000, + pollInterval: 20, + }); + + const requestFn = jest + .fn() + .mockImplementation(async (_signal: AbortSignal) => { + return new Promise((resolve) => { + setTimeout(() => { + resolve('test result'); + }, 100); + }); + }); + + const request = pollingWithTimeout.pollRequest('test', requestFn, { + fnName: 'test', + }); + const result = await request; + expect(result).toBe('test result'); + }); +}); diff --git a/packages/shield-controller/src/polling-with-timeout-abort.ts b/packages/shield-controller/src/polling-with-timeout-abort.ts new file mode 100644 index 00000000000..74ef22bef8b --- /dev/null +++ b/packages/shield-controller/src/polling-with-timeout-abort.ts @@ -0,0 +1,204 @@ +export type RequestEntry = { + abortController: AbortController; // The abort controller for the request + abortHandler: (ev: Event) => void; // The abort handler for the request + timerId: NodeJS.Timeout; // The timer ID for the request timeout +}; + +export type RequestFn = ( + signal: AbortSignal, +) => Promise; + +export class PollingWithTimeoutAndAbort { + readonly ABORT_REASON_TIMEOUT = 'Request timed out'; + + readonly ABORT_REASON_CANCELLED = 'Request cancelled'; + + // Map of request ID to request entry + readonly #requestEntries: Map = new Map(); + + readonly #timeout: number; + + readonly #pollInterval: number; + + constructor(config: { timeout: number; pollInterval: number }) { + this.#timeout = config.timeout; + this.#pollInterval = config.pollInterval; + } + + /** + * Poll a request with a timeout and abort. + * This will poll the request until it succeeds or fails due to the timeout or the abort signal being triggered. + * + * @param requestId - The ID of the request to poll. + * @param requestFn - The function to poll the request. + * @param pollingOptions - The options for the polling. + * @param pollingOptions.timeout - The timeout for the request. Defaults to the constructor's timeout. + * @param pollingOptions.pollInterval - The interval for the polling. Defaults to the constructor's pollInterval. + * @param pollingOptions.fnName - The name of the function to poll the request. Defaults to an empty string. + * @returns The result of the request. + */ + async pollRequest( + requestId: string, + requestFn: RequestFn, + pollingOptions: { + timeout?: number; + pollInterval?: number; + fnName?: string; + } = {}, + ) { + const timeout = pollingOptions.timeout ?? this.#timeout; + const pollInterval = pollingOptions.pollInterval ?? this.#pollInterval; + + // clean up the request entry if it exists + await this.abortPendingRequest(requestId); + + // insert the request entry for the next polling cycle + const { abortController } = this.#insertRequestEntry(requestId, timeout); + + while (!abortController.signal.aborted) { + try { + const result = await requestFn(abortController.signal); + // polling success, we just need to abort the request and return the result + // note: this will trigger the abort handler, which will clean up the request entry + abortController.abort(); + return result; + } catch { + // otherwise, we will wait for the next polling cycle + // and continue the polling loop + await this.#delayWithAbortSignal( + pollInterval, + abortController.signal, + ).catch(() => { + // delayWithAbortSignal rejects, which means the abort signal was triggered + // so we will break out of the polling loop + }); + } + } + // At this point, the polling loop has exited and abortController is aborted + const abortReason = abortController.signal.reason; + const errorMessage = pollingOptions.fnName + ? `${pollingOptions.fnName}: ${abortReason}` + : abortReason; + throw new Error(errorMessage); + } + + /** + * Abort the pending request. + * To make sure that the request is actually aborted, we will listen to the abort event and resolve the promise when the abort event is triggered. + * + * @param requestId - The ID of the request to abort. + * @returns A promise that resolves when the request is aborted and cleaned up. + */ + abortPendingRequest(requestId: string): Promise { + return new Promise((resolve) => { + const entry = this.#requestEntries.get(requestId); + if (!entry) { + resolve(); + return; + } + + // Listen for the abort event to ensure cleanup completes before resolving + const cleanupCompleteHandler = () => { + resolve(); + }; + entry.abortController.signal.addEventListener( + 'abort', + cleanupCompleteHandler, + { once: true }, + ); + + entry.abortController.abort(this.ABORT_REASON_CANCELLED); + }); + } + + /** + * Insert a new request entry. + * This will create a new abort controller, set a timeout to abort the request if it takes too long, and set the abort handler. + * + * @param requestId - The ID of the request to insert the entry for. + * @param timeout - The timeout for the request. + * @returns The request entry that was inserted. + */ + #insertRequestEntry(requestId: string, timeout: number) { + const abortController = new AbortController(); + + // Set a timeout to abort the request if it takes too long + const timerId = setTimeout(() => { + abortController.abort(this.ABORT_REASON_TIMEOUT); + }, timeout); + + // Set the abort handler and listen to the `abort` event + const abortHandler = () => { + this.#cleanUp(requestId); + }; + abortController.signal.addEventListener('abort', abortHandler); + + const requestEntry: RequestEntry = { + abortController, + abortHandler, + timerId, + }; + + // Insert the request entry + this.#requestEntries.set(requestId, requestEntry); + + return requestEntry; + } + + /** + * Clean up the request entry. + * This will clear the timeout, remove the abort handler event listener, and remove the request entry from the map. + * + * @param requestId - The ID of the request to clean up. + * @returns The request entry that was cleaned up. + */ + #cleanUp(requestId: string) { + const requestEntry = this.#requestEntries.get(requestId); + if (requestEntry) { + clearTimeout(requestEntry.timerId); // clear the timeout + requestEntry.abortController.signal.removeEventListener( + 'abort', + requestEntry.abortHandler, + ); // clean up the abort handler event listener + this.#requestEntries.delete(requestId); // remove the request entry + } + return requestEntry; + } + + /** + * Delay with an abort signal. + * This will delay the execution of the code until the abort signal is triggered. + * + * @param ms - The number of milliseconds to delay. + * @param abortSignal - The abort signal to listen to. + * @returns A promise that resolves when the delay is complete or rejects if the abort signal is triggered. + */ + async #delayWithAbortSignal(ms: number, abortSignal: AbortSignal) { + return new Promise((resolve, reject) => { + let timer: NodeJS.Timeout | null = null; + + if (abortSignal.aborted) { + reject(new Error(this.ABORT_REASON_CANCELLED)); + } + + const abortHandlerForDelay = () => { + // clear the timeout and resolve the promise + // Note: we don't reject the promise as this is only a dummy delay + if (timer) { + clearTimeout(timer); + } + reject(new Error(this.ABORT_REASON_CANCELLED)); + }; + + timer = setTimeout(() => { + abortSignal.removeEventListener('abort', abortHandlerForDelay); + resolve(undefined); + }, ms); + + // set the abort handler to clear the timeout and resolve the promise + abortSignal.addEventListener('abort', abortHandlerForDelay, { + once: true, // only listen to the abort event once + }); + }); + } +} diff --git a/packages/shield-controller/tests/data.ts b/packages/shield-controller/tests/data.ts new file mode 100644 index 00000000000..f1c7959a3e7 --- /dev/null +++ b/packages/shield-controller/tests/data.ts @@ -0,0 +1,70 @@ +import type { SimulationData } from '@metamask/transaction-controller'; +import { SimulationTokenStandard } from '@metamask/transaction-controller'; + +export const TX_META_SIMULATION_DATA_MOCKS: { + description: string; + previousSimulationData: SimulationData | undefined; + newSimulationData: SimulationData; +}[] = [ + { + description: '`SimulationData.nativeBalanceChange` has changed', + previousSimulationData: undefined, + newSimulationData: { + nativeBalanceChange: { + difference: '0x1', + previousBalance: '0x1', + newBalance: '0x2', + isDecrease: true, + }, + tokenBalanceChanges: [], + }, + }, + { + description: '`SimulationData.tokenBalanceChanges` has changed', + previousSimulationData: { + tokenBalanceChanges: [ + { + difference: '0x1', + previousBalance: '0x1', + standard: SimulationTokenStandard.erc20, + address: '0x1', + newBalance: '0x2', + isDecrease: true, + }, + ], + }, + newSimulationData: { + tokenBalanceChanges: [ + { + difference: '0x2', + previousBalance: '0x1', + standard: SimulationTokenStandard.erc20, + address: '0x1', + newBalance: '0x3', + isDecrease: true, + }, + ], + }, + }, + { + description: '`SimulationData.error` has changed', + previousSimulationData: undefined, + newSimulationData: { + error: { + code: '-123', + message: 'Reverted', + }, + tokenBalanceChanges: [], + }, + }, + { + description: '`SimulationData.isUpdatedAfterSecurityCheck` has changed', + previousSimulationData: { + tokenBalanceChanges: [], + }, + newSimulationData: { + isUpdatedAfterSecurityCheck: true, + tokenBalanceChanges: [], + }, + }, +]; diff --git a/packages/shield-controller/tests/utils.ts b/packages/shield-controller/tests/utils.ts index 8f40bfe94f1..d26c1ed0448 100644 --- a/packages/shield-controller/tests/utils.ts +++ b/packages/shield-controller/tests/utils.ts @@ -99,3 +99,15 @@ export function setupCoverageResultReceived( baseMessenger.subscribe('ShieldController:coverageResultReceived', handler); }); } + +/** + * Delay for a specified amount of time. + * + * @param ms - The number of milliseconds to delay. + * @returns A promise that resolves after the specified amount of time. + */ +export function delay(ms: number): Promise { + return new Promise((resolve) => { + setTimeout(resolve, ms); + }); +} diff --git a/yarn.lock b/yarn.lock index d157d4ade4b..6f95c6dd967 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4725,11 +4725,13 @@ __metadata: "@lavamoat/preinstall-always-fail": "npm:^2.1.0" "@metamask/auto-changelog": "npm:^3.4.4" "@metamask/base-controller": "npm:^8.4.1" + "@metamask/controller-utils": "npm:^11.14.1" "@metamask/signature-controller": "npm:^34.0.1" "@metamask/transaction-controller": "npm:^60.9.0" "@metamask/utils": "npm:^11.8.1" "@ts-bridge/cli": "npm:^0.6.1" "@types/jest": "npm:^27.4.1" + cockatiel: "npm:^3.1.2" deepmerge: "npm:^4.2.2" jest: "npm:^27.5.1" ts-jest: "npm:^27.1.4"