Fix since-opset too high (#2701)

Addresses two of the ops from
https://github.com/llvm/torch-mlir/issues/2689

https://github.com/llvm/torch-mlir/issues/2700
pull/2703/head
Xida Ren (Cedar) 2023-12-27 10:08:09 -08:00 committed by GitHub
parent abc6b0a25a
commit 6847fc1fc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 4 deletions

View File

@ -377,7 +377,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"Cast", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
int64_t dtypeIntOnnx, dtypeIntTorch;
@ -848,7 +848,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("Equal", 19,
patterns.onOp("Equal", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;

View File

@ -169,7 +169,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return success();
});
patterns.onOp(
"Gemm", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value a, b, c;
float alpha, beta;
@ -313,7 +313,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}
return failure();
});
patterns.onOp("LeakyRelu", 16,
patterns.onOp("LeakyRelu", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;

View File

@ -0,0 +1,40 @@
// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch
// FB OPT OPS from https://github.com/llvm/torch-mlir/issues/2689
// -----
// Fixed unecessarily high since-opset value
func.func @cast_operation(%arg0: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%208 = torch.operator "onnx.Cast"(%arg0) {
torch.onnx.to = 1 : si64
} : (!torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %208 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
func.func @div_operation(%arg0: !torch.vtensor<[1,64,768],f32>,
%arg1: !torch.vtensor<[1,64,1],f32>)
-> !torch.vtensor<[1,64,768],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%209 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[1,64,768],f32>, !torch.vtensor<[1,64,1],f32>) -> !torch.vtensor<[1,64,768],f32>
return %209 : !torch.vtensor<[1,64,768],f32>
}
// -----
// Fixed.
// this is the onnx opset 1 version of Equal, only int types.
// this used to fail to legalize because the "since" value is set unecessarily high (19)
func.func @equal_operation(%arg0: !torch.vtensor<[4],si64>,
%arg1: !torch.vtensor<[4],si64>)
-> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%205 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1>
return %205 : !torch.vtensor<[4],i1>
}
// -----
func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>)
-> !torch.vtensor<[1,64,1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// The ReduceMean operation as provided.
// expected-error @+1 {{failed to legalize operation 'torch.operator'}}
%211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32>
return %211 : !torch.vtensor<[1,64,1],f32>
}