diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 043dd92549b2..31981ec2a04a 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -135,6 +135,8 @@ ScalarType promote_skip_undefined(ScalarType a, ScalarType b); //===----------------------------------------------------------------------===// enum Reduction { None, Mean, Sum, END }; +Reduction get_loss_reduction_enum(const llvm::StringRef &reduce); + //===----------------------------------------------------------------------===// // Possible values for `memory_format` argument in PyTorch ops that support it. // Source: diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index cf14fc0268d3..5a9ed2acdb86 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -444,6 +444,45 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "NegativeLogLikelihoodLoss", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value self, target, weight, reduction, ignore_index; + int64_t ignore_index_int; + std::string reduction_str; + + if (binder.tensorOperandAtIndex(self, 0) || + binder.tensorOperandAtIndex(target, 1) || + binder.s64IntegerAttr(ignore_index_int, "ignore_index", -100) || + binder.customOpNameStringAttr(reduction_str, "reduction", "mean") || + binder.tensorResultType(resultType)) { + return failure(); + } + + // optional third tensor argument + if (binder.tensorOperandAtIndex(weight, 2)) { + weight = rewriter.create(binder.getLoc()); + } + + ignore_index = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(ignore_index_int)); + + // convert string reduction attr to standardized integer enum value + int reduction_value = + torch_upstream::get_loss_reduction_enum(reduction_str); + reduction = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(reduction_value)); + + Value nllLoss = rewriter + .create( + binder.getLoc(), resultType, resultType, self, + target, weight, reduction, ignore_index) + ->getResult(0); + + rewriter.replaceOp(binder.op, nllLoss); + return success(); + }); patterns.onOp("NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 2dce14ef964c..37b84066bcbe 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -128,6 +128,21 @@ ScalarType result_type(const ResultTypeState &in_state) { combine_categories(in_state.zeroResult, in_state.wrappedResult)); } +Reduction get_loss_reduction_enum(const llvm::StringRef &reduce) { + if (reduce == "none") { + return torch_upstream::Reduction::None; + } else if (reduce == "mean") { + return torch_upstream::Reduction::Mean; + } else if (reduce == "sum") { + return torch_upstream::Reduction::Sum; + } else if (reduce == "end") { + return torch_upstream::Reduction::END; + } else { + llvm_unreachable( + "'reduction' argument must be either none, mean, sum or end"); + } +} + ReductionType get_reduction_enum(const llvm::StringRef &reduce) { if (reduce == "max" || reduce == "amax") { return torch_upstream::ReductionType::MAX; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 54311fdbc805..0b5933209b14 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -929,6 +929,51 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4], // ----- +// CHECK-LABEL: func.func @test_nllloss_ii +func.func @test_nllloss_ii(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + // CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32> + %0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.ignore_index = 1 : si64, torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> + } + +// CHECK-LABEL: func.func @test_nllloss_ii_ignore_default +func.func @test_nllloss_ii_ignore_default(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int -100 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + // CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32> + %0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// CHECK-LABEL: func.func @test_nllloss_ii_reduction_sum +func.func @test_nllloss_ii_reduction_sum(%arg0: !torch.vtensor<[3,5,6,6],f32>, %arg1: !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int -100 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + // CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32> + %0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "sum"} : (!torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// CHECK-LABEL: func.func @test_nllloss_iii_reduction_none_ignore_negative +func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor<[3,5,6],f32>, %arg1: !torch.vtensor<[3,6],si64>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_4:.*]] = torch.constant.int -1 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %arg2, %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + // CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32> + %0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1, %arg2) {torch.onnx.ignore_index = -1 : si64, torch.onnx.reduction = "none"} : (!torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_nonzero func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>