[MLIR][TORCH] Add OnnxToTorch lowering for ops (#3049)

This commit adds the OnnxToTorch lowering for the Mish, Softplus,
HardSwish, Trilu, ThresholdedRelu op

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3046/merge
Vivek Khandelwal 2024-03-25 20:29:07 +05:30 committed by GitHub
parent 1fcbfa87ec
commit 9ae33e482e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 205 additions and 19 deletions

View File

@ -1501,4 +1501,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, self, other);
return success();
});
patterns.onOp("Mish", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
if (binder.tensorOperand(input) ||
binder.tensorResultType(resultType)) {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::AtenMishOp>(
binder.op, resultType, input);
return success();
});
patterns.onOp("HardSwish", 14,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
if (binder.tensorOperand(input) ||
binder.tensorResultType(resultType)) {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::AtenHardswishOp>(
binder.op, resultType, input);
return success();
});
}

View File

@ -2099,4 +2099,67 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"Softplus", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
if (binder.tensorOperand(input) ||
binder.tensorResultType(resultType)) {
return failure();
}
// out = ln(exp(x) + 1)
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
resultType, input);
rewriter.replaceOpWithNewOp<Torch::AtenLog1pOp>(binder.op, resultType,
exp);
return success();
});
patterns.onOp(
"Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
int64_t upper;
if (binder.tensorOperandAtIndex(input, 0) ||
binder.s64IntegerAttr(upper, "upper", 1) ||
binder.tensorResultType(resultType)) {
return failure();
}
Value diagonal;
if (binder.tensorOperandAtIndex(diagonal, 1)) {
diagonal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
} else {
diagonal = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), diagonal);
}
if (upper) {
rewriter.replaceOpWithNewOp<Torch::AtenTriuOp>(binder.op, resultType,
input, diagonal);
return success();
}
rewriter.replaceOpWithNewOp<Torch::AtenTrilOp>(binder.op, resultType,
input, diagonal);
return success();
});
patterns.onOp("ThresholdedRelu", 10,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
float alpha;
if (binder.tensorOperand(input) ||
binder.f32FloatAttr(alpha, "alpha", 1.0)) {
return failure();
}
Value cstAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
Value value = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
rewriter.replaceOpWithNewOp<Torch::AtenThresholdOp>(
binder.op, resultType, input, cstAlpha, value);
return success();
});
}

View File

@ -1924,11 +1924,6 @@ ONNX_XFAIL_SET = {
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
# Failure - onnx_lowering: onnx.HardSwish
"HardswishModule_basic",
"HardswishRandomModule_basic",
"MobilenetV3Module_basic",
# Failure - onnx_lowering: onnx.MaxPool
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
@ -2043,10 +2038,6 @@ ONNX_XFAIL_SET = {
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
# Failure - onnx_lowering: onnx.Softplus
"ElementwiseMishModule_basic",
"SoftplusModule_basic",
# Failure - onnx_lowering: onnx.Squeeze
"SqueezeModule_allUnitDim",
"SqueezeModule_broadcast",
@ -2059,16 +2050,6 @@ ONNX_XFAIL_SET = {
"SortTensorSpecificDimension_basic",
"SortTensor_basic",
# Failure - onnx_lowering: onnx.Trilu
"AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic",
"AtenTriuModule_basic",
"AtenTriuWithNegDiagonalModule_basic",
"AtenTriuWithPosDiagonalModule_basic",
"TriuBroadcastModule_basic",
"TriuModule_basic",
# Failure - incorrect dtype
"ReduceMaxAlongDimUnsignedInt_basic",

View File

@ -912,3 +912,22 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],
%0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}
// -----
// CHECK-LABEL: func.func @test_mish
func.func @test_mish(%arg0: !torch.vtensor<[10000],f32>) -> !torch.vtensor<[10000],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.mish %arg0 : !torch.vtensor<[10000],f32> -> !torch.vtensor<[10000],f32>
%0 = torch.operator "onnx.Mish"(%arg0) : (!torch.vtensor<[10000],f32>) -> !torch.vtensor<[10000],f32>
return %0 : !torch.vtensor<[10000],f32>
}
// -----
// CHECK-LABEL: func.func @test_hardswish
func.func @test_hardswish(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.hardswish %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.HardSwish"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}

View File

@ -1664,3 +1664,102 @@ func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si
return %0 : !torch.vtensor<[],si32>
}
// -----
// CHECK-LABEL: func.func @test_softplus
func.func @test_softplus(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[EXP:.*]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: torch.aten.log1p %[[EXP]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
%0 = torch.operator "onnx.Softplus"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
return %0 : !torch.vtensor<[3],f32>
}
// -----
// CHECK-LABEL: func.func @test_tril
func.func @test_tril(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[4,5],si64>, !torch.int -> !torch.vtensor<[4,5],si64>
%0 = torch.operator "onnx.Trilu"(%arg0) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64>
return %0 : !torch.vtensor<[4,5],si64>
}
// -----
// CHECK-LABEL: func.func @test_tril_neg
func.func @test_tril_neg(%arg0: !torch.vtensor<[4,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[4,5],si64>, !torch.int -> !torch.vtensor<[4,5],si64>
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[4,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[4,5],si64>
return %0 : !torch.vtensor<[4,5],si64>
}
// -----
// CHECK-LABEL: func.func @test_tril_one_row_neg
func.func @test_tril_one_row_neg(%arg0: !torch.vtensor<[3,1,5],si64>) -> !torch.vtensor<[3,1,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[3,1,5],si64>, !torch.int -> !torch.vtensor<[3,1,5],si64>
%0 = torch.operator "onnx.Trilu"(%arg0) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[3,1,5],si64>) -> !torch.vtensor<[3,1,5],si64>
return %0 : !torch.vtensor<[3,1,5],si64>
}
// -----
// CHECK-LABEL: func.func @test_tril_square
func.func @test_tril_square(%arg0: !torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[2,3,3],si64>, !torch.int -> !torch.vtensor<[2,3,3],si64>
%0 = torch.operator "onnx.Trilu"(%arg0) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64>
return %0 : !torch.vtensor<[2,3,3],si64>
}
// -----
// CHECK-LABEL: func.func @test_tril_zero
func.func @test_tril_zero(%arg0: !torch.vtensor<[3,0,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[3,0,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[3,0,5],si64>, !torch.int -> !torch.vtensor<[3,0,5],si64>
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[3,0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[3,0,5],si64>
return %0 : !torch.vtensor<[3,0,5],si64>
}
// -----
// CHECK-LABEL: func.func @test_triu
func.func @test_triu(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
// CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[4,5],si64>, !torch.int -> !torch.vtensor<[4,5],si64>
%0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64>
return %0 : !torch.vtensor<[4,5],si64>
}
// -----
// CHECK-LABEL: func.func @test_triu_one_row
func.func @test_triu_one_row(%arg0: !torch.vtensor<[3,1,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[3,1,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[3,1,5],si64>, !torch.int -> !torch.vtensor<[3,1,5],si64>
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[3,1,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[3,1,5],si64>
return %0 : !torch.vtensor<[3,1,5],si64>
}
// -----
// CHECK-LABEL: func.func @test_triu_square
func.func @test_triu_square(%arg0: !torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
// CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[2,3,3],si64>, !torch.int -> !torch.vtensor<[2,3,3],si64>
%0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64>
return %0 : !torch.vtensor<[2,3,3],si64>
}
// -----
// CHECK-LABEL: func.func @test_triu_zero
func.func @test_triu_zero(%arg0: !torch.vtensor<[0,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
// CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[0,5],si64>, !torch.int -> !torch.vtensor<[0,5],si64>
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64>
return %0 : !torch.vtensor<[0,5],si64>
}