mirror of https://github.com/llvm/torch-mlir
Fix since-opset too high (#2701)
Addresses two of the ops from https://github.com/llvm/torch-mlir/issues/2689 https://github.com/llvm/torch-mlir/issues/2700pull/2703/head
parent
abc6b0a25a
commit
6847fc1fc6
|
@ -377,7 +377,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Cast", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
"Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
int64_t dtypeIntOnnx, dtypeIntTorch;
|
||||
|
@ -848,7 +848,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
binder.op, resultType, lhs, rhs);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Equal", 19,
|
||||
patterns.onOp("Equal", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value lhs, rhs;
|
||||
|
|
|
@ -169,7 +169,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Gemm", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
"Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value a, b, c;
|
||||
float alpha, beta;
|
||||
|
@ -313,7 +313,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
}
|
||||
return failure();
|
||||
});
|
||||
patterns.onOp("LeakyRelu", 16,
|
||||
patterns.onOp("LeakyRelu", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value operand;
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch
|
||||
// FB OPT OPS from https://github.com/llvm/torch-mlir/issues/2689
|
||||
|
||||
// -----
|
||||
// Fixed unecessarily high since-opset value
|
||||
func.func @cast_operation(%arg0: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%208 = torch.operator "onnx.Cast"(%arg0) {
|
||||
torch.onnx.to = 1 : si64
|
||||
} : (!torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %208 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
func.func @div_operation(%arg0: !torch.vtensor<[1,64,768],f32>,
|
||||
%arg1: !torch.vtensor<[1,64,1],f32>)
|
||||
-> !torch.vtensor<[1,64,768],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%209 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[1,64,768],f32>, !torch.vtensor<[1,64,1],f32>) -> !torch.vtensor<[1,64,768],f32>
|
||||
return %209 : !torch.vtensor<[1,64,768],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Fixed.
|
||||
// this is the onnx opset 1 version of Equal, only int types.
|
||||
// this used to fail to legalize because the "since" value is set unecessarily high (19)
|
||||
func.func @equal_operation(%arg0: !torch.vtensor<[4],si64>,
|
||||
%arg1: !torch.vtensor<[4],si64>)
|
||||
-> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%205 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1>
|
||||
return %205 : !torch.vtensor<[4],i1>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>)
|
||||
-> !torch.vtensor<[1,64,1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// The ReduceMean operation as provided.
|
||||
// expected-error @+1 {{failed to legalize operation 'torch.operator'}}
|
||||
%211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32>
|
||||
return %211 : !torch.vtensor<[1,64,1],f32>
|
||||
}
|
Loading…
Reference in New Issue