diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 61bea1d86..3d2cf8aae 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -377,7 +377,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "Cast", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; int64_t dtypeIntOnnx, dtypeIntTorch; @@ -848,7 +848,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp("Equal", 19, + patterns.onOp("Equal", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 710c9823f..e7544d6c1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -169,7 +169,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); patterns.onOp( - "Gemm", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value a, b, c; float alpha, beta; @@ -313,7 +313,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } return failure(); }); - patterns.onOp("LeakyRelu", 16, + patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir new file mode 100644 index 000000000..8401c378b --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -0,0 +1,40 @@ +// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch +// FB OPT OPS from https://github.com/llvm/torch-mlir/issues/2689 + +// ----- +// Fixed unecessarily high since-opset value +func.func @cast_operation(%arg0: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %208 = torch.operator "onnx.Cast"(%arg0) { + torch.onnx.to = 1 : si64 + } : (!torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %208 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- +func.func @div_operation(%arg0: !torch.vtensor<[1,64,768],f32>, + %arg1: !torch.vtensor<[1,64,1],f32>) + -> !torch.vtensor<[1,64,768],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %209 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[1,64,768],f32>, !torch.vtensor<[1,64,1],f32>) -> !torch.vtensor<[1,64,768],f32> + return %209 : !torch.vtensor<[1,64,768],f32> +} + +// ----- +// Fixed. +// this is the onnx opset 1 version of Equal, only int types. +// this used to fail to legalize because the "since" value is set unecessarily high (19) +func.func @equal_operation(%arg0: !torch.vtensor<[4],si64>, + %arg1: !torch.vtensor<[4],si64>) + -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %205 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> + return %205 : !torch.vtensor<[4],i1> +} + + +// ----- +func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) + -> !torch.vtensor<[1,64,1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // The ReduceMean operation as provided. + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> + return %211 : !torch.vtensor<[1,64,1],f32> +} \ No newline at end of file