Skip to content

Commit 1d97d6e

Browse files
authored
[webgpu] fix Split operator implementation when input is 1D (microsoft#23376)
### Description [webgpu] fix Split operator implementation when input is 1D
1 parent e51bcfb commit 1d97d6e

File tree

1 file changed

+2
-2
lines changed
  • onnxruntime/core/providers/webgpu/tensor

1 file changed

+2
-2
lines changed

onnxruntime/core/providers/webgpu/tensor/split.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ Status SplitProgram::GenerateShaderCode(ShaderHelper& shader) const {
6565

6666
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size")
6767
<< " var indices = " << input.OffsetToIndices("global_idx") << ";\n"
68-
<< " var index = indices[" << axis_ << "];\n"
68+
<< " var index = " << input.IndicesGet("indices", axis_) << ";\n"
6969
<< " let output_number = calculate_output_index(index);\n"
7070
<< " if (output_number != 0u) {\n"
7171
<< " index -= uniforms.sizes_in_split_axis[output_number - 1u];\n"
72-
<< " indices[" << axis_ << "] = index;\n"
72+
<< " " << input.IndicesSet("indices", axis_, "index") << "\n"
7373
<< " }\n"
7474
<< " write_buffer_data(output_number, global_idx, indices);\n";
7575

0 commit comments

Comments
 (0)