diff --git a/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js b/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js index 18aebf98b4c1..9fac08f7beb5 100644 --- a/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js +++ b/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js @@ -46,6 +46,40 @@ const scatterElementsTests = [ } } }, + { + 'name': 'Scatter elements along axis 0 and constant indices', + 'graph': { + 'inputs': { + 'input': { + 'data': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + 'descriptor': {shape: [3, 3], dataType: 'float32'} + }, + 'indices': { + 'data': [1, 0, 2, 0, 2, 1], + 'descriptor': {shape: [2, 3], dataType: 'int32'}, + 'constant': true + }, + 'updates': { + 'data': [1.0, 1.1, 1.2, 2.0, 2.1, 2.2], + 'descriptor': {shape: [2, 3], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'scatterElements', + 'arguments': [ + {'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, + {'options': {'axis': 0}} + ], + 'outputs': 'output' + }], + 'expectedOutputs': { + 'output': { + 'data': [2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2], + 'descriptor': {shape: [3, 3], dataType: 'float32'} + } + } + } + }, { 'name': 'Scatter elements along axis 1', 'graph': { @@ -78,6 +112,40 @@ const scatterElementsTests = [ } } } + }, + { + 'name': 'Scatter elements along axis 1 and constant indices', + 'graph': { + 'inputs': { + 'input': { + 'data': [1.0, 2.0, 3.0, 4.0, 5.0], + 'descriptor': {shape: [1, 5], dataType: 'float32'} + }, + 'indices': { + 'data': [1, 3], + 'descriptor': {shape: [1, 2], dataType: 'int32'}, + 'constant': true + }, + 'updates': { + 'data': [1.1, 2.1], + 'descriptor': {shape: [1, 2], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'scatterElements', + 'arguments': [ + {'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, + {'options': {'axis': 1}} + ], + 'outputs': 'output' + }], + 'expectedOutputs': { + 'output': { + 'data': [1.0, 1.1, 3.0, 2.1, 5.0], + 'descriptor': {shape: [1, 5], dataType: 'float32'} + } + } + } } ];