Skip to content

Commit

Permalink
Fix compilation errors:
Browse files Browse the repository at this point in the history
 * replace base::ranges with std::ranges
 * add RankRange for some ops
  • Loading branch information
shiyi9801 committed Feb 11, 2025
1 parent 44a7c7f commit a97dfb2
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 73 deletions.
3 changes: 2 additions & 1 deletion gpu/ipc/service/gpu_init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,8 @@ bool GpuInit::InitializeAndStartSandbox(base::CommandLine* command_line,
}

if (command_line->HasSwitch(switches::kUseRedistributableONNXRuntime)) {
base::LoadNativeLibrary(module_path.Append(L"onnxruntime.dll"), nullptr);
base::LoadNativeLibrary(module_path.Append(L"onnxruntime.dll"),
nullptr);
}
}

Expand Down
160 changes: 94 additions & 66 deletions services/webnn/ort/context_impl_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,85 +48,92 @@ ContextImplOrt::ContextImplOrt(

ContextImplOrt::~ContextImplOrt() = default;

// TODO(https://github.com/shiyi9801/chromium/issues/103): Investigate how to
// set the tensor byte length limit and supported tensor ranks
static constexpr uint64_t kTensorByteLengthLimit =
std::numeric_limits<int32_t>::max();

// static
ContextProperties ContextImplOrt::GetContextProperties() {
// TODO(https://github.com/shiyi9801/chromium/issues/103): Investigate how to
// set the tensor byte length limit and supported tensor ranks
static constexpr uint64_t kTensorByteLengthLimit =
std::numeric_limits<int32_t>::max();

static constexpr SupportedRanks kMaxRank = SupportedRanks::UpTo(8);
static constexpr SupportedRanks kNonScalarMaxRank =
SupportedRanks::NonScalarUpTo(8);

return ContextProperties(
InputOperandLayout::kNchw, Resample2DAxes::kChannelsFirst,
/*tensor_byte_length_limit=*/kTensorByteLengthLimit,
{/*input=*/SupportedDataTypes::All(),
/*constant=*/SupportedDataTypes::All(),
/*arg_min_max_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*arg_min_max_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kNonScalarMaxRank},
/*arg_min_max_output=*/DataTypeConstraint::kInt32To64,
/*batch_normalization_input=*/DataTypeConstraint::kFloat16To32,
/*cast_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*clamp_input=*/DataTypeConstraint::kFloat16To32,
/*cast_input=*/{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*clamp_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*concat_inputs=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*conv2d_input=*/DataTypeConstraint::kFloat16To32,
/*conv_transpose2d_input=*/DataTypeConstraint::kFloat16To32,
/*cumulative_sum_input=*/{},
/*dequantize_linear_input=*/{},
/*dequantize_linear_scale=*/{},
/*add_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*sub_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*mul_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*div_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*max_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*min_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*pow_input=*/
{DataTypeConstraint::kFloat16To32, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kFloat16To32, kMaxRank},
/*equal_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*greater_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*greater_or_equal_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*lesser_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*lesser_or_equal_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*not_equal_input=*/{},
/*logical_and_input=*/
{DataTypeConstraint::kUint8, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kUint8, kMaxRank},
/*logical_or_input=*/
{DataTypeConstraint::kUint8, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kUint8, kMaxRank},
/*logical_xor_input=*/
{DataTypeConstraint::kUint8, SupportedRanks::UpTo(8)},
/*logical_not_input=*/DataTypeConstraint::kUint8,
{DataTypeConstraint::kUint8, kMaxRank},
/*logical_not_input=*/{DataTypeConstraint::kUint8, kMaxRank},
/*logical_output=*/DataTypeConstraint::kUint8,
/*abs_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*ceil_input=*/DataTypeConstraint::kFloat16To32,
/*cos_input=*/DataTypeConstraint::kFloat16To32,
/*erf_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*exp_input=*/DataTypeConstraint::kFloat16To32,
/*floor_input=*/DataTypeConstraint::kFloat16To32,
/*identity_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*log_input=*/DataTypeConstraint::kFloat16To32,
/*neg_input=*/DataTypeConstraint::kFloat16To32Int8To64,
/*reciprocal_input=*/DataTypeConstraint::kFloat16To32,
/*sign_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*sin_input=*/DataTypeConstraint::kFloat16To32,
/*sqrt_input=*/DataTypeConstraint::kFloat16To32,
/*tan_input=*/DataTypeConstraint::kFloat16To32,
/*abs_input=*/{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*ceil_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*cos_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*erf_input=*/{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*exp_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*floor_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*identity_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*log_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*neg_input=*/{DataTypeConstraint::kFloat16To32Int8To64, kMaxRank},
/*reciprocal_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*sign_input=*/{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*sin_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*sqrt_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*tan_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*elu_input=*/{},
/*expand_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*expand_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*gather_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*gather_indices=*/DataTypeConstraint::kInt32To64,
/*gather_elements_input=*/{},
/*gather_elements_indices=*/{},
/*gather_nd_input=*/{},
/*gather_nd_indices=*/{},
/*gelu_input=*/DataTypeConstraint::kFloat16To32,
/*gelu_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*gemm_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*gru_input=*/{},
/*gru_cell_input=*/{},
Expand All @@ -139,46 +146,67 @@ ContextProperties ContextImplOrt::GetContextProperties() {
/*lstm_input=*/{},
/*lstm_cell_input=*/{},
/*matmul_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, SupportedRanks::UpTo(8)},
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
// TODO: Support more data types including int4.
// https://github.com/shiyi9801/chromium/issues/85
/*pad_input=*/DataTypeConstraint::kFloat16To32,
/*average_pool2d_input=*/DataTypeConstraint::kFloat16To32,
/*l2_pool2d_input=*/DataTypeConstraint::kFloat16To32,
/*max_pool2d_input=*/DataTypeConstraint::kFloat16To32,
/*pad_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*average_pool2d_input=*/
{DataTypeConstraint::kFloat16To32, kNonScalarMaxRank},
/*l2_pool2d_input=*/
{DataTypeConstraint::kFloat16To32, kNonScalarMaxRank},
/*max_pool2d_input=*/
{DataTypeConstraint::kFloat16To32, kNonScalarMaxRank},
/*prelu_input=*/{},
/*quantize_linear_input=*/{},
/*quantize_linear_zero_point=*/{},
/*reduce_l1_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*reduce_l2_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*reduce_log_sum_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*reduce_log_sum_exp_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*reduce_max_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*reduce_mean_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*reduce_min_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*reduce_product_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*reduce_sum_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*reduce_sum_square_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
/*relu_input=*/DataTypeConstraint::kFloat16To32Int8To32,
/*resample2d_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*reshape_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*reduce_l1_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*reduce_l2_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*reduce_log_sum_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*reduce_log_sum_exp_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*reduce_max_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*reduce_mean_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*reduce_min_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*reduce_product_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*reduce_sum_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*reduce_sum_square_input=*/
{DataTypeConstraint::kFloat16To32Ints32To64, kMaxRank},
/*relu_input=*/{DataTypeConstraint::kFloat16To32Int8To32, kMaxRank},
/*resample2d_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*reshape_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*reverse_input=*/{},
/*scatter_elements_input=*/{},
/*scatter_elements_indices=*/{},
/*scatter_nd_input=*/{},
/*scatter_nd_indices=*/{},
/*sigmoid_input=*/DataTypeConstraint::kFloat16To32,
/*slice_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*softmax_input=*/DataTypeConstraint::kFloat16To32,
/*scatter_nd_updates=*/{},
/*sigmoid_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*slice_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*softmax_input=*/{DataTypeConstraint::kFloat16To32, kMaxRank},
/*softplus_input=*/{},
/*softsign_input=*/{},
/*split_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*split_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*tanh_input=*/{},
/*tile_input=*/{},
/*transpose_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*triangular_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
/*where_condition=*/DataTypeConstraint::kUint8,
/*where_value=*/DataTypeConstraint::kAllDataTypesAtLeast8bits});
/*transpose_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*triangular_input=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank},
/*where_condition=*/{DataTypeConstraint::kUint8, kMaxRank},
/*where_value=*/
{DataTypeConstraint::kAllDataTypesAtLeast8bits, kMaxRank}});
}

base::WeakPtr<WebNNContextImpl> ContextImplOrt::AsWeakPtr() {
Expand Down
12 changes: 6 additions & 6 deletions services/webnn/ort/graph_builder_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ GraphBuilderOrt::AddExpandOperation(const mojom::Expand& expand) {
std::vector<uint32_t> shape_dims = {
base::checked_cast<uint32_t>(output_shape.size())};
std::vector<int64_t> shape_values;
base::ranges::transform(
std::ranges::transform(
output_shape, std::back_inserter(shape_values),
[](uint32_t dim) { return static_cast<int64_t>(dim); });
ASSIGN_OR_RETURN(const std::string shape_name,
Expand Down Expand Up @@ -1168,7 +1168,7 @@ GraphBuilderOrt::AddLayerNormalizationOperation(
}

// TODO: crbug.com/356905058: Figure out if unordered axes should be allowed.
if (!base::ranges::is_sorted(axes)) {
if (!std::ranges::is_sorted(axes)) {
return NewNotSupportedError("Axes must be ordered for layerNormalization.");
}
const auto axes_size = axes.size();
Expand All @@ -1194,7 +1194,7 @@ GraphBuilderOrt::AddLayerNormalizationOperation(

std::vector<uint32_t> scale_dims;
scale_dims.reserve(axes_size);
base::ranges::transform(
std::ranges::transform(
axes, std::back_inserter(scale_dims),
[&input_shape](uint32_t axis) { return input_shape[axis]; });

Expand Down Expand Up @@ -1273,10 +1273,10 @@ GraphBuilderOrt::AddPadOperation(const mojom::Pad& pad) {
// paddings is an operand with data type int64, not an attribute.
std::vector<int64_t> paddings;
paddings.reserve(padding_length);
base::ranges::transform(
std::ranges::transform(
pad.beginning_padding, std::back_inserter(paddings),
[](uint32_t value) { return base::checked_cast<int64_t>(value); });
base::ranges::transform(
std::ranges::transform(
pad.ending_padding, std::back_inserter(paddings),
[](uint32_t value) { return base::checked_cast<int64_t>(value); });

Expand Down Expand Up @@ -1542,7 +1542,7 @@ GraphBuilderOrt::AddReshapeOperation(const mojom::Reshape& reshape) {
std::vector<uint32_t> shape_dims = {
base::checked_cast<uint32_t>(output_shape.size())};
std::vector<int64_t> shape_values;
base::ranges::transform(
std::ranges::transform(
output_shape, std::back_inserter(shape_values),
[](uint32_t dim) { return static_cast<int64_t>(dim); });
ASSIGN_OR_RETURN(const std::string shape_name,
Expand Down

0 comments on commit a97dfb2

Please sign in to comment.