From b411a40b3d674be82f91b652c1939249a633f36f Mon Sep 17 00:00:00 2001 From: Jiawei Wu Date: Wed, 6 Sep 2023 14:21:51 +0800 Subject: [PATCH] [Torch Dialect] emit aten.__or__Tensor Op (#2437) * emit aten.__or__TensorOp * bug fix * remove convert to stablehlo * code style refinement --- e2e_testing/xfail_sets.py | 3 ++ .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 13 +++++ .../Transforms/AbstractInterpLibrary.cpp | 12 +++++ .../build_tools/abstract_interp_lib_gen.py | 11 ++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 50 +++++++++++++++++++ 7 files changed, 115 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 3f6703684..a860d1be7 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -452,6 +452,7 @@ STABLEHLO_PASS_SET = { "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseBitwiseOrStaticShapeModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseClampModule_basic", @@ -985,6 +986,8 @@ TOSA_PASS_SET = { "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt32Module_basic", "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseOrTensorModule_basic", "ElementwiseBitwiseOrModule_basic", "ElementwiseBitwiseOrStaticShapeModule_basic", "ElementwiseBitwiseXorModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9da6a227c..97403ab0e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6141,6 +6141,31 @@ def Torch_Aten__And__TensorOp : Torch_Op<"aten.__and__.Tensor", [ }]; } +def Torch_Aten__Or__TensorOp : Torch_Op<"aten.__or__.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Or__TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Or__TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e1dcf4e46..a717e304e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1124,6 +1124,19 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// Aten__Or__TensorOp +//===----------------------------------------------------------------------===// + +void Aten__Or__TensorOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](Aten__Or__TensorOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getOther()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenScalarImplicitOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 4ed92b85b..8a39364cf 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7376,6 +7376,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.__or__.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.minimum\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9154,6 +9158,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.__or__.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 6c1d59f25..f0a77d9ad 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -771,6 +771,9 @@ def aten〇atan2〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇__and__〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇__or__〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇minimum〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -2223,6 +2226,14 @@ def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇__or__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function(_check_two_tensor_op()) def aten〇add〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int: other_rank, other_dtype = other_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index f44175ed0..ef8166ace 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -456,6 +456,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)") emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)") emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") + emit("aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index a1ddaaeaf..f81aae63a 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2024,6 +2024,56 @@ def ElementwiseBitwiseOrStaticShapeModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseOrTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.__or__(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseOrTensorModule()) +def ElementwiseOrTensorModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-10, high=10).to(torch.int32), + tu.randint(3, 4, low=-10, high=10)) + + +# ============================================================================== + + +class ElementwiseOrTensorStaticShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int32, True), + ([4], torch.int64, True), + ]) + def forward(self, x, y): + return torch.ops.aten.__or__(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseOrTensorStaticShapeModule()) +def ElementwiseOrTensorStaticShapeModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-10, high=10).to(torch.int32), + tu.randint(4, low=-10, high=10)) + + +# ============================================================================== + + class ElementwiseBitwiseXorModule(torch.nn.Module): def __init__(self):