diff --git a/src/webgpu/api/operation/device/all_limits_and_features.spec.ts b/src/webgpu/api/operation/device/all_limits_and_features.spec.ts index 82585090b499..729928631348 100644 --- a/src/webgpu/api/operation/device/all_limits_and_features.spec.ts +++ b/src/webgpu/api/operation/device/all_limits_and_features.spec.ts @@ -8,7 +8,7 @@ import { GPUTestSubcaseBatchState, initUncanonicalizedDeviceDescriptor, } from '../../../gpu_test.js'; -import { CanonicalDeviceDescriptor, DescriptorModifierFn } from '../../../util/device_pool.js'; +import { CanonicalDeviceDescriptor, DescriptorModifier } from '../../../util/device_pool.js'; /** * Gets the adapter limits as a standard JavaScript object. @@ -36,18 +36,25 @@ function setAllLimitsToAdapterLimitsAndAddAllFeatures( } /** - * Used by MaxLimitsTest to request a device with all the max limits of the adapter. + * Used to request a device with all the max limits of the adapter. */ export class AllLimitsAndFeaturesGPUTestSubcaseBatchState extends GPUTestSubcaseBatchState { override selectDeviceOrSkipTestCase( descriptor: DeviceSelectionDescriptor, - descriptorModifierFn?: DescriptorModifierFn + descriptorModifier?: DescriptorModifier ): void { - const wrapper = (adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) => { - desc = descriptorModifierFn ? descriptorModifierFn(adapter, desc) : desc; - return setAllLimitsToAdapterLimitsAndAddAllFeatures(adapter, desc); + const mod: DescriptorModifier = { + descriptorModifier(adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) { + desc = descriptorModifier?.descriptorModifier + ? descriptorModifier.descriptorModifier(adapter, desc) + : desc; + return setAllLimitsToAdapterLimitsAndAddAllFeatures(adapter, desc); + }, + keyModifier(baseKey: string) { + return `${baseKey}:AllLimitsAndFeaturesTest`; + }, }; - super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), wrapper); + super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), mod); } } diff --git a/src/webgpu/gpu_test.ts b/src/webgpu/gpu_test.ts index 53e4a1481791..2719679b512e 100644 --- a/src/webgpu/gpu_test.ts +++ b/src/webgpu/gpu_test.ts @@ -43,7 +43,7 @@ import { CommandBufferMaker, EncoderType } from './util/command_buffer_maker.js' import { ScalarType } from './util/conversion.js'; import { CanonicalDeviceDescriptor, - DescriptorModifierFn, + DescriptorModifier, DevicePool, DeviceProvider, UncanonicalizedDeviceDescriptor, @@ -156,13 +156,13 @@ export class GPUTestSubcaseBatchState extends SubcaseBatchState { */ selectDeviceOrSkipTestCase( descriptor: DeviceSelectionDescriptor, - descriptorModifierFn?: DescriptorModifierFn + descriptorModifier?: DescriptorModifier ): void { assert(this.provider === undefined, "Can't selectDeviceOrSkipTestCase() multiple times"); this.provider = devicePool.acquire( this.recorder, initUncanonicalizedDeviceDescriptor(descriptor), - descriptorModifierFn + descriptorModifier ); // Suppress uncaught promise rejection (we'll catch it later). this.provider.catch(() => {}); @@ -1334,13 +1334,20 @@ function setAllLimitsToAdapterLimits( export class MaxLimitsGPUTestSubcaseBatchState extends GPUTestSubcaseBatchState { override selectDeviceOrSkipTestCase( descriptor: DeviceSelectionDescriptor, - descriptorModifierFn?: DescriptorModifierFn + descriptorModifier?: DescriptorModifier ): void { - const wrapper = (adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) => { - desc = descriptorModifierFn ? descriptorModifierFn(adapter, desc) : desc; - return setAllLimitsToAdapterLimits(adapter, desc); + const mod: DescriptorModifier = { + descriptorModifier(adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) { + desc = descriptorModifier?.descriptorModifier + ? descriptorModifier.descriptorModifier(adapter, desc) + : desc; + return setAllLimitsToAdapterLimits(adapter, desc); + }, + keyModifier(baseKey: string) { + return `${baseKey}:MaxLimits`; + }, }; - super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), wrapper); + super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), mod); } } diff --git a/src/webgpu/util/device_pool.ts b/src/webgpu/util/device_pool.ts index 20db6afd1b2e..262a567d7a2c 100644 --- a/src/webgpu/util/device_pool.ts +++ b/src/webgpu/util/device_pool.ts @@ -23,10 +23,35 @@ class TestFailedButDeviceReusable extends Error {} class FeaturesNotSupported extends Error {} export class TestOOMedShouldAttemptGC extends Error {} -export type DescriptorModifierFn = ( - adapter: GPUAdapter, - desc: CanonicalDeviceDescriptor | undefined -) => CanonicalDeviceDescriptor; +/** + * DescriptorModifier lets you supply a function to select a device + * based on the limits/features available from the adapter. + * Devices pooled based on a key and that key is derived before + * an adapter is requested. That means you select key without + * knowledge of what the adapter will provide. You do this by + * providing a keyModifier function that appends a suffix. + * + * For example: If your modifier adds all the limits you might + * choose 'maxLimits' are your suffix + * + * ```js + * keyModifier(s: string) { return `${s}:maxLimits`; }, + * ``` + * + * If your modifier selects only `maxBindGroups` and `maxColorAttachments` + * then your suffix might be `maxBindGroups&maxColorAttachments` + * + * ```js + * keyModifier(s: string) { return `${s}:maxBindGroups&maxColorAttachments`; }, + * ``` + */ +export type DescriptorModifier = { + keyModifier(baseKey: string): string; + descriptorModifier( + adapter: GPUAdapter, + desc: CanonicalDeviceDescriptor | undefined + ): CanonicalDeviceDescriptor; +}; export class DevicePool { private holders: 'uninitialized' | 'failed' | DescriptorToHolderMap = 'uninitialized'; @@ -35,13 +60,13 @@ export class DevicePool { async acquire( recorder: TestCaseRecorder, descriptor: UncanonicalizedDeviceDescriptor | undefined, - descriptorModifierFn: DescriptorModifierFn | undefined + descriptorModifier: DescriptorModifier | undefined ): Promise { let errorMessage = ''; if (this.holders === 'uninitialized') { this.holders = new DescriptorToHolderMap(); try { - await this.holders.getOrCreate(recorder, undefined, descriptorModifierFn); + await this.holders.getOrCreate(recorder, undefined, descriptorModifier); } catch (ex) { this.holders = 'failed'; if (ex instanceof Error) { @@ -55,7 +80,7 @@ export class DevicePool { `WebGPU device failed to initialize${errorMessage}; not retrying` ); - const holder = await this.holders.getOrCreate(recorder, descriptor, descriptorModifierFn); + const holder = await this.holders.getOrCreate(recorder, descriptor, descriptorModifier); assert(holder.state === 'free', 'Device was in use on DevicePool.acquire'); holder.state = 'acquired'; @@ -150,9 +175,10 @@ class DescriptorToHolderMap { async getOrCreate( recorder: TestCaseRecorder, uncanonicalizedDescriptor: UncanonicalizedDeviceDescriptor | undefined, - descriptorModifierFn: DescriptorModifierFn | undefined + descriptorModifier: DescriptorModifier | undefined ): Promise { - const [descriptor, key] = canonicalizeDescriptor(uncanonicalizedDescriptor); + const [descriptor, baseKey] = canonicalizeDescriptor(uncanonicalizedDescriptor); + const key = descriptorModifier?.keyModifier(baseKey) || baseKey; // Quick-reject descriptors that are known to be unsupported already. if (this.unsupported.has(key)) { throw new SkipTestCase( @@ -174,7 +200,7 @@ class DescriptorToHolderMap { // No existing item was found; add a new one. let value; try { - value = await DeviceHolder.create(recorder, descriptor, descriptorModifierFn); + value = await DeviceHolder.create(recorder, descriptor, descriptorModifier); } catch (ex) { if (ex instanceof FeaturesNotSupported) { this.unsupported.add(key); @@ -313,13 +339,13 @@ class DeviceHolder implements DeviceProvider { static async create( recorder: TestCaseRecorder, descriptor: CanonicalDeviceDescriptor | undefined, - descriptorModifierFn: DescriptorModifierFn | undefined + descriptorModifier: DescriptorModifier | undefined ): Promise { const gpu = getGPU(recorder); const adapter = await gpu.requestAdapter(); assert(adapter !== null, 'requestAdapter returned null'); - if (descriptorModifierFn) { - descriptor = descriptorModifierFn(adapter, descriptor); + if (descriptorModifier) { + descriptor = descriptorModifier.descriptorModifier(adapter, descriptor); } if (!supportsFeature(adapter, descriptor)) { throw new FeaturesNotSupported('One or more features are not supported');