Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/shield-controller/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- Fixed and optimized shield-coverage-result polling with Cockatiel Policy from Controller-utils. ([#6847](https://github.com/MetaMask/core/pull/6847))

## [0.4.0]

### Added
Expand Down
4 changes: 3 additions & 1 deletion packages/shield-controller/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
},
"dependencies": {
"@metamask/base-controller": "^8.4.2",
"@metamask/utils": "^11.8.1"
"@metamask/controller-utils": "^11.14.1",
"@metamask/utils": "^11.8.1",
"cockatiel": "^3.1.2"
},
"devDependencies": {
"@babel/runtime": "^7.23.9",
Expand Down
36 changes: 33 additions & 3 deletions packages/shield-controller/src/backend.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ describe('ShieldRemoteBackend', () => {
expect(getAccessToken).toHaveBeenCalledTimes(1);
});

it('should throw on check coverage timeout', async () => {
it('should throw on check coverage timeout with coverage status', async () => {
const { backend, fetchMock } = setup({
getCoverageResultTimeout: 0,
getCoverageResultPollInterval: 0,
Expand All @@ -144,12 +144,42 @@ describe('ShieldRemoteBackend', () => {
// Mock get coverage result: result unavailable.
fetchMock.mockResolvedValue({
status: 404,
json: jest.fn().mockResolvedValue({ status: 'unavailable' }),
} as unknown as Response);

const txMeta = generateMockTxMeta();
await expect(backend.checkCoverage({ txMeta })).rejects.toThrow(
'Timeout waiting for coverage result',
'Failed to get coverage result: 404',
);

// Waiting here ensures coverage of the unexpected error and lets us know
// that the polling loop is exited as expected.
await new Promise((resolve) => setTimeout(resolve, 10));
});

it('should throw on check coverage timeout', async () => {
const { backend, fetchMock } = setup({
getCoverageResultTimeout: 0,
getCoverageResultPollInterval: 0,
});

// Mock init coverage check.
fetchMock.mockResolvedValueOnce({
status: 200,
json: jest.fn().mockResolvedValue({ coverageId: 'coverageId' }),
} as unknown as Response);

// Mock get coverage result: result unavailable.
fetchMock.mockResolvedValue({
status: 412,
json: jest.fn().mockResolvedValue({
message: 'Results are not available yet',
statusCode: 412,
}),
} as unknown as Response);

const txMeta = generateMockTxMeta();
await expect(backend.checkCoverage({ txMeta })).rejects.toThrow(
'Failed to get coverage result: Results are not available yet',
);

// Waiting here ensures coverage of the unexpected error and lets us know
Expand Down
145 changes: 85 additions & 60 deletions packages/shield-controller/src/backend.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import {
ConstantBackoff,
DEFAULT_MAX_RETRIES,
HttpError,
} from '@metamask/controller-utils';
import {
EthMethod,
SignatureRequestType,
Expand All @@ -7,6 +12,7 @@ import type { TransactionMeta } from '@metamask/transaction-controller';
import type { Json } from '@metamask/utils';

import { SignTypedDataVersion } from './constants';
import { PollingWithCockatielPolicy } from './polling-with-policy';
import type {
CheckCoverageRequest,
CheckSignatureCoverageRequest,
Expand Down Expand Up @@ -56,14 +62,12 @@ export type GetCoverageResultResponse = {
export class ShieldRemoteBackend implements ShieldBackend {
readonly #getAccessToken: () => Promise<string>;

readonly #getCoverageResultTimeout: number;

readonly #getCoverageResultPollInterval: number;

readonly #baseUrl: string;

readonly #fetch: typeof globalThis.fetch;

readonly #pollingPolicy: PollingWithCockatielPolicy;

constructor({
getAccessToken,
getCoverageResultTimeout = 5000, // milliseconds
Expand All @@ -78,10 +82,18 @@ export class ShieldRemoteBackend implements ShieldBackend {
fetch: typeof globalThis.fetch;
}) {
this.#getAccessToken = getAccessToken;
this.#getCoverageResultTimeout = getCoverageResultTimeout;
this.#getCoverageResultPollInterval = getCoverageResultPollInterval;
this.#baseUrl = baseUrl;
this.#fetch = fetchFn;

const { backoff, maxRetries } = computePollingIntervalAndRetryCount(
getCoverageResultTimeout,
getCoverageResultPollInterval,
);

this.#pollingPolicy = new PollingWithCockatielPolicy({
backoff,
maxRetries,
});
}

async checkCoverage(req: CheckCoverageRequest): Promise<CoverageResult> {
Expand All @@ -95,9 +107,11 @@ export class ShieldRemoteBackend implements ShieldBackend {
}

const txCoverageResultUrl = `${this.#baseUrl}/v1/transaction/coverage/result`;
const coverageResult = await this.#getCoverageResult(coverageId, {
coverageResultUrl: txCoverageResultUrl,
});
const coverageResult = await this.#getCoverageResult(
req.txMeta.id,
coverageId,
txCoverageResultUrl,
);
return {
coverageId,
message: coverageResult.message,
Expand All @@ -119,9 +133,11 @@ export class ShieldRemoteBackend implements ShieldBackend {
}

const signatureCoverageResultUrl = `${this.#baseUrl}/v1/signature/coverage/result`;
const coverageResult = await this.#getCoverageResult(coverageId, {
coverageResultUrl: signatureCoverageResultUrl,
});
const coverageResult = await this.#getCoverageResult(
req.signatureRequest.id,
coverageId,
signatureCoverageResultUrl,
);
return {
coverageId,
message: coverageResult.message,
Expand All @@ -138,6 +154,9 @@ export class ShieldRemoteBackend implements ShieldBackend {
...initBody,
};

// cancel the pending get coverage result request
this.#pollingPolicy.abortPendingRequest(req.signatureRequest.id);

const res = await this.#fetch(
`${this.#baseUrl}/v1/signature/coverage/log`,
{
Expand All @@ -159,6 +178,9 @@ export class ShieldRemoteBackend implements ShieldBackend {
...initBody,
};

// cancel the pending get coverage result request
this.#pollingPolicy.abortPendingRequest(req.txMeta.id);

const res = await this.#fetch(
`${this.#baseUrl}/v1/transaction/coverage/log`,
{
Expand Down Expand Up @@ -188,51 +210,39 @@ export class ShieldRemoteBackend implements ShieldBackend {
}

async #getCoverageResult(
requestId: string,
coverageId: string,
configs: {
coverageResultUrl: string;
timeout?: number;
pollInterval?: number;
},
coverageResultUrl: string,
): Promise<GetCoverageResultResponse> {
const reqBody: GetCoverageResultRequest = {
coverageId,
};

const timeout = configs?.timeout ?? this.#getCoverageResultTimeout;
const pollInterval =
configs?.pollInterval ?? this.#getCoverageResultPollInterval;

const headers = await this.#createHeaders();
return await new Promise((resolve, reject) => {
let timeoutReached = false;
setTimeout(() => {
timeoutReached = true;
reject(new Error('Timeout waiting for coverage result'));
}, timeout);

const poll = async (): Promise<GetCoverageResultResponse> => {
// The timeoutReached variable is modified in the timeout callback.
// eslint-disable-next-line no-unmodified-loop-condition
while (!timeoutReached) {
const startTime = Date.now();
const res = await this.#fetch(configs.coverageResultUrl, {
method: 'POST',
headers,
body: JSON.stringify(reqBody),
});
if (res.status === 200) {
return (await res.json()) as GetCoverageResultResponse;
}
await sleep(pollInterval - (Date.now() - startTime));
}
// The following line will not have an effect as the upper level promise
// will already be rejected by now.
throw new Error('unexpected error');
};

poll().then(resolve).catch(reject);
});

const getCoverageResultFn = async (signal: AbortSignal) => {
const res = await this.#fetch(coverageResultUrl, {
method: 'POST',
headers,
body: JSON.stringify(reqBody),
signal,
});
if (res.status === 200) {
return (await res.json()) as GetCoverageResultResponse;
}

// parse the error message from the response body
let errorMessage = 'Timeout waiting for coverage result';
try {
const errorJson = await res.json();
errorMessage = `Failed to get coverage result: ${errorJson.message || errorJson.status}`;
} catch {
errorMessage = `Failed to get coverage result: ${res.status}`;
}
throw new HttpError(res.status, errorMessage);
};

return this.#pollingPolicy.start(requestId, getCoverageResultFn);
}

async #createHeaders() {
Expand All @@ -244,16 +254,6 @@ export class ShieldRemoteBackend implements ShieldBackend {
}
}

/**
* Sleep for a specified amount of time.
*
* @param ms - The number of milliseconds to sleep.
* @returns A promise that resolves after the specified amount of time.
*/
async function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}

/**
* Make the body for the init coverage check request.
*
Expand Down Expand Up @@ -324,3 +324,28 @@ export function parseSignatureRequestMethod(

return signatureRequest.type;
}

/**
* Compute the polling interval and retry count for the Cockatiel policy based on the timeout and poll interval given.
*
* @param timeout - The timeout in milliseconds.
* @param pollInterval - The poll interval in milliseconds.
* @returns The polling interval and retry count.
*/
function computePollingIntervalAndRetryCount(
timeout: number,
pollInterval: number,
) {
const backoff = new ConstantBackoff(pollInterval);
const computedMaxRetries = Math.floor(timeout / pollInterval) + 1;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Polling Timeout Exceeded Due to Retry Calculation Error

The maxRetries calculation in computePollingIntervalAndRetryCount is off by one. It currently calculates total attempts (Math.floor(timeout / pollInterval) + 1), but the Cockatiel policy expects the number of retries (attempts - 1). This leads to an extra retry, causing polling to exceed the intended timeout.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intended.


const maxRetries =
isNaN(computedMaxRetries) || !isFinite(computedMaxRetries)
? DEFAULT_MAX_RETRIES
: computedMaxRetries;

return {
backoff,
maxRetries,
};
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Polling Timeout Exceeded Due to Retry Counting Error

The maxRetries calculation in computePollingIntervalAndRetryCount has an off-by-one error. It effectively counts the initial attempt as a retry, leading to one extra poll and causing the total polling duration to exceed the specified timeout.

Fix in Cursor Fix in Web

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Edge Case Handling in Retry Calculation

The maxRetries calculation in computePollingIntervalAndRetryCount has edge cases when pollInterval is 0 or when both timeout and pollInterval are 0. Division by zero results in NaN or Infinity, causing the function to incorrectly fall back to DEFAULT_MAX_RETRIES. This prevents respecting intended behaviors like immediate polling or immediate failure, leading to unexpected retry counts.

Fix in Cursor Fix in Web

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Polling Interval Miscalculation

The computePollingIntervalAndRetryCount function calculates maxRetries as Math.floor(timeout / pollInterval) + 1. This adds an extra retry, causing the total polling duration to exceed the specified timeout. This can lead to longer-than-expected waits, particularly with small poll intervals.

Fix in Cursor Fix in Web

Loading
Loading