From 2f6fdb7f0b5b4361e0053c1872aef6fae6aeb490 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 11 Feb 2023 05:58:15 +0800 Subject: [PATCH] [Torch Dialect] add folder for prim.min.int (#1864) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 19 +++++++++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 2 +- .../torch_mlir_e2e_test/test_suite/basic.py | 19 +++++++++++++++++++ test/Dialect/Torch/canonicalize.mlir | 18 ++++++++++++++++++ 5 files changed, 58 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 1b7464be6..aa9565111 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10894,6 +10894,7 @@ def Torch_PrimMinIntOp : Torch_Op<"prim.min.int", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e3fb59fc0..e55b21f43 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2348,6 +2348,25 @@ OpFoldResult PrimMinSelfIntOp::fold(FoldAdaptor adaptor) { *std::min_element(values.begin(), values.end())); } +//===----------------------------------------------------------------------===// +// PrimMinIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) { + // If both operands are the same, then the operation is an identity. + if (getA() == getB()) + return getA(); + + auto lhs = adaptor.getA().dyn_cast_or_null(); + auto rhs = adaptor.getB().dyn_cast_or_null(); + if (!lhs || !rhs) + return nullptr; + // Torch semantics are that !torch.int is 64-bit signed. + return IntegerAttr::get( + lhs.getType(), + std::min(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue())); +} + //===----------------------------------------------------------------------===// // ShapeCalculateOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index f4ba271b9..3c12ede1d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -662,7 +662,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)") emit("prim::min.self_int : (int[]) -> (int)", has_folder=True) - emit("prim::min.int : (int, int) -> (int)") + emit("prim::min.int : (int, int) -> (int)", has_folder=True) emit("prim::max.self_int : (int[]) -> (int)") emit("prim::max.int : (int, int) -> (int)", has_folder=True) emit("prim::RaiseException : (str, str?) -> ()") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index ef00c0d15..b38dd7f24 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1270,6 +1270,25 @@ def LogSoftmaxIntModule_basic(module, tu: TestUtils): # ============================================================================== +class PrimMinIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.ops.prim.min(1, -1) + + +@register_test_case(module_factory=lambda: PrimMinIntModule()) +def PrimMinIntModule_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== class NumToTensorIntModule(torch.nn.Module): diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index cb5e3ead1..9f976f513 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -502,6 +502,24 @@ func.func @torch.prim.max.int$constant() -> !torch.int { return %0 : !torch.int } +// CHECK-LABEL: func.func @torch.prim.min.int$identity( +// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int { +// CHECK: return %[[ARG]] : !torch.int +func.func @torch.prim.min.int$identity(%arg0: !torch.int) -> !torch.int { + %0 = torch.prim.min.int %arg0, %arg0 : !torch.int, !torch.int -> !torch.int + return %0 : !torch.int +} + +// CHECK-LABEL: func.func @torch.prim.min.int$constant() -> !torch.int { +// CHECK: %[[INT1:.*]] = torch.constant.int -1 +// CHECK: return %[[INT1]] : !torch.int +func.func @torch.prim.min.int$constant() -> !torch.int { + %int-1 = torch.constant.int -1 + %int3 = torch.constant.int 3 + %0 = torch.prim.min.int %int-1, %int3 : !torch.int, !torch.int -> !torch.int + return %0 : !torch.int +} + // CHECK-LABEL: func.func @torch.prim.min.self_int$basic() -> !torch.int { // CHECK: %[[M1:.*]] = torch.constant.int -1 // CHECK: return %[[M1]] : !torch.int