[ONNX] Add OnnxToTorch lowering for Onnx.NegativeLogLikelihoodLoss Op (#3380)

This implements the Onnx.NegativeLogLikelihoodLoss op using the
signature provided
[here](https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html)
by replacing it with a `NLLLossForward` op.

Additionally, I included a helper function `get_loss_reduction_enum` to
convert from a string `reduction` parameter to the corresponding
intended integer value since this is an operation that will be reused
for any loss function module. This differs from `get_reduction_enum` in
`TorchUpstream.cpp` which handles the `reduce` parameter from
`scatter_reduce` type operations.
pull/3461/merge
Arham Khan 2024-06-14 09:31:11 -07:00 committed by GitHub
parent 2ea2bc3948
commit 09c988046c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 101 additions and 0 deletions

View File

@ -145,6 +145,8 @@ ScalarType promote_skip_undefined(ScalarType a, ScalarType b);
//===----------------------------------------------------------------------===//
enum Reduction { None, Mean, Sum, END };
Reduction get_loss_reduction_enum(const llvm::StringRef &reduce);
//===----------------------------------------------------------------------===//
// Possible values for `memory_format` argument in PyTorch ops that support it.
// Source:

View File

@ -435,6 +435,45 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"NegativeLogLikelihoodLoss", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value self, target, weight, reduction, ignore_index;
int64_t ignore_index_int;
std::string reduction_str;
if (binder.tensorOperandAtIndex(self, 0) ||
binder.tensorOperandAtIndex(target, 1) ||
binder.s64IntegerAttr(ignore_index_int, "ignore_index", -100) ||
binder.customOpNameStringAttr(reduction_str, "reduction", "mean") ||
binder.tensorResultType(resultType)) {
return failure();
}
// optional third tensor argument
if (binder.tensorOperandAtIndex(weight, 2)) {
weight = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
}
ignore_index = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(ignore_index_int));
// convert string reduction attr to standardized integer enum value
int reduction_value =
torch_upstream::get_loss_reduction_enum(reduction_str);
reduction = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(reduction_value));
Value nllLoss = rewriter
.create<Torch::AtenNllLossForwardOp>(
binder.getLoc(), resultType, resultType, self,
target, weight, reduction, ignore_index)
->getResult(0);
rewriter.replaceOp(binder.op, nllLoss);
return success();
});
patterns.onOp("NonZero", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -128,6 +128,21 @@ ScalarType result_type(const ResultTypeState &in_state) {
combine_categories(in_state.zeroResult, in_state.wrappedResult));
}
Reduction get_loss_reduction_enum(const llvm::StringRef &reduce) {
if (reduce == "none") {
return torch_upstream::Reduction::None;
} else if (reduce == "mean") {
return torch_upstream::Reduction::Mean;
} else if (reduce == "sum") {
return torch_upstream::Reduction::Sum;
} else if (reduce == "end") {
return torch_upstream::Reduction::END;
} else {
llvm_unreachable(
"'reduction' argument must be either none, mean, sum or end");
}
}
ReductionType get_reduction_enum(const llvm::StringRef &reduce) {
if (reduce == "max" || reduce == "amax") {
return torch_upstream::ReductionType::MAX;

View File

@ -1095,6 +1095,51 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],
// -----
// CHECK-LABEL: func.func @test_nllloss_ii
func.func @test_nllloss_ii(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
// CHECK: %[[VAL_5:.*]] = torch.constant.int 1
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.ignore_index = 1 : si64, torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @test_nllloss_ii_ignore_default
func.func @test_nllloss_ii_ignore_default(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = torch.constant.int -100
// CHECK: %[[VAL_5:.*]] = torch.constant.int 1
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @test_nllloss_ii_reduction_sum
func.func @test_nllloss_ii_reduction_sum(%arg0: !torch.vtensor<[3,5,6,6],f32>, %arg1: !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_3:.*]] = torch.constant.none
// CHECK: %[[VAL_4:.*]] = torch.constant.int -100
// CHECK: %[[VAL_5:.*]] = torch.constant.int 2
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "sum"} : (!torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @test_nllloss_iii_reduction_none_ignore_negative
func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor<[3,5,6],f32>, %arg1: !torch.vtensor<[3,6],si64>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %arg2, %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1, %arg2) {torch.onnx.ignore_index = -1 : si64, torch.onnx.reduction = "none"} : (!torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// -----
// CHECK-LABEL: func.func @test_nonzero
func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>