Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use AllowSharedBufferSource for MLGraphBuilder.constant() #790

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions index.bs
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,8 @@ interface MLGraphBuilder {
MLOperand input(USVString name, MLOperandDescriptor descriptor);

// Create an operand for a graph constant.
MLOperand constant(MLOperandDescriptor descriptor, ArrayBufferView bufferView);
MLOperand constant(MLOperandDescriptor descriptor,
[AllowShared] ArrayBufferView bufferView);

// Create a scalar operand from the specified number of the specified type.
MLOperand constant(MLOperandDataType type, MLNumber value);
Expand Down Expand Up @@ -2012,8 +2013,7 @@ partial dictionary MLOpSupportLimits {
input, builder.constant(input.dataType, options.minValue));
} else {
return builder.min(
builder.max(
input, builder.constant(input.dataType, options.minValue)),
builder.max(input, builder.constant(input.dataType, options.minValue)),
builder.constant(input.dataType, options.maxValue));
}
}
Expand Down Expand Up @@ -3421,8 +3421,8 @@ partial dictionary MLOpSupportLimits {
{shape: [4, 3]},
new Float32Array([0, 1, 2, 10, 11, 12, 20, 21, 22, 30, 31, 32]));

const indices1 = builder.constant(
{dataType: 'uint32', shape: [2]}, new Uint32Array([3, 1]));
const indices1 =
builder.constant({dataType: 'uint32', shape: [2]}, new Uint32Array([3, 1]));

const indices2 = builder.constant(
{dataType: 'uint32', shape: [3]}, new Uint32Array([2, 1, 1]));
Expand Down Expand Up @@ -3937,10 +3937,7 @@ partial dictionary MLOpSupportLimits {
let hiddenState = options.initialHiddenState;

if (!hiddenState) {
const desc = {
dataType: 'float32',
shape: [numDirections, 1, hiddenSize]
};
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
const totalSize = numDirections * hiddenSize;
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
}
Expand Down Expand Up @@ -4619,8 +4616,7 @@ partial dictionary MLOpSupportLimits {
const reduceOptions = {axes: [2, 3], keepDimensions: true};
const mean = builder.reduceMean(input, reduceOptions);
const variance = builder.reduceMean(
builder.pow(
builder.sub(input, mean), builder.constant(input.dataType, 2)),
builder.pow(builder.sub(input, mean), builder.constant(input.dataType, 2)),
reduceOptions);

// The scale and bias values are applied per input feature
Expand Down Expand Up @@ -4765,8 +4761,7 @@ partial dictionary MLOpSupportLimits {
const reduceOptions = {axes: [1, 2, 3], keepDimensions: true};
const mean = builder.reduceMean(input, reduceOptions);
const variance = builder.reduceMean(
builder.pow(
builder.sub(input, mean), builder.constant(input.dataType, 2)),
builder.pow(builder.sub(input, mean), builder.constant(input.dataType, 2)),
reduceOptions);

// The scale and bias tensors are of the shape of the input
Expand Down Expand Up @@ -5222,19 +5217,13 @@ partial dictionary MLOpSupportLimits {
let cellState = options.initialCellState;

if (!hiddenState) {
const desc = {
dataType: 'float32',
shape: [numDirections, 1, hiddenSize]
};
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
const totalSize = numDirections * hiddenSize;
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
}

if (!cellState) {
const desc = {
dataType: 'float32',
shape: [numDirections, 1, hiddenSize]
};
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
const totalSize = numDirections * hiddenSize;
cellState = builder.constant(desc, new Float32Array(totalSize).fill(0));
}
Expand Down Expand Up @@ -5878,8 +5867,7 @@ partial dictionary MLOpSupportLimits {
<pre highlight="js">
// input: [[1,2,3], [4,5,6]]
const input = builder.constant(
{dataType: 'float32', shape: [2, 3]},
new Float32Array([1, 2, 3, 4, 5, 6]));
{dataType: 'float32', shape: [2, 3]}, new Float32Array([1, 2, 3, 4, 5, 6]));

const beginningPadding = [1, 2];
const endingPadding = [1, 2];
Expand Down
Loading