From abfaf8c577813316bcd2b9ec9686f1eaacfcd487 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 21 Oct 2021 11:50:01 -0400 Subject: [PATCH] Add aten.ne.bool to make CI pass --- .../Dialect/Torch/IR/GeneratedAtenOps.td | 16 ++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 18 ++++++++++- .../jit_ir/build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/canonicalize.mlir | 30 +++++++++++++++++++ 4 files changed, 64 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index ed6053880..5a84b70c4 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -2382,6 +2382,22 @@ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; } +def Torch_AtenNeBoolOp : Torch_Op<"aten.ne.bool", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::ne.bool : (bool, bool) -> (bool)`"; + let arguments = (ins + Torch_BoolType:$a, + Torch_BoolType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; + let hasFolder = 1; +} + def Torch_Aten__Is__Op : Torch_Op<"aten.__is__", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index cf2f904c4..3d04de4b2 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -435,7 +435,23 @@ OpFoldResult Aten__Not__Op::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// -// AtenLenTOp +// AtenNeBoolOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenNeBoolOp::fold(ArrayRef operands) { + if (getOperand(0) == getOperand(1)) + return IntegerAttr::get(IntegerType::get(getContext(), 1), false); + + bool a, b; + if (!matchPattern(getOperand(0), m_TorchConstantBool(&a))) + return nullptr; + if (!matchPattern(getOperand(1), m_TorchConstantBool(&b))) + return nullptr; + return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b); +} + +//===----------------------------------------------------------------------===// +// AtenDimOp //===----------------------------------------------------------------------===// OpFoldResult AtenDimOp::fold(ArrayRef operands) { diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index fca55f443..597762bbc 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -590,6 +590,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::neg.float : (float) -> (float)") emit("aten::lt.float_int : (float, int) -> (bool)") emit("aten::__and__.bool : (bool, bool) -> (bool)") + emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__not__ : (bool) -> (bool)", has_folder=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 63c87418e..31d61f223 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -32,6 +32,36 @@ func @torch.aten.__isnot__$none_isnot_none(%arg0: !torch.none, %arg1: !torch.non return %0 : !torch.bool } +// CHECK-LABEL: func @torch.aten.ne.bool() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func @torch.aten.ne.bool() -> !torch.bool { + %a = torch.constant.bool true + %b = torch.constant.bool false + %0 = torch.aten.ne.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ne.bool$same_operand( +// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func @torch.aten.ne.bool$same_operand(%arg0: !torch.bool) -> !torch.bool { + %0 = torch.aten.ne.bool %arg0, %arg0: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ne.bool$different_operand( +// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[RET:.*]] = torch.aten.ne.bool %[[ARG0]], %[[FALSE]] : !torch.bool, !torch.bool -> !torch.bool +// CHECK: return %[[RET]] : !torch.bool +func @torch.aten.ne.bool$different_operand(%a: !torch.bool) -> !torch.bool { + %b = torch.constant.bool false + %0 = torch.aten.ne.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func @torch.aten.size$canonicalize_to_list( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.list { // CHECK: %[[C2:.*]] = torch.constant.int 2