mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add OnnxToTorch lowering for ops (#3049)
This commit adds the OnnxToTorch lowering for the Mish, Softplus, HardSwish, Trilu, ThresholdedRelu op Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3046/merge
parent
1fcbfa87ec
commit
9ae33e482e
|
@ -1501,4 +1501,28 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.op, resultType, self, other);
|
binder.op, resultType, self, other);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp("Mish", 18,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value input;
|
||||||
|
if (binder.tensorOperand(input) ||
|
||||||
|
binder.tensorResultType(resultType)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenMishOp>(
|
||||||
|
binder.op, resultType, input);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp("HardSwish", 14,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value input;
|
||||||
|
if (binder.tensorOperand(input) ||
|
||||||
|
binder.tensorResultType(resultType)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenHardswishOp>(
|
||||||
|
binder.op, resultType, input);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -2099,4 +2099,67 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, resultType, operand);
|
binder.op, resultType, operand);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"Softplus", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value input;
|
||||||
|
if (binder.tensorOperand(input) ||
|
||||||
|
binder.tensorResultType(resultType)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
// out = ln(exp(x) + 1)
|
||||||
|
Value exp = rewriter.create<Torch::AtenExpOp>(binder.getLoc(),
|
||||||
|
resultType, input);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenLog1pOp>(binder.op, resultType,
|
||||||
|
exp);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value input;
|
||||||
|
int64_t upper;
|
||||||
|
if (binder.tensorOperandAtIndex(input, 0) ||
|
||||||
|
binder.s64IntegerAttr(upper, "upper", 1) ||
|
||||||
|
binder.tensorResultType(resultType)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value diagonal;
|
||||||
|
if (binder.tensorOperandAtIndex(diagonal, 1)) {
|
||||||
|
diagonal = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getI64IntegerAttr(0));
|
||||||
|
} else {
|
||||||
|
diagonal = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), diagonal);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (upper) {
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenTriuOp>(binder.op, resultType,
|
||||||
|
input, diagonal);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenTrilOp>(binder.op, resultType,
|
||||||
|
input, diagonal);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
patterns.onOp("ThresholdedRelu", 10,
|
||||||
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
Value input;
|
||||||
|
float alpha;
|
||||||
|
if (binder.tensorOperand(input) ||
|
||||||
|
binder.f32FloatAttr(alpha, "alpha", 1.0)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
Value cstAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(), alpha));
|
||||||
|
Value value = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||||
|
rewriter.getFloatAttr(rewriter.getF64Type(), 0.0));
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenThresholdOp>(
|
||||||
|
binder.op, resultType, input, cstAlpha, value);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -1924,11 +1924,6 @@ ONNX_XFAIL_SET = {
|
||||||
"EinsumStaticFourDimensionModule_basic",
|
"EinsumStaticFourDimensionModule_basic",
|
||||||
"EinsumStaticModule_basic",
|
"EinsumStaticModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.HardSwish
|
|
||||||
"HardswishModule_basic",
|
|
||||||
"HardswishRandomModule_basic",
|
|
||||||
"MobilenetV3Module_basic",
|
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.MaxPool
|
# Failure - onnx_lowering: onnx.MaxPool
|
||||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
|
@ -2043,10 +2038,6 @@ ONNX_XFAIL_SET = {
|
||||||
"CrossEntropyLossModule_basic",
|
"CrossEntropyLossModule_basic",
|
||||||
"CrossEntropyLossNoReductionModule_basic",
|
"CrossEntropyLossNoReductionModule_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.Softplus
|
|
||||||
"ElementwiseMishModule_basic",
|
|
||||||
"SoftplusModule_basic",
|
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.Squeeze
|
# Failure - onnx_lowering: onnx.Squeeze
|
||||||
"SqueezeModule_allUnitDim",
|
"SqueezeModule_allUnitDim",
|
||||||
"SqueezeModule_broadcast",
|
"SqueezeModule_broadcast",
|
||||||
|
@ -2059,16 +2050,6 @@ ONNX_XFAIL_SET = {
|
||||||
"SortTensorSpecificDimension_basic",
|
"SortTensorSpecificDimension_basic",
|
||||||
"SortTensor_basic",
|
"SortTensor_basic",
|
||||||
|
|
||||||
# Failure - onnx_lowering: onnx.Trilu
|
|
||||||
"AtenTrilModule_basic",
|
|
||||||
"AtenTrilWithNegDiagonalModule_basic",
|
|
||||||
"AtenTrilWithPosDiagonalModule_basic",
|
|
||||||
"AtenTriuModule_basic",
|
|
||||||
"AtenTriuWithNegDiagonalModule_basic",
|
|
||||||
"AtenTriuWithPosDiagonalModule_basic",
|
|
||||||
"TriuBroadcastModule_basic",
|
|
||||||
"TriuModule_basic",
|
|
||||||
|
|
||||||
# Failure - incorrect dtype
|
# Failure - incorrect dtype
|
||||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
|
|
||||||
|
|
|
@ -912,3 +912,22 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],
|
||||||
%0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
%0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||||
return %0 : !torch.vtensor<[3,4,5],f32>
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_mish
|
||||||
|
func.func @test_mish(%arg0: !torch.vtensor<[10000],f32>) -> !torch.vtensor<[10000],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
%none = torch.constant.none
|
||||||
|
// CHECK: torch.aten.mish %arg0 : !torch.vtensor<[10000],f32> -> !torch.vtensor<[10000],f32>
|
||||||
|
%0 = torch.operator "onnx.Mish"(%arg0) : (!torch.vtensor<[10000],f32>) -> !torch.vtensor<[10000],f32>
|
||||||
|
return %0 : !torch.vtensor<[10000],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_hardswish
|
||||||
|
func.func @test_hardswish(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: torch.aten.hardswish %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
%0 = torch.operator "onnx.HardSwish"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
|
@ -1664,3 +1664,102 @@ func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si
|
||||||
return %0 : !torch.vtensor<[],si32>
|
return %0 : !torch.vtensor<[],si32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_softplus
|
||||||
|
func.func @test_softplus(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[EXP:.*]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
|
// CHECK: torch.aten.log1p %[[EXP]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
|
||||||
|
%0 = torch.operator "onnx.Softplus"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||||
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_tril
|
||||||
|
func.func @test_tril(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[4,5],si64>, !torch.int -> !torch.vtensor<[4,5],si64>
|
||||||
|
%0 = torch.operator "onnx.Trilu"(%arg0) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64>
|
||||||
|
return %0 : !torch.vtensor<[4,5],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_tril_neg
|
||||||
|
func.func @test_tril_neg(%arg0: !torch.vtensor<[4,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[4,5],si64>, !torch.int -> !torch.vtensor<[4,5],si64>
|
||||||
|
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[4,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[4,5],si64>
|
||||||
|
return %0 : !torch.vtensor<[4,5],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_tril_one_row_neg
|
||||||
|
func.func @test_tril_one_row_neg(%arg0: !torch.vtensor<[3,1,5],si64>) -> !torch.vtensor<[3,1,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[3,1,5],si64>, !torch.int -> !torch.vtensor<[3,1,5],si64>
|
||||||
|
%0 = torch.operator "onnx.Trilu"(%arg0) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[3,1,5],si64>) -> !torch.vtensor<[3,1,5],si64>
|
||||||
|
return %0 : !torch.vtensor<[3,1,5],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_tril_square
|
||||||
|
func.func @test_tril_square(%arg0: !torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[2,3,3],si64>, !torch.int -> !torch.vtensor<[2,3,3],si64>
|
||||||
|
%0 = torch.operator "onnx.Trilu"(%arg0) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64>
|
||||||
|
return %0 : !torch.vtensor<[2,3,3],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_tril_zero
|
||||||
|
func.func @test_tril_zero(%arg0: !torch.vtensor<[3,0,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[3,0,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.tril %arg0, %[[DIAGONAL]] : !torch.vtensor<[3,0,5],si64>, !torch.int -> !torch.vtensor<[3,0,5],si64>
|
||||||
|
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) {torch.onnx.upper = 0 : si64} : (!torch.vtensor<[3,0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[3,0,5],si64>
|
||||||
|
return %0 : !torch.vtensor<[3,0,5],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_triu
|
||||||
|
func.func @test_triu(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[4,5],si64>, !torch.int -> !torch.vtensor<[4,5],si64>
|
||||||
|
%0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64>
|
||||||
|
return %0 : !torch.vtensor<[4,5],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_triu_one_row
|
||||||
|
func.func @test_triu_one_row(%arg0: !torch.vtensor<[3,1,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[3,1,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[3,1,5],si64>, !torch.int -> !torch.vtensor<[3,1,5],si64>
|
||||||
|
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[3,1,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[3,1,5],si64>
|
||||||
|
return %0 : !torch.vtensor<[3,1,5],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_triu_square
|
||||||
|
func.func @test_triu_square(%arg0: !torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[DIAGONAL:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[2,3,3],si64>, !torch.int -> !torch.vtensor<[2,3,3],si64>
|
||||||
|
%0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[2,3,3],si64>) -> !torch.vtensor<[2,3,3],si64>
|
||||||
|
return %0 : !torch.vtensor<[2,3,3],si64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_triu_zero
|
||||||
|
func.func @test_triu_zero(%arg0: !torch.vtensor<[0,5],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[DIAGONAL:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int
|
||||||
|
// CHECK: torch.aten.triu %arg0, %[[DIAGONAL]] : !torch.vtensor<[0,5],si64>, !torch.int -> !torch.vtensor<[0,5],si64>
|
||||||
|
%0 = torch.operator "onnx.Trilu"(%arg0, %arg1) : (!torch.vtensor<[0,5],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[0,5],si64>
|
||||||
|
return %0 : !torch.vtensor<[0,5],si64>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue