diff --git a/src/webgpu/util/check_contents.ts b/src/webgpu/util/check_contents.ts index 5e0d2dfcebdf..298e7ae4a9e9 100644 --- a/src/webgpu/util/check_contents.ts +++ b/src/webgpu/util/check_contents.ts @@ -44,7 +44,28 @@ export function checkElementsEqual( ): ErrorWithExtra | undefined { assert(actual.constructor === expected.constructor, 'TypedArray type mismatch'); assert(actual.length === expected.length, 'size mismatch'); - return checkElementsEqualGenerated(actual, i => expected[i]); + + let failedElementsFirstMaybe: number | undefined = undefined; + /** Sparse array with `true` for elements that failed. */ + const failedElements: (true | undefined)[] = []; + for (let i = 0; i < actual.length; ++i) { + if (actual[i] !== expected[i]) { + failedElementsFirstMaybe ??= i; + failedElements[i] = true; + } + } + + if (failedElementsFirstMaybe === undefined) { + return undefined; + } + + const failedElementsFirst = failedElementsFirstMaybe; + return failCheckElements({ + actual, + failedElements, + failedElementsFirst, + predicatePrinter: [{ leftHeader: 'expected ==', getValueForCell: index => expected[index] }], + }); } /** @@ -117,11 +138,29 @@ export function checkElementsEqualGenerated( actual: TypedArrayBufferView, generator: CheckElementsGenerator ): ErrorWithExtra | undefined { - const error = checkElementsPassPredicate(actual, (index, value) => value === generator(index), { + let failedElementsFirstMaybe: number | undefined = undefined; + /** Sparse array with `true` for elements that failed. */ + const failedElements: (true | undefined)[] = []; + for (let i = 0; i < actual.length; ++i) { + if (actual[i] !== generator(i)) { + failedElementsFirstMaybe ??= i; + failedElements[i] = true; + } + } + + if (failedElementsFirstMaybe === undefined) { + return undefined; + } + + const failedElementsFirst = failedElementsFirstMaybe; + const error = failCheckElements({ + actual, + failedElements, + failedElementsFirst, predicatePrinter: [{ leftHeader: 'expected ==', getValueForCell: index => generator(index) }], }); - // If there was an error, extend it with additional extras. - return error ? new ErrorWithExtra(error, () => ({ generator })) : undefined; + // Add more extras to the error. + return new ErrorWithExtra(error, () => ({ generator })); } /** @@ -133,14 +172,10 @@ export function checkElementsPassPredicate( predicate: CheckElementsPredicate, { predicatePrinter }: { predicatePrinter?: CheckElementsSupplementalTableRows } ): ErrorWithExtra | undefined { - const size = actual.length; - const ctor = actual.constructor as TypedArrayBufferViewConstructor; - const printAsFloat = ctor === Float16Array || ctor === Float32Array || ctor === Float64Array; - let failedElementsFirstMaybe: number | undefined = undefined; /** Sparse array with `true` for elements that failed. */ const failedElements: (true | undefined)[] = []; - for (let i = 0; i < size; ++i) { + for (let i = 0; i < actual.length; ++i) { if (!predicate(i, actual[i])) { failedElementsFirstMaybe ??= i; failedElements[i] = true; @@ -150,7 +185,35 @@ export function checkElementsPassPredicate( if (failedElementsFirstMaybe === undefined) { return undefined; } + const failedElementsFirst = failedElementsFirstMaybe; + return failCheckElements({ actual, failedElements, failedElementsFirst, predicatePrinter }); +} + +interface CheckElementsFailOpts { + actual: TypedArrayBufferView; + failedElements: (true | undefined)[]; + failedElementsFirst: number; + predicatePrinter?: CheckElementsSupplementalTableRows; +} + +/** + * Implements the failure case of some checkElementsX helpers above. This allows those functions to + * implement their checks directly without too many function indirections in between. + * + * Note: Separating this into its own function significantly speeds up the non-error case in + * Chromium (though this may be V8-specific behavior). + */ +function failCheckElements({ + actual, + failedElements, + failedElementsFirst, + predicatePrinter, +}: CheckElementsFailOpts): ErrorWithExtra { + const size = actual.length; + const ctor = actual.constructor as TypedArrayBufferViewConstructor; + const printAsFloat = ctor === Float16Array || ctor === Float32Array || ctor === Float64Array; + const failedElementsLast = failedElements.length - 1; // Include one extra non-failed element at the beginning and end (if they exist), for context.