From 46a2189a417100a17659a096a18560a89a34255e Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 23 Nov 2021 21:54:47 +0530 Subject: [PATCH] [MLIR][TORCH] Add E2E support for aten.bitwise_and.tensor op This commit adds lowering of `aten.bitwise_and.tensor` op. Signed-Off By: Vivek Khandelwal vivek@nod-labs.com --- e2e_testing/torchscript/elementwise.py | 21 +++++++++ .../Dialect/Torch/IR/GeneratedAtenOps.td | 15 +++++++ .../TorchToLinalg/TorchToLinalg.cpp | 43 +++++++++++++------ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- .../jit_ir/build_tools/torch_ods_gen.py | 2 + 5 files changed, 69 insertions(+), 14 deletions(-) diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 8cc8d0eae..d1872dd03 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -552,3 +552,24 @@ class ElementwiseDivScalarModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwiseDivScalarModule()) def ElementwiseDivScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + +# ============================================================================== +class ElementwiseAndIntegerModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int64, True), + ]) + + def forward(self, x, y): + return torch.bitwise_and(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseAndIntegerModule()) +def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-10, 10, (3, 4)).to(torch.int32), + torch.randint(-10, 10, (3, 4))) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index d24249ab6..c56091352 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1494,6 +1494,21 @@ def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)"; } +def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $other attr-dict `:` type($self) `,` type($other) `->` type($result)"; +} + def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ AllowsTypeRefinement ]> { diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 4d6682d58..39e216dd1 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1394,6 +1394,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); + if (auto bitwiseAndTensor = dyn_cast(op)) { + if (bitwiseAndTensor.getType() + .cast() + .getDtype() + .isa()) { + bitwiseAndTensor.emitError( + "Bitwise_And does not support floating point dtype"); + return nullptr; + } + Type dtype = converter->convertType(bitwiseAndTensor.getType()) + .cast() + .getElementType(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); + } if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) @@ -1895,13 +1911,14 @@ struct ConvertElementwiseOp : ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isa(op)) + if (!isa(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -3193,12 +3210,12 @@ public: target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp< - AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, - AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, - AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, - AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, - AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, - AtenReciprocalOp>(); + AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, + AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, + AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, + AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, + AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index c91ab9f8f..03f84c85e 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -304,7 +304,7 @@ public: return visitBinaryTensorScalarOp(op, operands); } else if (isa(op)) { + AtenMinimumOp, AtenMaximumOp, AtenBitwiseAndTensorOp>(op)) { return visitBinaryBroadcastingOp(op, operands); } else if (auto lerpTensor = llvm::dyn_cast(op)) { return visitAtenLerpTensorOp(lerpTensor, operands); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 7f6b6686d..419be419c 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -470,6 +470,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): "aten::rsqrt : (Tensor) -> (Tensor)", "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", + "aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)", + ]: emit_with_mutating_variants(key) # Elementwise tensor compute ops that don't have the standard mutating