mirror of https://github.com/llvm/torch-mlir
[MLIR][ONNX] Add OnnxToTorch support for Gelu (#2647)
This commit adds the OnnxToTorch support for Gelu op. --------- Co-authored-by: Rob Suderman <suderman@google.com>pull/2664/merge
parent
791c666479
commit
deacb8ef38
|
@ -27,6 +27,26 @@ using namespace mlir::torch::onnx_c;
|
||||||
// thing here, so we simplify.
|
// thing here, so we simplify.
|
||||||
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
OnnxCustomOpConversionPattern &patterns) {
|
OnnxCustomOpConversionPattern &patterns) {
|
||||||
|
|
||||||
|
patterns.onOp(
|
||||||
|
"Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Value operand;
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
std::string approximate;
|
||||||
|
|
||||||
|
if (binder.tensorOperand(operand) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.customOpNameStringAttr(approximate, "approximate", "none"))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value vApproximate = rewriter.create<Torch::ConstantStrOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::StringType>(),
|
||||||
|
rewriter.getStringAttr(approximate));
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenGeluOp>(binder.op, resultType,
|
||||||
|
operand, vApproximate);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
patterns.onOp("MatMul", 13,
|
patterns.onOp("MatMul", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s
|
// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch --split-input-file | FileCheck %s
|
||||||
// Generally, the test cases accumulated here come from running the importer
|
// Generally, the test cases accumulated here come from running the importer
|
||||||
// over all included backend tests that involve simple ops with no model
|
// over all included backend tests that involve simple ops with no model
|
||||||
// level constants. This is a pragmatic choice which lets us have a lot
|
// level constants. This is a pragmatic choice which lets us have a lot
|
||||||
|
@ -131,6 +131,8 @@ func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtenso
|
||||||
return %0 : !torch.vtensor<[3,3],f32>
|
return %0 : !torch.vtensor<[3,3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_matmul_3d
|
// CHECK-LABEL: @test_matmul_3d
|
||||||
func.func @test_matmul_3d(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,4,3],f32>) -> !torch.vtensor<[2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_matmul_3d(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,4,3],f32>) -> !torch.vtensor<[2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32> -> !torch.vtensor<[2,3,3],f32>
|
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32> -> !torch.vtensor<[2,3,3],f32>
|
||||||
|
@ -138,6 +140,8 @@ func.func @test_matmul_3d(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vten
|
||||||
return %0 : !torch.vtensor<[2,3,3],f32>
|
return %0 : !torch.vtensor<[2,3,3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @test_matmul_4d
|
// CHECK-LABEL: @test_matmul_4d
|
||||||
func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32> -> !torch.vtensor<[1,2,3,3],f32>
|
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32> -> !torch.vtensor<[1,2,3,3],f32>
|
||||||
|
@ -145,6 +149,48 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt
|
||||||
return %0 : !torch.vtensor<[1,2,3,3],f32>
|
return %0 : !torch.vtensor<[1,2,3,3],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_gelu_default_1
|
||||||
|
func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[STR1:.*]] = torch.constant.str "none"
|
||||||
|
// CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3],f32>, !torch.str -> !torch.vtensor<[3],f32>
|
||||||
|
%0 = torch.operator "onnx.Gelu"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||||
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_gelu_default_2
|
||||||
|
func.func @test_gelu_default_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[STR1:.*]] = torch.constant.str "none"
|
||||||
|
// CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3,4,5],f32>, !torch.str -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
%0 = torch.operator "onnx.Gelu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_gelu_tanh_1
|
||||||
|
func.func @test_gelu_tanh_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[STR1:.*]] = torch.constant.str "tanh"
|
||||||
|
// CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3],f32>, !torch.str -> !torch.vtensor<[3],f32>
|
||||||
|
%0 = torch.operator "onnx.Gelu"(%arg0) {torch.onnx.approximate = "tanh"} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32>
|
||||||
|
return %0 : !torch.vtensor<[3],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_gelu_tanh_2
|
||||||
|
func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[STR1:.*]] = torch.constant.str "tanh"
|
||||||
|
// CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3,4,5],f32>, !torch.str -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
%0 = torch.operator "onnx.Gelu"(%arg0) {torch.onnx.approximate = "tanh"} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @test_less_or_equal
|
// CHECK-LABEL: func.func @test_less_or_equal
|
||||||
func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
||||||
|
|
Loading…
Reference in New Issue