Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebNN: Support AllowSharedBufferSource for constant #49470

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions webnn/conformance_tests/shared_arraybuffer_constant.https.any.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// META: title=test WebNN API constant with shared array buffer
// META: global=window,dedicatedworker
// META: variant=?cpu
// META: variant=?gpu
// META: variant=?npu
// META: script=../resources/utils_validation.js
// META: script=../resources/utils.js
// META: timeout=long

'use strict';

// Skip tests if WebNN is unimplemented.
promise_setup(async () => {
assert_implements(navigator.ml, 'missing navigator.ml');
});

// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-constant-buffer

const testContents = Int32Array.from([0, 1, 2, 3, 4, 5, 6, 7]);
const sharedArrayBuffer = new SharedArrayBuffer(testContents.byteLength);
const typedArray = new Int32Array(sharedArrayBuffer);
typedArray.set(testContents);

let mlContext;
let mlGraph;
let outputTensor;
promise_setup(async () => {
try {
mlContext = await navigator.ml.createContext(contextOptions);
} catch (e) {
throw new AssertionError(
`Unable to create mlContext for ${variant} variant. ${e}`);
}

try {
outputTensor = await mlContext.createTensor({
dataType: 'int32',
shape: [8],
readable: true,
});
} catch (e) {
throw new AssertionError(
`Unable to create tensor for ${variant} variant. ${e}`);
}
});

promise_test(async () => {
const builder = new MLGraphBuilder(mlContext);
const constant =
builder.constant({dataType: 'int32', shape: [8]}, sharedArrayBuffer);
const output = builder.identity(constant);
const mlGraph = await builder.build({output});

mlContext.dispatch(mlGraph, {}, {output: outputTensor});
const results = new Int32Array(await mlContext.readTensor(outputTensor));

assert_array_equals(results, testContents);
}, `constant() with a SharedArrayBuffer`);

promise_test(async () => {
const builder = new MLGraphBuilder(mlContext);
const constant =
builder.constant({dataType: 'int32', shape: [8]}, typedArray);
const output = builder.identity(constant);
const mlGraph = await builder.build({output});

mlContext.dispatch(mlGraph, {}, {output: outputTensor});
const results = new Int32Array(await mlContext.readTensor(outputTensor));

assert_array_equals(results, testContents);
}, `constant() with a typeArray from a SharedArrayBuffer`);
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Cross-Origin-Embedder-Policy: require-corp
Cross-Origin-Opener-Policy: same-origin
98 changes: 68 additions & 30 deletions webnn/validation_tests/constant.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,125 +8,152 @@
'use strict';

const tests = [
// Tests for constant(descriptor, bufferView)
// Tests for constant(descriptor, buffer)
{
name:
'[constant] Test building a 0-D scalar constant with empty dimensions',
descriptor: {dataType: 'float32', shape: []},
bufferView: {type: Float32Array, byteLength: 1 * 4},
buffer: {type: Float32Array, byteLength: 1 * 4},
output: {dataType: 'float32', shape: []}
},
{
name: '[constant] Test building a constant with float32 data type',
descriptor: {dataType: 'float32', shape: [2, 3]},
bufferView: {type: Float32Array, byteLength: 6 * 4},
buffer: {type: Float32Array, byteLength: 6 * 4},
output: {dataType: 'float32', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for float32 doesn\'t match the given dimensions',
'[constant] Throw if byte length of float32 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'float32', shape: [2, 3]},
bufferView: {
buffer: {
type: Float32Array,
byteLength: 6 * 4 - 4 // The bufferView's byte length is less than the
byteLength: 6 * 4 - 4 // The buffer's byte length is less than the
// one by given dimensions
}
},
// TODO (crbug.com/329702838): Test building a constant with float16 data type
{
name: '[constant] Test building a constant with int32 data type',
descriptor: {dataType: 'int32', shape: [2, 3]},
bufferView: {type: Int32Array, byteLength: 6 * 4},
buffer: {type: Int32Array, byteLength: 6 * 4},
output: {dataType: 'int32', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for int32 doesn\'t match the given dimensions',
'[constant] Throw if byte length of int32 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'int32', shape: [2, 3]},
bufferView: {
buffer: {
type: Int32Array,
byteLength: 6 * 4 + 4 // The bufferView's byte length is greater than the
byteLength: 6 * 4 + 4 // The buffer's byte length is greater than the
// one by given dimensions
}
},
{
name: '[constant] Test building a constant with uint32 data type',
descriptor: {dataType: 'uint32', shape: [2, 3]},
bufferView: {type: Uint32Array, byteLength: 6 * 4},
buffer: {type: Uint32Array, byteLength: 6 * 4},
output: {dataType: 'uint32', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for uint32 doesn\'t match the given dimensions',
'[constant] Throw if byte length of uint32 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'uint32', shape: [2, 3]},
bufferView: {type: Uint32Array, byteLength: 6 * 4 + 4}
buffer: {type: Uint32Array, byteLength: 6 * 4 + 4}
},
{
name: '[constant] Test building a constant with int64 data type',
descriptor: {dataType: 'int64', shape: [2, 3]},
bufferView: {type: BigInt64Array, byteLength: 6 * 8},
buffer: {type: BigInt64Array, byteLength: 6 * 8},
output: {dataType: 'int64', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for int64 doesn\'t match the given dimensions',
'[constant] Throw if byte length of int64 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'int64', shape: [2, 3]},
bufferView: {type: BigInt64Array, byteLength: 6 * 8 + 8}
buffer: {type: BigInt64Array, byteLength: 6 * 8 + 8}
},
{
name: '[constant] Test building a constant with uint64 data type',
descriptor: {dataType: 'uint64', shape: [2, 3]},
bufferView: {type: BigUint64Array, byteLength: 6 * 8},
buffer: {type: BigUint64Array, byteLength: 6 * 8},
output: {dataType: 'uint64', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for uint64 doesn\'t match the given dimensions',
'[constant] Throw if byte length of uint64 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'uint64', shape: [2, 3]},
bufferView: {type: BigUint64Array, byteLength: 6 * 8 + 8}
buffer: {type: BigUint64Array, byteLength: 6 * 8 + 8}
},
{
name: '[constant] Test building a constant with int8 data type',
descriptor: {dataType: 'int8', shape: [2, 3]},
bufferView: {type: Int8Array, byteLength: 6 * 1},
buffer: {type: Int8Array, byteLength: 6 * 1},
output: {dataType: 'int8', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for int8 doesn\'t match the given dimensions',
'[constant] Throw if byte length of int8 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'int8', shape: [2, 3]},
bufferView: {type: Int8Array, byteLength: 6 * 4 - 4}
buffer: {type: Int8Array, byteLength: 6 * 4 - 4}
},
{
name: '[constant] Test building a constant with uint8 data type',
descriptor: {dataType: 'uint8', shape: [2, 3]},
bufferView: {type: Uint8Array, byteLength: 6 * 1},
buffer: {type: Uint8Array, byteLength: 6 * 1},
output: {dataType: 'uint8', shape: [2, 3]}
},
{
name:
'[constant] Throw if byte length of bufferView for uint8 doesn\'t match the given dimensions',
'[constant] Throw if byte length of uint8 buffer doesn\'t match the given dimensions',
descriptor: {dataType: 'uint8', shape: [2, 3]},
bufferView: {type: Uint8Array, byteLength: 6 * 4 - 4}
buffer: {type: Uint8Array, byteLength: 6 * 4 - 4}
},
{
name: '[constant] Throw if a dimension is 0',
descriptor: {dataType: 'float32', shape: [2, 0]},
bufferView: {type: Float32Array, byteLength: 2 * 4}
buffer: {type: Float32Array, byteLength: 2 * 4}
},
{
name:
'[constant] Throw if bufferView type doesn\'t match the operand data type',
'[constant] Throw if buffer view\'s type doesn\'t match the operand data type',
descriptor: {dataType: 'float32', shape: [2, 3]},
bufferView: {type: Int32Array, byteLength: 6 * 4}
buffer: {type: Int32Array, byteLength: 6 * 4},
viewTestOnly: true
}
];

tests.forEach(
test => promise_test(async t => {
const builder = new MLGraphBuilder(context);
const buffer = new ArrayBuffer(test.bufferView.byteLength);
const bufferView = new test.bufferView.type(buffer);
const buffer = new ArrayBuffer(test.buffer.byteLength);
const bufferView = new test.buffer.type(buffer);
const sharedBuffer = new SharedArrayBuffer(test.buffer.byteLength);
const sharedBufferView = new test.buffer.type(sharedBuffer);

if (test.viewTestOnly === undefined || test.viewTestOnly === false) {
// Test building constant from ArrayBuffer.
if (test.output) {
const constantOperand = builder.constant(test.descriptor, buffer);
assert_equals(constantOperand.dataType, test.output.dataType);
assert_array_equals(constantOperand.shape, test.output.shape);
} else {
assert_throws_js(
TypeError, () => builder.constant(test.descriptor, buffer));
}
// Test building constant from SharedArrayBuffer.
if (test.output) {
const constantOperand =
builder.constant(test.descriptor, sharedBuffer);
assert_equals(constantOperand.dataType, test.output.dataType);
assert_array_equals(constantOperand.shape, test.output.shape);
} else {
assert_throws_js(
TypeError, () => builder.constant(test.descriptor, sharedBuffer));
}
}

// Test building constant from ArrayBufferView.
if (test.output) {
const constantOperand = builder.constant(test.descriptor, bufferView);
assert_equals(constantOperand.dataType, test.output.dataType);
Expand All @@ -135,4 +162,15 @@ tests.forEach(
assert_throws_js(
TypeError, () => builder.constant(test.descriptor, bufferView));
}
// Test building constant from shared ArrayBufferView.
if (test.output) {
const constantOperand =
builder.constant(test.descriptor, sharedBufferView);
assert_equals(constantOperand.dataType, test.output.dataType);
assert_array_equals(constantOperand.shape, test.output.shape);
} else {
assert_throws_js(
TypeError,
() => builder.constant(test.descriptor, sharedBufferView));
}
}, test.name));
2 changes: 2 additions & 0 deletions webnn/validation_tests/constant.https.any.js.headers
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Cross-Origin-Embedder-Policy: require-corp
Cross-Origin-Opener-Policy: same-origin