Skip to content

Commit

Permalink
webnn: Enforce input data type constraints for gemm and matmul
Browse files Browse the repository at this point in the history
As specified in webmachinelearning/webnn#646

Bug: 328567884
Change-Id: Ia55a214e7ad281ec3c8911e9116f388fac209d05
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5495161
Auto-Submit: Shiyi Zou <[email protected]>
Commit-Queue: Shiyi Zou <[email protected]>
Reviewed-by: Austin Sullivan <[email protected]>
Reviewed-by: ningxin hu <[email protected]>
Cr-Commit-Position: refs/heads/main@{#1294123}
  • Loading branch information
shiyi9801 authored and chromium-wpt-export-bot committed Apr 30, 2024
1 parent ce9bd55 commit cc4e41f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
33 changes: 19 additions & 14 deletions webnn/validation_tests/gemm.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const tests = [
b: {dataType: 'float32', dimensions: [2, 4]},
},
{
name: 'Test building gemm with aTranspose=true.',
name: '[gemm] Test building gemm with aTranspose=true.',
a: {dataType: 'float32', dimensions: [2, 3]},
b: {dataType: 'float32', dimensions: [2, 4]},
options: {
Expand All @@ -44,15 +44,15 @@ const tests = [
},
{
name:
'Throw if inputShapeA[0] is not equal to inputShapeB[0] with aTranspose=true.',
'[gemm] Throw if inputShapeA[0] is not equal to inputShapeB[0] with aTranspose=true.',
a: {dataType: 'float32', dimensions: [2, 3]},
b: {dataType: 'float32', dimensions: [3, 4]},
options: {
aTranspose: true,
},
},
{
name: 'Test building gemm with bTranspose=true.',
name: '[gemm] Test building gemm with bTranspose=true.',
a: {dataType: 'float32', dimensions: [2, 3]},
b: {dataType: 'float32', dimensions: [4, 3]},
options: {
Expand All @@ -62,30 +62,30 @@ const tests = [
},
{
name:
'Throw if inputShapeA[0] is not equal to inputShapeB[0] with bTranspose=true.',
'[gemm] Throw if inputShapeA[0] is not equal to inputShapeB[0] with bTranspose=true.',
a: {dataType: 'float32', dimensions: [2, 3]},
b: {dataType: 'float32', dimensions: [3, 4]},
options: {
bTranspose: true,
},
},
{
name: 'Throw if the rank of inputA is not 2.',
name: '[gemm] Throw if the rank of inputA is not 2.',
a: {dataType: 'float32', dimensions: [2, 3, 1]},
b: {dataType: 'float32', dimensions: [2, 4]},
},
{
name: 'Throw if the rank of inputB is not 2.',
name: '[gemm] Throw if the rank of inputB is not 2.',
a: {dataType: 'float32', dimensions: [2, 4]},
b: {dataType: 'float32', dimensions: [2, 3, 1]},
},
{
name: 'Throw if data types of two inputs do not match.',
name: '[gemm] Throw if data types of two inputs do not match.',
a: {dataType: 'float32', dimensions: [2, 3]},
b: {dataType: 'int32', dimensions: [3, 4]},
b: {dataType: 'float16', dimensions: [3, 4]},
},
{
name: 'Test building gemm with inputC.',
name: '[gemm] Test building gemm with inputC.',
a: {dataType: 'float32', dimensions: [2, 3]},
b: {dataType: 'float32', dimensions: [3, 4]},
options: {
Expand All @@ -94,7 +94,7 @@ const tests = [
output: {dataType: 'float32', dimensions: [2, 4]}
},
{
name: 'Test building gemm with scalar inputC.',
name: '[gemm] Test building gemm with scalar inputC.',
a: {dataType: 'float32', dimensions: [2, 3]},
b: {dataType: 'float32', dimensions: [3, 4]},
options: {
Expand All @@ -104,26 +104,31 @@ const tests = [
},
{
name:
'Throw if inputShapeC is not unidirectionally broadcastable to the output shape [inputShapeA[0], inputShapeB[1]].',
'[gemm] Throw if inputShapeC is not unidirectionally broadcastable to the output shape [inputShapeA[0], inputShapeB[1]].',
a: {dataType: 'float32', dimensions: [2, 3]},
b: {dataType: 'float32', dimensions: [3, 4]},
options: {
c: {dataType: 'float32', dimensions: [2, 3]},
},
},
{
name: '[gemm] Throw if the input data type is not floating point.',
a: {dataType: 'int32', dimensions: [2, 3]},
b: {dataType: 'int32', dimensions: [3, 4]}
},
{
name:
'Throw if data type of inputC does not match ones of inputA and inputB.',
'[gemm] Throw if data type of inputC does not match ones of inputA and inputB.',
a: {dataType: 'float32', dimensions: [3, 2]},
b: {dataType: 'float32', dimensions: [4, 3]},
options: {
c: {dataType: 'int32', dimensions: [2, 4]},
c: {dataType: 'float16', dimensions: [2, 4]},
aTranspose: true,
bTranspose: true,
},
},
{
name: 'Throw if the rank of inputC is 3.',
name: '[gemm] Throw if the rank of inputC is 3.',
a: {dataType: 'float32', dimensions: [3, 2]},
b: {dataType: 'float32', dimensions: [4, 3]},
options: {
Expand Down
9 changes: 8 additions & 1 deletion webnn/validation_tests/matmul.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,18 @@ const tests = [
},
output: {dataType: 'float32', dimensions: [2, 3, 5]}
},
{
name: '[matmul] Throw if the input data type is not floating point',
inputs: {
a: {dataType: 'uint32', dimensions: [2, 3, 4]},
b: {dataType: 'uint32', dimensions: [2, 4, 5]}
}
},
{
name: '[matmul] Throw if data type of two inputs don\'t match',
inputs: {
a: {dataType: 'float32', dimensions: [2, 3, 4]},
b: {dataType: 'int32', dimensions: [2, 4, 5]}
b: {dataType: 'float16', dimensions: [2, 4, 5]}
}
},
{
Expand Down

0 comments on commit cc4e41f

Please sign in to comment.