Skip to content

Commit

Permalink
Refactor selectDeviceOrSkipTest system (#4150)
Browse files Browse the repository at this point in the history
The existing system selected the device immediately which means
code like this

    t.selectDeviceOrSkipTest('timestamp-query');
    t.selectDeviceOrSkipTesT('float32-renderable'); // fail!

would fail on the 2nd line because a device had already been
requested on the first line.

Refactored so that the various requirements are merged and
only at the end is a device requested.

Tested here:
greggman@d4abf41
  • Loading branch information
greggman authored Jan 17, 2025
1 parent ac3cd91 commit 7cdd7c8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 18 deletions.
10 changes: 8 additions & 2 deletions src/webgpu/api/operation/device/all_limits_and_features.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import { CanonicalDeviceDescriptor, DescriptorModifier } from '../../../util/dev

/**
* Gets the adapter limits as a standard JavaScript object.
* MAINTENANCE_TODO: Remove this and use the same function from gpu_test.ts once minSubgroupSize is removed
* The reason this is separate now is we want this test to fail. `mnSubgroupSize` should never have
* be added and this test exists to see that the same mistake doesn't happen in the future.
*/
function getAdapterLimitsAsDeviceRequiredLimits(adapter: GPUAdapter) {
const requiredLimits: Record<string, GPUSize64> = {};
Expand Down Expand Up @@ -39,7 +42,7 @@ function setAllLimitsToAdapterLimitsAndAddAllFeatures(
* Used to request a device with all the max limits of the adapter.
*/
export class AllLimitsAndFeaturesGPUTestSubcaseBatchState extends GPUTestSubcaseBatchState {
override selectDeviceOrSkipTestCase(
override requestDeviceWithRequiredParametersOrSkip(
descriptor: DeviceSelectionDescriptor,
descriptorModifier?: DescriptorModifier
): void {
Expand All @@ -54,7 +57,10 @@ export class AllLimitsAndFeaturesGPUTestSubcaseBatchState extends GPUTestSubcase
return `${baseKey}:AllLimitsAndFeaturesTest`;
},
};
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), mod);
super.requestDeviceWithRequiredParametersOrSkip(
initUncanonicalizedDeviceDescriptor(descriptor),
mod
);
}
}

Expand Down
52 changes: 44 additions & 8 deletions src/webgpu/gpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,32 @@ export type DeviceSelectionDescriptor =

export function initUncanonicalizedDeviceDescriptor(
descriptor: DeviceSelectionDescriptor
): UncanonicalizedDeviceDescriptor | undefined {
): UncanonicalizedDeviceDescriptor {
if (typeof descriptor === 'string') {
return { requiredFeatures: [descriptor] };
} else if (descriptor instanceof Array) {
return {
requiredFeatures: descriptor.filter(f => f !== undefined) as GPUFeatureName[],
};
} else {
return descriptor;
return descriptor ?? {};
}
}

type DeviceDescriptorSimplified = {
requiredFeatures: GPUFeatureName[];
requiredLimits: Record<string, number>;
defaultQueue: GPUQueueDescriptor;
};

function mergeDeviceSelectionDescriptorIntoDeviceDescriptor(
src: DeviceSelectionDescriptor,
dst: DeviceDescriptorSimplified
) {
const srcFixed = initUncanonicalizedDeviceDescriptor(src);
if (srcFixed) {
dst.requiredFeatures.push(...(srcFixed.requiredFeatures ?? []));
Object.assign(dst.requiredLimits, srcFixed.requiredLimits ?? {});
}
}

Expand All @@ -109,6 +126,12 @@ export class GPUTestSubcaseBatchState extends SubcaseBatchState {
private provider: Promise<DeviceProvider> | undefined;
/** Provider for mismatched device. */
private mismatchedProvider: Promise<DeviceProvider> | undefined;
/** The accumulated skip-if requirements for this subcase */
private skipIfRequirements: DeviceDescriptorSimplified = {
requiredFeatures: [],
requiredLimits: {},
defaultQueue: {},
};

override async postInit(): Promise<void> {
// Skip all subcases if there's no device.
Expand All @@ -128,7 +151,7 @@ export class GPUTestSubcaseBatchState extends SubcaseBatchState {
/** @internal MAINTENANCE_TODO: Make this not visible to test code? */
acquireProvider(): Promise<DeviceProvider> {
if (this.provider === undefined) {
this.selectDeviceOrSkipTestCase(undefined);
this.requestDeviceWithRequiredParametersOrSkip(this.skipIfRequirements);
}
assert(this.provider !== undefined);
return this.provider;
Expand All @@ -145,7 +168,7 @@ export class GPUTestSubcaseBatchState extends SubcaseBatchState {
*
* If the request isn't supported, throws a SkipTestCase exception to skip the entire test case.
*/
selectDeviceOrSkipTestCase(
requestDeviceWithRequiredParametersOrSkip(
descriptor: DeviceSelectionDescriptor,
descriptorModifier?: DescriptorModifier
): void {
Expand All @@ -159,6 +182,16 @@ export class GPUTestSubcaseBatchState extends SubcaseBatchState {
this.provider.catch(() => {});
}

/**
* Some tests or cases need particular feature flags or limits to be enabled.
* Call this function with a descriptor or feature name (or `undefined`) to add
* features or limits required by the subcase. If the features or limits are not
* available a SkipTestCase exception will be thrown to skip the entire test case.
*/
selectDeviceOrSkipTestCase(descriptor: DeviceSelectionDescriptor): void {
mergeDeviceSelectionDescriptorIntoDeviceDescriptor(descriptor, this.skipIfRequirements);
}

/**
* Convenience function for {@link selectDeviceOrSkipTestCase}.
* Select a device with the features required by these texture format(s).
Expand Down Expand Up @@ -1310,7 +1343,7 @@ function getAdapterLimitsAsDeviceRequiredLimits(adapter: GPUAdapter) {
* t.skipIf(!(limit >= 2)); // Good. Skips if limits is not >= 2. undefined is not >= 2.
* ```
*/
function removeNonExistantLimits(adapter: GPUAdapter, limits: Record<string, GPUSize64>) {
function removeNonExistentLimits(adapter: GPUAdapter, limits: Record<string, GPUSize64>) {
const filteredLimits: Record<string, GPUSize64> = {};
const adapterLimits = adapter.limits as unknown as Record<string, GPUSize64>;
for (const [limit, value] of Object.entries(limits)) {
Expand All @@ -1330,7 +1363,7 @@ function applyLimitsToDescriptor(
requiredFeatures: [],
defaultQueue: {},
...desc,
requiredLimits: removeNonExistantLimits(adapter, getRequiredLimits(adapter)),
requiredLimits: removeNonExistentLimits(adapter, getRequiredLimits(adapter)),
};
return descWithMaxLimits;
}
Expand Down Expand Up @@ -1382,7 +1415,7 @@ export class RequiredLimitsGPUTestSubcaseBatchState extends GPUTestSubcaseBatchS
super(recorder, params);
this.requiredLimitsHelper = requiredLimitsHelper;
}
override selectDeviceOrSkipTestCase(
override requestDeviceWithRequiredParametersOrSkip(
descriptor: DeviceSelectionDescriptor,
descriptorModifier?: DescriptorModifier
): void {
Expand All @@ -1398,7 +1431,10 @@ export class RequiredLimitsGPUTestSubcaseBatchState extends GPUTestSubcaseBatchS
return `${baseKey}:${requiredLimitsHelper.key()}`;
},
};
super.selectDeviceOrSkipTestCase(initUncanonicalizedDeviceDescriptor(descriptor), mod);
super.requestDeviceWithRequiredParametersOrSkip(
initUncanonicalizedDeviceDescriptor(descriptor),
mod
);
}
}

Expand Down
8 changes: 0 additions & 8 deletions src/webgpu/util/device_pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,6 @@ class DescriptorToHolderMap {
export type UncanonicalizedDeviceDescriptor = {
requiredFeatures?: Iterable<GPUFeatureName>;
requiredLimits?: Record<string, GPUSize32>;
/** @deprecated this field cannot be used */
nonGuaranteedFeatures?: undefined;
/** @deprecated this field cannot be used */
nonGuaranteedLimits?: undefined;
/** @deprecated this field cannot be used */
extensions?: undefined;
/** @deprecated this field cannot be used */
features?: undefined;
};
export type CanonicalDeviceDescriptor = Omit<
Required<GPUDeviceDescriptor>,
Expand Down

0 comments on commit 7cdd7c8

Please sign in to comment.