Skip to content

Commit

Permalink
Rename the operands to "label_id" to ensure the names are unique (#111)
Browse files Browse the repository at this point in the history
Fix #63

This PR renames the operands to make sure each name is unique, refactors
`ComputeResources` for using new operands names.
  • Loading branch information
shiyi9801 authored Jan 22, 2025
1 parent a75cbe6 commit 17d8c1d
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 210 deletions.
37 changes: 12 additions & 25 deletions services/webnn/ort/graph_builder_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ constexpr char kOpTypeTranspose[] = "Transpose";
constexpr char kOpTypeTriangular[] = "Trilu";
constexpr char kOpTypeWhere[] = "Where";

constexpr char kInserted[] = "Inserted";
constexpr char kUnderscore[] = "_";

base::unexpected<mojom::ErrorPtr> NewNotSupportedError(std::string message) {
return base::unexpected(mojom::Error::New(
mojom::Error::Code::kNotSupportedError, std::move(message)));
Expand Down Expand Up @@ -182,6 +185,10 @@ struct TensorTypeMap<int64_t> {

} // namespace

std::string GetOperandName(std::string_view label, uint64_t id) {
return base::JoinString({label, base::NumberToString(id)}, kUnderscore);
}

// static
base::expected<std::unique_ptr<OrtModelBuilder::ModelInfo>, mojom::ErrorPtr>
GraphBuilderOrt::CreateAndBuild(
Expand Down Expand Up @@ -215,40 +222,20 @@ const mojom::Operand& GraphBuilderOrt::GetOperand(uint64_t operand_id) {
return *graph_info_->id_to_operand_map.at(operand_id);
}

// TODO(https://github.com/shiyi9801/chromium/issues/63): Make name generation
// more robust.
std::string GraphBuilderOrt::GetOperandNameById(uint64_t operand_id) {
const mojom::Operand& operand = GetOperand(operand_id);
switch (operand.kind) {
case mojom::Operand::Kind::kInput: {
CHECK(operand.name.has_value());
// Add a prefix to avoid possible name collision.
return operand.name.value();
// return base::JoinString({"input", operand.name.value()}, "_");
}
case mojom::Operand::Kind::kConstant: {
// It's okay to use operand id as name directly since operand id is
// guaranteed to be unique.
return base::NumberToString(operand_id);
}
case mojom::Operand::Kind::kOutput: {
if (operand.name.has_value()) {
return operand.name.value();
// return base::JoinString({"output", operand.name.value()}, "_");
} else {
return base::NumberToString(operand_id);
}
}
}
std::string operand_label =
operand.name.has_value() ? operand.name.value() : "";
return GetOperandName(operand_label, operand_id);
}

std::string GraphBuilderOrt::GenerateNextOperandName() {
return base::NumberToString(next_operand_id_++);
return GetOperandName(kInserted, next_operand_id_++);
}

std::string GraphBuilderOrt::GenerateNextOperationName(std::string_view label) {
return base::JoinString({label, base::NumberToString(next_operation_id_++)},
"_");
kUnderscore);
}

template <typename DataType>
Expand Down
13 changes: 6 additions & 7 deletions services/webnn/ort/graph_builder_ort.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ concept IsSupportedTensorType = IsAnyOf<T, float, uint16_t, int64_t>;

} // namespace internal

// The returned operand name has a format of "label_id". Adding operand id at
// the end ensures that the name is unique.
std::string GetOperandName(std::string_view label, uint64_t id);

class GraphBuilderOrt {
STACK_ALLOCATED();

Expand Down Expand Up @@ -76,13 +80,8 @@ class GraphBuilderOrt {
// Get the unique name of an existing operand by its id.
std::string GetOperandNameById(uint64_t operand_id);

// Generate the unique name of a newly created operand using the
// `next_operand_id_`, and then increase the `next_operand_id_`.
// TODO(https://github.com/shiyi9801/chromium/issues/63): Make name generation
// more robust. The newly created operands should also have a unique id, so
// here they're named by their ids for now. However, it is still possible to
// have names that are the same as the graph's inputs/outputs provided by
// users. ONNX doesn't allow duplicate operand names.
// Generate the unique name of a newly created operand by combining a prefix
// "inserted" and `next_operand_id_`, and then increase `next_operand_id_`.
std::string GenerateNextOperandName();

// Generate the unique name of a newly created operation by combining the
Expand Down
Loading

0 comments on commit 17d8c1d

Please sign in to comment.