Skip to content

Commit 09c9880

Browse files
authored
[ONNX] Add OnnxToTorch lowering for Onnx.NegativeLogLikelihoodLoss Op (#3380)
This implements the Onnx.NegativeLogLikelihoodLoss op using the signature provided [here](https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html) by replacing it with a `NLLLossForward` op. Additionally, I included a helper function `get_loss_reduction_enum` to convert from a string `reduction` parameter to the corresponding intended integer value since this is an operation that will be reused for any loss function module. This differs from `get_reduction_enum` in `TorchUpstream.cpp` which handles the `reduce` parameter from `scatter_reduce` type operations.
1 parent 2ea2bc3 commit 09c9880

File tree

4 files changed

+101
-0
lines changed

4 files changed

+101
-0
lines changed

include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ ScalarType promote_skip_undefined(ScalarType a, ScalarType b);
145145
//===----------------------------------------------------------------------===//
146146
enum Reduction { None, Mean, Sum, END };
147147

148+
Reduction get_loss_reduction_enum(const llvm::StringRef &reduce);
149+
148150
//===----------------------------------------------------------------------===//
149151
// Possible values for `memory_format` argument in PyTorch ops that support it.
150152
// Source:

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,45 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
435435
binder.op, resultType, lhs, rhs);
436436
return success();
437437
});
438+
patterns.onOp(
439+
"NegativeLogLikelihoodLoss", 13,
440+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
441+
Torch::ValueTensorType resultType;
442+
Value self, target, weight, reduction, ignore_index;
443+
int64_t ignore_index_int;
444+
std::string reduction_str;
445+
446+
if (binder.tensorOperandAtIndex(self, 0) ||
447+
binder.tensorOperandAtIndex(target, 1) ||
448+
binder.s64IntegerAttr(ignore_index_int, "ignore_index", -100) ||
449+
binder.customOpNameStringAttr(reduction_str, "reduction", "mean") ||
450+
binder.tensorResultType(resultType)) {
451+
return failure();
452+
}
453+
454+
// optional third tensor argument
455+
if (binder.tensorOperandAtIndex(weight, 2)) {
456+
weight = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
457+
}
458+
459+
ignore_index = rewriter.create<Torch::ConstantIntOp>(
460+
binder.getLoc(), rewriter.getI64IntegerAttr(ignore_index_int));
461+
462+
// convert string reduction attr to standardized integer enum value
463+
int reduction_value =
464+
torch_upstream::get_loss_reduction_enum(reduction_str);
465+
reduction = rewriter.create<Torch::ConstantIntOp>(
466+
binder.getLoc(), rewriter.getI64IntegerAttr(reduction_value));
467+
468+
Value nllLoss = rewriter
469+
.create<Torch::AtenNllLossForwardOp>(
470+
binder.getLoc(), resultType, resultType, self,
471+
target, weight, reduction, ignore_index)
472+
->getResult(0);
473+
474+
rewriter.replaceOp(binder.op, nllLoss);
475+
return success();
476+
});
438477
patterns.onOp("NonZero", 13,
439478
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
440479
Torch::ValueTensorType resultType;

lib/Dialect/Torch/Utils/TorchUpstream.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,21 @@ ScalarType result_type(const ResultTypeState &in_state) {
128128
combine_categories(in_state.zeroResult, in_state.wrappedResult));
129129
}
130130

131+
Reduction get_loss_reduction_enum(const llvm::StringRef &reduce) {
132+
if (reduce == "none") {
133+
return torch_upstream::Reduction::None;
134+
} else if (reduce == "mean") {
135+
return torch_upstream::Reduction::Mean;
136+
} else if (reduce == "sum") {
137+
return torch_upstream::Reduction::Sum;
138+
} else if (reduce == "end") {
139+
return torch_upstream::Reduction::END;
140+
} else {
141+
llvm_unreachable(
142+
"'reduction' argument must be either none, mean, sum or end");
143+
}
144+
}
145+
131146
ReductionType get_reduction_enum(const llvm::StringRef &reduce) {
132147
if (reduce == "max" || reduce == "amax") {
133148
return torch_upstream::ReductionType::MAX;

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,51 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],
10951095

10961096
// -----
10971097

1098+
// CHECK-LABEL: func.func @test_nllloss_ii
1099+
func.func @test_nllloss_ii(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1100+
// CHECK: %[[VAL_3:.*]] = torch.constant.none
1101+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
1102+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 1
1103+
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
1104+
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
1105+
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.ignore_index = 1 : si64, torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32>
1106+
return %0 : !torch.vtensor<[],f32>
1107+
}
1108+
1109+
// CHECK-LABEL: func.func @test_nllloss_ii_ignore_default
1110+
func.func @test_nllloss_ii_ignore_default(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1111+
// CHECK: %[[VAL_3:.*]] = torch.constant.none
1112+
// CHECK: %[[VAL_4:.*]] = torch.constant.int -100
1113+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 1
1114+
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
1115+
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
1116+
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32>
1117+
return %0 : !torch.vtensor<[],f32>
1118+
}
1119+
1120+
// CHECK-LABEL: func.func @test_nllloss_ii_reduction_sum
1121+
func.func @test_nllloss_ii_reduction_sum(%arg0: !torch.vtensor<[3,5,6,6],f32>, %arg1: !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1122+
// CHECK: %[[VAL_3:.*]] = torch.constant.none
1123+
// CHECK: %[[VAL_4:.*]] = torch.constant.int -100
1124+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 2
1125+
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
1126+
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
1127+
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "sum"} : (!torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32>
1128+
return %0 : !torch.vtensor<[],f32>
1129+
}
1130+
1131+
// CHECK-LABEL: func.func @test_nllloss_iii_reduction_none_ignore_negative
1132+
func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor<[3,5,6],f32>, %arg1: !torch.vtensor<[3,6],si64>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1133+
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
1134+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
1135+
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %arg2, %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
1136+
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
1137+
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1, %arg2) {torch.onnx.ignore_index = -1 : si64, torch.onnx.reduction = "none"} : (!torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32>
1138+
return %0 : !torch.vtensor<[],f32>
1139+
}
1140+
1141+
// -----
1142+
10981143
// CHECK-LABEL: func.func @test_nonzero
10991144
func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
11001145
// CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>

0 commit comments

Comments
 (0)