Skip to content

Commit 753a881

Browse files
committed
feat: updated shield-backend with CockatielPollingPolicy
1 parent a7f6535 commit 753a881

File tree

3 files changed

+68
-56
lines changed

3 files changed

+68
-56
lines changed

packages/shield-controller/src/backend.test.ts

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ describe('ShieldRemoteBackend', () => {
129129
expect(getAccessToken).toHaveBeenCalledTimes(1);
130130
});
131131

132-
it('should throw on check coverage timeout', async () => {
132+
it('should throw on check coverage timeout with coverage status', async () => {
133133
const { backend, fetchMock } = setup({
134134
getCoverageResultTimeout: 0,
135135
getCoverageResultPollInterval: 0,
@@ -147,6 +147,33 @@ describe('ShieldRemoteBackend', () => {
147147
json: jest.fn().mockResolvedValue({ status: 'unavailable' }),
148148
} as unknown as Response);
149149

150+
const txMeta = generateMockTxMeta();
151+
await expect(backend.checkCoverage({ txMeta })).rejects.toThrow(
152+
'Timeout waiting for coverage result: unavailable',
153+
);
154+
155+
// Waiting here ensures coverage of the unexpected error and lets us know
156+
// that the polling loop is exited as expected.
157+
await new Promise((resolve) => setTimeout(resolve, 10));
158+
});
159+
160+
it('should throw on check coverage timeout', async () => {
161+
const { backend, fetchMock } = setup({
162+
getCoverageResultTimeout: 0,
163+
getCoverageResultPollInterval: 0,
164+
});
165+
166+
// Mock init coverage check.
167+
fetchMock.mockResolvedValueOnce({
168+
status: 200,
169+
json: jest.fn().mockResolvedValue({ coverageId: 'coverageId' }),
170+
} as unknown as Response);
171+
172+
// Mock get coverage result: result unavailable.
173+
fetchMock.mockResolvedValue({
174+
status: 404,
175+
} as unknown as Response);
176+
150177
const txMeta = generateMockTxMeta();
151178
await expect(backend.checkCoverage({ txMeta })).rejects.toThrow(
152179
'Timeout waiting for coverage result',

packages/shield-controller/src/backend.ts

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import type {
1616
LogTransactionRequest,
1717
ShieldBackend,
1818
} from './types';
19+
import { PollingWithCockatielPolicy } from './polling-with-policy';
20+
import { HttpError } from '@metamask/controller-utils';
1921

2022
export type InitCoverageCheckRequest = {
2123
txParams: [
@@ -64,6 +66,8 @@ export class ShieldRemoteBackend implements ShieldBackend {
6466

6567
readonly #fetch: typeof globalThis.fetch;
6668

69+
readonly #pollingPolicy: PollingWithCockatielPolicy;
70+
6771
constructor({
6872
getAccessToken,
6973
getCoverageResultTimeout = 5000, // milliseconds
@@ -82,6 +86,7 @@ export class ShieldRemoteBackend implements ShieldBackend {
8286
this.#getCoverageResultPollInterval = getCoverageResultPollInterval;
8387
this.#baseUrl = baseUrl;
8488
this.#fetch = fetchFn;
89+
this.#pollingPolicy = new PollingWithCockatielPolicy();
8590
}
8691

8792
async checkCoverage(req: CheckCoverageRequest): Promise<CoverageResult> {
@@ -95,9 +100,7 @@ export class ShieldRemoteBackend implements ShieldBackend {
95100
}
96101

97102
const txCoverageResultUrl = `${this.#baseUrl}/v1/transaction/coverage/result`;
98-
const coverageResult = await this.#getCoverageResult(coverageId, {
99-
coverageResultUrl: txCoverageResultUrl,
100-
});
103+
const coverageResult = await this.#getCoverageResult(req.txMeta.id, coverageId, txCoverageResultUrl);
101104
return {
102105
coverageId,
103106
message: coverageResult.message,
@@ -119,9 +122,7 @@ export class ShieldRemoteBackend implements ShieldBackend {
119122
}
120123

121124
const signatureCoverageResultUrl = `${this.#baseUrl}/v1/signature/coverage/result`;
122-
const coverageResult = await this.#getCoverageResult(coverageId, {
123-
coverageResultUrl: signatureCoverageResultUrl,
124-
});
125+
const coverageResult = await this.#getCoverageResult(req.signatureRequest.id, coverageId, signatureCoverageResultUrl);
125126
return {
126127
coverageId,
127128
message: coverageResult.message,
@@ -138,6 +139,9 @@ export class ShieldRemoteBackend implements ShieldBackend {
138139
...initBody,
139140
};
140141

142+
// cancel the pending get coverage result request
143+
this.#pollingPolicy.abortPendingRequest(req.signatureRequest.id);
144+
141145
const res = await this.#fetch(
142146
`${this.#baseUrl}/v1/signature/coverage/log`,
143147
{
@@ -159,6 +163,9 @@ export class ShieldRemoteBackend implements ShieldBackend {
159163
...initBody,
160164
};
161165

166+
// cancel the pending get coverage result request
167+
this.#pollingPolicy.abortPendingRequest(req.txMeta.id);
168+
162169
const res = await this.#fetch(
163170
`${this.#baseUrl}/v1/transaction/coverage/log`,
164171
{
@@ -188,51 +195,39 @@ export class ShieldRemoteBackend implements ShieldBackend {
188195
}
189196

190197
async #getCoverageResult(
198+
requestId: string,
191199
coverageId: string,
192-
configs: {
193-
coverageResultUrl: string;
194-
timeout?: number;
195-
pollInterval?: number;
196-
},
200+
coverageResultUrl: string,
197201
): Promise<GetCoverageResultResponse> {
198202
const reqBody: GetCoverageResultRequest = {
199203
coverageId,
200204
};
201205

202-
const timeout = configs?.timeout ?? this.#getCoverageResultTimeout;
203-
const pollInterval =
204-
configs?.pollInterval ?? this.#getCoverageResultPollInterval;
205-
206206
const headers = await this.#createHeaders();
207-
return await new Promise((resolve, reject) => {
208-
let timeoutReached = false;
209-
setTimeout(() => {
210-
timeoutReached = true;
211-
reject(new Error('Timeout waiting for coverage result'));
212-
}, timeout);
213-
214-
const poll = async (): Promise<GetCoverageResultResponse> => {
215-
// The timeoutReached variable is modified in the timeout callback.
216-
// eslint-disable-next-line no-unmodified-loop-condition
217-
while (!timeoutReached) {
218-
const startTime = Date.now();
219-
const res = await this.#fetch(configs.coverageResultUrl, {
220-
method: 'POST',
221-
headers,
222-
body: JSON.stringify(reqBody),
223-
});
224-
if (res.status === 200) {
225-
return (await res.json()) as GetCoverageResultResponse;
226-
}
227-
await sleep(pollInterval - (Date.now() - startTime));
228-
}
229-
// The following line will not have an effect as the upper level promise
230-
// will already be rejected by now.
231-
throw new Error('unexpected error');
232-
};
233-
234-
poll().then(resolve).catch(reject);
235-
});
207+
208+
const getCoverageResultFn = async (signal: AbortSignal) => {
209+
const res = await this.#fetch(coverageResultUrl, {
210+
method: 'POST',
211+
headers,
212+
body: JSON.stringify(reqBody),
213+
signal,
214+
});
215+
if (res.status === 200) {
216+
return (await res.json()) as GetCoverageResultResponse;
217+
}
218+
219+
// parse the error message from the response body
220+
let errorMessage = 'Timeout waiting for coverage result';
221+
try {
222+
const errorJson = await res.json();
223+
errorMessage = `Timeout waiting for coverage result: ${errorJson.status}`;
224+
} catch (error) {
225+
errorMessage = 'Timeout waiting for coverage result';
226+
}
227+
throw new HttpError(res.status, errorMessage);
228+
}
229+
230+
return this.#pollingPolicy.start(requestId, getCoverageResultFn);
236231
}
237232

238233
async #createHeaders() {
@@ -244,16 +239,6 @@ export class ShieldRemoteBackend implements ShieldBackend {
244239
}
245240
}
246241

247-
/**
248-
* Sleep for a specified amount of time.
249-
*
250-
* @param ms - The number of milliseconds to sleep.
251-
* @returns A promise that resolves after the specified amount of time.
252-
*/
253-
async function sleep(ms: number) {
254-
return new Promise((resolve) => setTimeout(resolve, ms));
255-
}
256-
257242
/**
258243
* Make the body for the init coverage check request.
259244
*

packages/shield-controller/src/polling-with-policy.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ export class PollingWithCockatielPolicy {
5656

5757
#shouldRetry(error: Error): boolean {
5858
if (error instanceof HttpError) {
59-
// Note: we don't retry on 4xx errors, only on 5xx errors.
59+
// Note: we don't retry on 5xx errors, only on 4xx errors.
6060
// but we won't retry on 400 coz it means that the request body is invalid.
6161
return error.httpStatus > 400 && error.httpStatus < 500;
6262
}

0 commit comments

Comments
 (0)