Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DeviceModifier #4103

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/webgpu/api/operation/device/all_limits_and_features.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
GPUTestSubcaseBatchState,
initUncanonicalizedDeviceDescriptor,
} from '../../../gpu_test.js';
import { CanonicalDeviceDescriptor, DescriptorModifierFn } from '../../../util/device_pool.js';
import { CanonicalDeviceDescriptor, DescriptorModifier } from '../../../util/device_pool.js';

/**
* Gets the adapter limits as a standard JavaScript object.
Expand Down Expand Up @@ -36,18 +36,25 @@ function setAllLimitsToAdapterLimitsAndAddAllFeatures(
}

/**
* Used by MaxLimitsTest to request a device with all the max limits of the adapter.
* Used to request a device with all the max limits of the adapter.
*/
export class AllLimitsAndFeaturesGPUTestSubcaseBatchState extends GPUTestSubcaseBatchState {
override selectDeviceOrSkipTestCase(
descriptor: DeviceSelectionDescriptor,
descriptorModifierFn?: DescriptorModifierFn
descriptorModifier?: DescriptorModifier
): void {
const wrapper = (adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) => {
desc = descriptorModifierFn ? descriptorModifierFn(adapter, desc) : desc;
return setAllLimitsToAdapterLimitsAndAddAllFeatures(adapter, desc);
const mod: DescriptorModifier = {
descriptorModifier(adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) {
desc = descriptorModifier?.descriptorModifier
? descriptorModifier.descriptorModifier(adapter, desc)
: desc;
return setAllLimitsToAdapterLimitsAndAddAllFeatures(adapter, desc);
},
keyModifier(baseKey: string) {
return `${baseKey}:AllLimitsAndFeaturesTest`;
},
};
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), wrapper);
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), mod);
}
}

Expand Down
23 changes: 15 additions & 8 deletions src/webgpu/gpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import { CommandBufferMaker, EncoderType } from './util/command_buffer_maker.js'
import { ScalarType } from './util/conversion.js';
import {
CanonicalDeviceDescriptor,
DescriptorModifierFn,
DescriptorModifier,
DevicePool,
DeviceProvider,
UncanonicalizedDeviceDescriptor,
Expand Down Expand Up @@ -156,13 +156,13 @@ export class GPUTestSubcaseBatchState extends SubcaseBatchState {
*/
selectDeviceOrSkipTestCase(
descriptor: DeviceSelectionDescriptor,
descriptorModifierFn?: DescriptorModifierFn
descriptorModifier?: DescriptorModifier
): void {
assert(this.provider === undefined, "Can't selectDeviceOrSkipTestCase() multiple times");
this.provider = devicePool.acquire(
this.recorder,
initUncanonicalizedDeviceDescriptor(descriptor),
descriptorModifierFn
descriptorModifier
);
// Suppress uncaught promise rejection (we'll catch it later).
this.provider.catch(() => {});
Expand Down Expand Up @@ -1334,13 +1334,20 @@ function setAllLimitsToAdapterLimits(
export class MaxLimitsGPUTestSubcaseBatchState extends GPUTestSubcaseBatchState {
override selectDeviceOrSkipTestCase(
descriptor: DeviceSelectionDescriptor,
descriptorModifierFn?: DescriptorModifierFn
descriptorModifier?: DescriptorModifier
): void {
const wrapper = (adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) => {
desc = descriptorModifierFn ? descriptorModifierFn(adapter, desc) : desc;
return setAllLimitsToAdapterLimits(adapter, desc);
const mod: DescriptorModifier = {
descriptorModifier(adapter: GPUAdapter, desc: CanonicalDeviceDescriptor | undefined) {
desc = descriptorModifier?.descriptorModifier
? descriptorModifier.descriptorModifier(adapter, desc)
: desc;
return setAllLimitsToAdapterLimits(adapter, desc);
},
keyModifier(baseKey: string) {
return `${baseKey}:MaxLimits`;
},
};
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), wrapper);
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), mod);
}
}

Expand Down
52 changes: 39 additions & 13 deletions src/webgpu/util/device_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,35 @@ class TestFailedButDeviceReusable extends Error {}
class FeaturesNotSupported extends Error {}
export class TestOOMedShouldAttemptGC extends Error {}

export type DescriptorModifierFn = (
adapter: GPUAdapter,
desc: CanonicalDeviceDescriptor | undefined
) => CanonicalDeviceDescriptor;
/**
* DescriptorModifier lets you supply a function to select a device
* based on the limits/features available from the adapter.
* Devices pooled based on a key and that key is derived before
* an adapter is requested. That means you select key without
* knowledge of what the adapter will provide. You do this by
* providing a keyModifier function that appends a suffix.
*
* For example: If your modifier adds all the limits you might
* choose 'maxLimits' are your suffix
*
* ```js
* keyModifier(s: string) { return `${s}:maxLimits`; },
* ```
*
* If your modifier selects only `maxBindGroups` and `maxColorAttachments`
* then your suffix might be `maxBindGroups&maxColorAttachments`
*
* ```js
* keyModifier(s: string) { return `${s}:maxBindGroups&maxColorAttachments`; },
* ```
*/
export type DescriptorModifier = {
keyModifier(baseKey: string): string;
descriptorModifier(
adapter: GPUAdapter,
desc: CanonicalDeviceDescriptor | undefined
): CanonicalDeviceDescriptor;
};

export class DevicePool {
private holders: 'uninitialized' | 'failed' | DescriptorToHolderMap = 'uninitialized';
Expand All @@ -35,13 +60,13 @@ export class DevicePool {
async acquire(
recorder: TestCaseRecorder,
descriptor: UncanonicalizedDeviceDescriptor | undefined,
descriptorModifierFn: DescriptorModifierFn | undefined
descriptorModifier: DescriptorModifier | undefined
): Promise<DeviceProvider> {
let errorMessage = '';
if (this.holders === 'uninitialized') {
this.holders = new DescriptorToHolderMap();
try {
await this.holders.getOrCreate(recorder, undefined, descriptorModifierFn);
await this.holders.getOrCreate(recorder, undefined, descriptorModifier);
} catch (ex) {
this.holders = 'failed';
if (ex instanceof Error) {
Expand All @@ -55,7 +80,7 @@ export class DevicePool {
`WebGPU device failed to initialize${errorMessage}; not retrying`
);

const holder = await this.holders.getOrCreate(recorder, descriptor, descriptorModifierFn);
const holder = await this.holders.getOrCreate(recorder, descriptor, descriptorModifier);

assert(holder.state === 'free', 'Device was in use on DevicePool.acquire');
holder.state = 'acquired';
Expand Down Expand Up @@ -150,9 +175,10 @@ class DescriptorToHolderMap {
async getOrCreate(
recorder: TestCaseRecorder,
uncanonicalizedDescriptor: UncanonicalizedDeviceDescriptor | undefined,
descriptorModifierFn: DescriptorModifierFn | undefined
descriptorModifier: DescriptorModifier | undefined
): Promise<DeviceHolder> {
const [descriptor, key] = canonicalizeDescriptor(uncanonicalizedDescriptor);
const [descriptor, baseKey] = canonicalizeDescriptor(uncanonicalizedDescriptor);
const key = descriptorModifier?.keyModifier(baseKey) || baseKey;
// Quick-reject descriptors that are known to be unsupported already.
if (this.unsupported.has(key)) {
throw new SkipTestCase(
Expand All @@ -174,7 +200,7 @@ class DescriptorToHolderMap {
// No existing item was found; add a new one.
let value;
try {
value = await DeviceHolder.create(recorder, descriptor, descriptorModifierFn);
value = await DeviceHolder.create(recorder, descriptor, descriptorModifier);
} catch (ex) {
if (ex instanceof FeaturesNotSupported) {
this.unsupported.add(key);
Expand Down Expand Up @@ -313,13 +339,13 @@ class DeviceHolder implements DeviceProvider {
static async create(
recorder: TestCaseRecorder,
descriptor: CanonicalDeviceDescriptor | undefined,
descriptorModifierFn: DescriptorModifierFn | undefined
descriptorModifier: DescriptorModifier | undefined
): Promise<DeviceHolder> {
const gpu = getGPU(recorder);
const adapter = await gpu.requestAdapter();
assert(adapter !== null, 'requestAdapter returned null');
if (descriptorModifierFn) {
descriptor = descriptorModifierFn(adapter, descriptor);
if (descriptorModifier) {
descriptor = descriptorModifier.descriptorModifier(adapter, descriptor);
}
if (!supportsFeature(adapter, descriptor)) {
throw new FeaturesNotSupported('One or more features are not supported');
Expand Down
Loading