diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index e7e3cee21c956..b207b804416aa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -37,13 +37,18 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto input_size = input_shape.size(); // WebNN Softmax only support 2d input shape, reshape input to 2d. if (input_size != 2) { - int32_t new_shape_0 = SafeInt(input_shape.data()[0]); - for (size_t i = 1; i < input_size - 1; i++) { - new_shape_0 *= input_shape.data()[i]; - } - emscripten::val new_shape = emscripten::val::array(); - new_shape.call("push", new_shape_0); - new_shape.call("push", static_cast(input_shape.back())); + NodeAttrHelper helper(node); + int32_t axis = helper.Get("axis", 1); + if (node.SinceVersion() >= 13) + // Opset 13 has default value -1. + axis = helper.Get("axis", -1); + // Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. + axis = static_cast(HandleNegativeAxis(axis, input_size)); + int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), + 1, std::multiplies())); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); input = model_builder.GetBuilder().call("reshape", input, new_shape); } output = model_builder.GetBuilder().call("softmax", input); @@ -76,9 +81,10 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return false; } NodeAttrHelper helper(node); - const int32_t axis = helper.Get("axis", 1); - // WebNN softmax only support input axis 1 - if (axis != 1 && axis != -1) { + const int64_t axis = helper.Get("axis", 1); + // WebNN softmax only support reshape for the last axis or version before 13. + // TODO: support opset 13 by composing into: Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1). + if (axis != -1 && axis != input_shape.size() - 1 && node.SinceVersion() >= 13) { LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis; return false; }