mirror of https://github.com/llvm/torch-mlir
[ONNX] Add OnnxToTorch lowering for Onnx.NegativeLogLikelihoodLoss Op (#3380)
This implements the Onnx.NegativeLogLikelihoodLoss op using the signature provided [here](https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html) by replacing it with a `NLLLossForward` op. Additionally, I included a helper function `get_loss_reduction_enum` to convert from a string `reduction` parameter to the corresponding intended integer value since this is an operation that will be reused for any loss function module. This differs from `get_reduction_enum` in `TorchUpstream.cpp` which handles the `reduce` parameter from `scatter_reduce` type operations.pull/3461/merge
parent
2ea2bc3948
commit
09c988046c
|
@ -145,6 +145,8 @@ ScalarType promote_skip_undefined(ScalarType a, ScalarType b);
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
enum Reduction { None, Mean, Sum, END };
|
enum Reduction { None, Mean, Sum, END };
|
||||||
|
|
||||||
|
Reduction get_loss_reduction_enum(const llvm::StringRef &reduce);
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Possible values for `memory_format` argument in PyTorch ops that support it.
|
// Possible values for `memory_format` argument in PyTorch ops that support it.
|
||||||
// Source:
|
// Source:
|
||||||
|
|
|
@ -435,6 +435,45 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.op, resultType, lhs, rhs);
|
binder.op, resultType, lhs, rhs);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"NegativeLogLikelihoodLoss", 13,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value self, target, weight, reduction, ignore_index;
|
||||||
|
int64_t ignore_index_int;
|
||||||
|
std::string reduction_str;
|
||||||
|
|
||||||
|
if (binder.tensorOperandAtIndex(self, 0) ||
|
||||||
|
binder.tensorOperandAtIndex(target, 1) ||
|
||||||
|
binder.s64IntegerAttr(ignore_index_int, "ignore_index", -100) ||
|
||||||
|
binder.customOpNameStringAttr(reduction_str, "reduction", "mean") ||
|
||||||
|
binder.tensorResultType(resultType)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// optional third tensor argument
|
||||||
|
if (binder.tensorOperandAtIndex(weight, 2)) {
|
||||||
|
weight = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
}
|
||||||
|
|
||||||
|
ignore_index = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(ignore_index_int));
|
||||||
|
|
||||||
|
// convert string reduction attr to standardized integer enum value
|
||||||
|
int reduction_value =
|
||||||
|
torch_upstream::get_loss_reduction_enum(reduction_str);
|
||||||
|
reduction = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(reduction_value));
|
||||||
|
|
||||||
|
Value nllLoss = rewriter
|
||||||
|
.create<Torch::AtenNllLossForwardOp>(
|
||||||
|
binder.getLoc(), resultType, resultType, self,
|
||||||
|
target, weight, reduction, ignore_index)
|
||||||
|
->getResult(0);
|
||||||
|
|
||||||
|
rewriter.replaceOp(binder.op, nllLoss);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp("NonZero", 13,
|
patterns.onOp("NonZero", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
|
|
@ -128,6 +128,21 @@ ScalarType result_type(const ResultTypeState &in_state) {
|
||||||
combine_categories(in_state.zeroResult, in_state.wrappedResult));
|
combine_categories(in_state.zeroResult, in_state.wrappedResult));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reduction get_loss_reduction_enum(const llvm::StringRef &reduce) {
|
||||||
|
if (reduce == "none") {
|
||||||
|
return torch_upstream::Reduction::None;
|
||||||
|
} else if (reduce == "mean") {
|
||||||
|
return torch_upstream::Reduction::Mean;
|
||||||
|
} else if (reduce == "sum") {
|
||||||
|
return torch_upstream::Reduction::Sum;
|
||||||
|
} else if (reduce == "end") {
|
||||||
|
return torch_upstream::Reduction::END;
|
||||||
|
} else {
|
||||||
|
llvm_unreachable(
|
||||||
|
"'reduction' argument must be either none, mean, sum or end");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ReductionType get_reduction_enum(const llvm::StringRef &reduce) {
|
ReductionType get_reduction_enum(const llvm::StringRef &reduce) {
|
||||||
if (reduce == "max" || reduce == "amax") {
|
if (reduce == "max" || reduce == "amax") {
|
||||||
return torch_upstream::ReductionType::MAX;
|
return torch_upstream::ReductionType::MAX;
|
||||||
|
|
|
@ -1095,6 +1095,51 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_nllloss_ii
|
||||||
|
func.func @test_nllloss_ii(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
|
||||||
|
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.ignore_index = 1 : si64, torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32>
|
||||||
|
return %0 : !torch.vtensor<[],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_nllloss_ii_ignore_default
|
||||||
|
func.func @test_nllloss_ii_ignore_default(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int -100
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
|
||||||
|
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32>
|
||||||
|
return %0 : !torch.vtensor<[],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_nllloss_ii_reduction_sum
|
||||||
|
func.func @test_nllloss_ii_reduction_sum(%arg0: !torch.vtensor<[3,5,6,6],f32>, %arg1: !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int -100
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
|
||||||
|
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "sum"} : (!torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32>
|
||||||
|
return %0 : !torch.vtensor<[],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_nllloss_iii_reduction_none_ignore_negative
|
||||||
|
func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor<[3,5,6],f32>, %arg1: !torch.vtensor<[3,6],si64>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int -1
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %arg2, %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32>
|
||||||
|
%0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1, %arg2) {torch.onnx.ignore_index = -1 : si64, torch.onnx.reduction = "none"} : (!torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32>
|
||||||
|
return %0 : !torch.vtensor<[],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_nonzero
|
// CHECK-LABEL: func.func @test_nonzero
|
||||||
func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
|
// CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64>
|
||||||
|
|
Loading…
Reference in New Issue