Skip to content

Commit dda7198

Browse files
committed
Fix DeviceModifier
The issue is the key is selected before the modification happens which means the devices in the pool will not match the key. I can't think of a way to fix this using the adaptor so maybe this hacky solution works by letting you modify the key. Generally you just append a string but I made it function because MaxLimitsTestMixin wants to be able to chain to other modifiers. Maybe it should all be re-designed This came up because I notice a test passing that shouldn't have passed. It passsed because a `MaxLimitsTestMixin` test ran first and it's key was `''` which meant that other tests got a max limits device from the pool.
1 parent 7eedd5a commit dda7198

File tree

3 files changed

+69
-29
lines changed

3 files changed

+69
-29
lines changed

src/webgpu/api/operation/device/all_limits_and_features.spec.ts

+14-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import {
88
GPUTestSubcaseBatchState,
99
initUncanonicalizedDeviceDescriptor,
1010
} from '../../../gpu_test.js';
11-
import { CanonicalDeviceDescriptor, DescriptorModifierFn } from '../../../util/device_pool.js';
11+
import { CanonicalDeviceDescriptor, DescriptorModifier } from '../../../util/device_pool.js';
1212

1313
/**
1414
* Gets the adapter limits as a standard JavaScript object.
@@ -36,18 +36,25 @@ function setAllLimitsToAdapterLimitsAndAddAllFeatures(
3636
}
3737

3838
/**
39-
* Used by MaxLimitsTest to request a device with all the max limits of the adapter.
39+
* Used to request a device with all the max limits of the adapter.
4040
*/
4141
export class AllLimitsAndFeaturesGPUTestSubcaseBatchState extends GPUTestSubcaseBatchState {
4242
override selectDeviceOrSkipTestCase(
4343
descriptor: DeviceSelectionDescriptor,
44-
descriptorModifierFn?: DescriptorModifierFn
44+
descriptorModifier?: DescriptorModifier
4545
): void {
46-
const wrapper = (adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) => {
47-
desc = descriptorModifierFn ? descriptorModifierFn(adapter, desc) : desc;
48-
return setAllLimitsToAdapterLimitsAndAddAllFeatures(adapter, desc);
46+
const mod: DescriptorModifier = {
47+
descriptorModifier(adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) {
48+
desc = descriptorModifier?.descriptorModifier
49+
? descriptorModifier.descriptorModifier(adapter, desc)
50+
: desc;
51+
return setAllLimitsToAdapterLimitsAndAddAllFeatures(adapter, desc);
52+
},
53+
keyModifier(baseKey: string) {
54+
return `${baseKey}:AllLimitsAndFeaturesTest`;
55+
},
4956
};
50-
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), wrapper);
57+
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), mod);
5158
}
5259
}
5360

src/webgpu/gpu_test.ts

+16-9
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import { CommandBufferMaker, EncoderType } from './util/command_buffer_maker.js'
4343
import { ScalarType } from './util/conversion.js';
4444
import {
4545
CanonicalDeviceDescriptor,
46-
DescriptorModifierFn,
46+
DescriptorModifier,
4747
DevicePool,
4848
DeviceProvider,
4949
UncanonicalizedDeviceDescriptor,
@@ -156,13 +156,13 @@ export class GPUTestSubcaseBatchState extends SubcaseBatchState {
156156
*/
157157
selectDeviceOrSkipTestCase(
158158
descriptor: DeviceSelectionDescriptor,
159-
descriptorModifierFn?: DescriptorModifierFn
159+
descriptorModifier?: DescriptorModifier
160160
): void {
161161
assert(this.provider === undefined, "Can't selectDeviceOrSkipTestCase() multiple times");
162162
this.provider = devicePool.acquire(
163163
this.recorder,
164164
initUncanonicalizedDeviceDescriptor(descriptor),
165-
descriptorModifierFn
165+
descriptorModifier
166166
);
167167
// Suppress uncaught promise rejection (we'll catch it later).
168168
this.provider.catch(() => {});
@@ -1334,13 +1334,20 @@ function setAllLimitsToAdapterLimits(
13341334
export class MaxLimitsGPUTestSubcaseBatchState extends GPUTestSubcaseBatchState {
13351335
override selectDeviceOrSkipTestCase(
13361336
descriptor: DeviceSelectionDescriptor,
1337-
descriptorModifierFn?: DescriptorModifierFn
1337+
descriptorModifier?: DescriptorModifier
13381338
): void {
1339-
const wrapper = (adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) => {
1340-
desc = descriptorModifierFn ? descriptorModifierFn(adapter, desc) : desc;
1341-
return setAllLimitsToAdapterLimits(adapter, desc);
1339+
const mod: DescriptorModifier = {
1340+
descriptorModifier(adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) {
1341+
desc = descriptorModifier?.descriptorModifier
1342+
? descriptorModifier.descriptorModifier(adapter, desc)
1343+
: desc;
1344+
return setAllLimitsToAdapterLimits(adapter, desc);
1345+
},
1346+
keyModifier(baseKey: string) {
1347+
return `${baseKey}:MaxLimits`;
1348+
},
13421349
};
1343-
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), wrapper);
1350+
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), mod);
13441351
}
13451352
}
13461353

@@ -1466,7 +1473,7 @@ export interface TextureTestMixinType {
14661473
* Effectively it's a Uint8Array to Uint8Array copy that
14671474
* does the same thing as `writeTexture` but because the
14681475
* destination is a buffer you have to provide the parameters
1469-
* of the destination buffer similarly to how you'd provide them
1476+
* of the destination buffer similarly to how you'esc provide them
14701477
* to `copyTextureToBuffer`
14711478
*/
14721479
updateLinearTextureDataSubBox(

src/webgpu/util/device_pool.ts

+39-13
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,35 @@ class TestFailedButDeviceReusable extends Error {}
2323
class FeaturesNotSupported extends Error {}
2424
export class TestOOMedShouldAttemptGC extends Error {}
2525

26-
export type DescriptorModifierFn = (
27-
adapter: GPUAdapter,
28-
desc: CanonicalDeviceDescriptor | undefined
29-
) => CanonicalDeviceDescriptor;
26+
/**
27+
* DescriptorModifier lets you supply a function to select a device
28+
* based on the limits/features available from the adapter.
29+
* Devices pooled based on a key and that key is derived before
30+
* an adapter is requested. That means you select key without
31+
* knowledge of what the adapter will provide. You do this by
32+
* providing a keyModifier function that appends a suffix.
33+
*
34+
* For example: If your modifier adds all the limits you might
35+
* choose 'maxLimits' are your suffix
36+
*
37+
* ```js
38+
* keyModifier(s: string) { return `${s}:maxLimits`; },
39+
* ```
40+
*
41+
* If your modifier selects only `maxBindGroups` and `maxColorAttachments`
42+
* then your suffix might be `maxBindGroups&maxColorAttachments`
43+
*
44+
* ```js
45+
* keyModifier(s: string) { return `${s}:maxBindGroups&maxColorAttachments`; },
46+
* ```
47+
*/
48+
export type DescriptorModifier = {
49+
keyModifier(baseKey: string): string;
50+
descriptorModifier(
51+
adapter: GPUAdapter,
52+
desc: CanonicalDeviceDescriptor | undefined
53+
): CanonicalDeviceDescriptor;
54+
};
3055

3156
export class DevicePool {
3257
private holders: 'uninitialized' | 'failed' | DescriptorToHolderMap = 'uninitialized';
@@ -35,13 +60,13 @@ export class DevicePool {
3560
async acquire(
3661
recorder: TestCaseRecorder,
3762
descriptor: UncanonicalizedDeviceDescriptor | undefined,
38-
descriptorModifierFn: DescriptorModifierFn | undefined
63+
descriptorModifier: DescriptorModifier | undefined
3964
): Promise<DeviceProvider> {
4065
let errorMessage = '';
4166
if (this.holders === 'uninitialized') {
4267
this.holders = new DescriptorToHolderMap();
4368
try {
44-
await this.holders.getOrCreate(recorder, undefined, descriptorModifierFn);
69+
await this.holders.getOrCreate(recorder, undefined, descriptorModifier);
4570
} catch (ex) {
4671
this.holders = 'failed';
4772
if (ex instanceof Error) {
@@ -55,7 +80,7 @@ export class DevicePool {
5580
`WebGPU device failed to initialize${errorMessage}; not retrying`
5681
);
5782

58-
const holder = await this.holders.getOrCreate(recorder, descriptor, descriptorModifierFn);
83+
const holder = await this.holders.getOrCreate(recorder, descriptor, descriptorModifier);
5984

6085
assert(holder.state === 'free', 'Device was in use on DevicePool.acquire');
6186
holder.state = 'acquired';
@@ -150,9 +175,10 @@ class DescriptorToHolderMap {
150175
async getOrCreate(
151176
recorder: TestCaseRecorder,
152177
uncanonicalizedDescriptor: UncanonicalizedDeviceDescriptor | undefined,
153-
descriptorModifierFn: DescriptorModifierFn | undefined
178+
descriptorModifier: DescriptorModifier | undefined
154179
): Promise<DeviceHolder> {
155-
const [descriptor, key] = canonicalizeDescriptor(uncanonicalizedDescriptor);
180+
const [descriptor, baseKey] = canonicalizeDescriptor(uncanonicalizedDescriptor);
181+
const key = descriptorModifier?.keyModifier(baseKey) || baseKey;
156182
// Quick-reject descriptors that are known to be unsupported already.
157183
if (this.unsupported.has(key)) {
158184
throw new SkipTestCase(
@@ -174,7 +200,7 @@ class DescriptorToHolderMap {
174200
// No existing item was found; add a new one.
175201
let value;
176202
try {
177-
value = await DeviceHolder.create(recorder, descriptor, descriptorModifierFn);
203+
value = await DeviceHolder.create(recorder, descriptor, descriptorModifier);
178204
} catch (ex) {
179205
if (ex instanceof FeaturesNotSupported) {
180206
this.unsupported.add(key);
@@ -313,13 +339,13 @@ class DeviceHolder implements DeviceProvider {
313339
static async create(
314340
recorder: TestCaseRecorder,
315341
descriptor: CanonicalDeviceDescriptor | undefined,
316-
descriptorModifierFn: DescriptorModifierFn | undefined
342+
descriptorModifier: DescriptorModifier | undefined
317343
): Promise<DeviceHolder> {
318344
const gpu = getGPU(recorder);
319345
const adapter = await gpu.requestAdapter();
320346
assert(adapter !== null, 'requestAdapter returned null');
321-
if (descriptorModifierFn) {
322-
descriptor = descriptorModifierFn(adapter, descriptor);
347+
if (descriptorModifier) {
348+
descriptor = descriptorModifier.descriptorModifier(adapter, descriptor);
323349
}
324350
if (!supportsFeature(adapter, descriptor)) {
325351
throw new FeaturesNotSupported('One or more features are not supported');

0 commit comments

Comments
 (0)