diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 85d6f805f..d842ea77b 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -76,6 +76,13 @@ struct OpBinder { return failure(); return success(); } + + ParseResult tensorOperandsList( llvm::SmallVectorImpl &values) { + for (int i = 0; i < op->getNumOperands(); i++) { + values.push_back(op->getOperand(i)); + } + return success(); + } // Result type matchers of different arities. ParseResult tensorResultType(Torch::ValueTensorType &type0) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c24fd0c65..d154edb1a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -88,8 +88,45 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( operand, vApproximate); return success(); }); - patterns.onOp("MatMul", 13, + patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + + patterns.onOp("LessOrEqual", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("Log", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("MatMul", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || @@ -135,19 +172,67 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp("Less", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) { + patterns.onOp("Max", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + operands.size() == 0) { return failure(); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); + } + Value result = operands[0]; + for (int i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp( + binder.op, result.getDefiningOp()); + return success(); }); - patterns.onOp("LessOrEqual", 16, + patterns.onOp("Min", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (int i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp( + binder.op, result.getDefiningOp()); + return success(); + }); + patterns.onOp("Neg", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Not", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Or", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; @@ -155,9 +240,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) { return failure(); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); - return success(); + return success(); }); patterns.onOp( "GatherElements", 13, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 085c6ea6a..e224ddfa2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -317,3 +317,46 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 %0 = torch.operator "onnx.GlobalAveragePool"(%arg0) : (!torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> return %0 : !torch.vtensor<[1,1,1,1],f32> } + +// CHECK-LABEL: func.func @test_max_example + func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Max"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// CHECK-LABEL: func.func @test_min_example + func.func @test_min_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Min"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + + +// CHECK-LABEL: func.func @test_log + func.func @test_log(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.log %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Log"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// CHECK-LABEL: func.func @test_neg + func.func @test_neg(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Neg"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// CHECK-LABEL: func.func @test_not_2d +func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Not"(%arg0) : (!torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> + } + +// CHECK-LABEL: func.func @test_or2d + func.func @test_or2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Or"(%arg0, %arg1) : (!torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> + }