Skip to content

Commit

Permalink
[WebNN EP] Fix bug in Softmax (microsoft#17665)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
For now, WebNN Softmax only support 2D (or implicitly coerce to 2D)
inputs and the last axis.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Fallback some cases to pass the CI.
  • Loading branch information
zesongw authored Sep 26, 2023
1 parent 614af37 commit 93f22aa
Showing 1 changed file with 16 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,18 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto input_size = input_shape.size();
// WebNN Softmax only support 2d input shape, reshape input to 2d.
if (input_size != 2) {
int32_t new_shape_0 = SafeInt<int32_t>(input_shape.data()[0]);
for (size_t i = 1; i < input_size - 1; i++) {
new_shape_0 *= input_shape.data()[i];
}
emscripten::val new_shape = emscripten::val::array();
new_shape.call<void>("push", new_shape_0);
new_shape.call<void>("push", static_cast<int32_t>(input_shape.back()));
NodeAttrHelper helper(node);
int32_t axis = helper.Get("axis", 1);
if (node.SinceVersion() >= 13)
// Opset 13 has default value -1.
axis = helper.Get("axis", -1);
// Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}].
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_size));
int32_t first_dim = static_cast<int32_t>(std::reduce(input_shape.begin(), input_shape.begin() + axis,
1, std::multiplies<int64_t>()));
int32_t second_dim = static_cast<int32_t>(std::reduce(input_shape.begin() + axis, input_shape.end(),
1, std::multiplies<int64_t>()));
emscripten::val new_shape = emscripten::val::array(std::vector<int32_t>{first_dim, second_dim});
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, new_shape);
}
output = model_builder.GetBuilder().call<emscripten::val>("softmax", input);
Expand Down Expand Up @@ -76,9 +81,10 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
return false;
}
NodeAttrHelper helper(node);
const int32_t axis = helper.Get("axis", 1);
// WebNN softmax only support input axis 1
if (axis != 1 && axis != -1) {
const int64_t axis = helper.Get("axis", 1);
// WebNN softmax only support reshape for the last axis or version before 13.
// TODO: support opset 13 by composing into: Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1).
if (axis != -1 && axis != input_shape.size() - 1 && node.SinceVersion() >= 13) {
LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis;
return false;
}
Expand Down

0 comments on commit 93f22aa

Please sign in to comment.