From 0a913bc904c2752c4357dcc1670d315339bfae8b Mon Sep 17 00:00:00 2001 From: Vidush Singhal <54336227+vid-999@users.noreply.github.com> Date: Wed, 1 Jun 2022 21:20:25 -0400 Subject: [PATCH] Add E2E support for AtenAllBoolOp (#874) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 +++++++++++ lib/Conversion/TorchToStd/TorchToStd.cpp | 34 +++++++++++---- .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise_comparison.py | 41 ++++++++++++++++++- 4 files changed, 91 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 79dcf8fef..c184a9243 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4146,6 +4146,29 @@ def Torch_AtenAllOp : Torch_Op<"aten.all", [ }]; } +def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::all.bool : (bool[]) -> (bool)`"; + let arguments = (ins + AnyTorchListOfTorchBoolType:$self + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAllBoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAllBoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAnyOp : Torch_Op<"aten.any", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index 5ec956341..60627a167 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -196,11 +196,14 @@ public: } // namespace namespace { -class ConvertAtenAnyBoolOp : public OpConversionPattern { +template +class ConvertAtenAnyOrAllBoolOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; + virtual bool reductionFunction(ArrayRef inputArray) const = 0; LogicalResult - matchAndRewrite(AtenAnyBoolOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector inputListTorchBool; @@ -216,13 +219,29 @@ public: op, "only support constant bool input list elements"); inputListBool.push_back(cst); } - bool result = llvm::any_of( - inputListBool, [](bool inputListElem) { return inputListElem; }); + bool result = reductionFunction(inputListBool); + rewriter.replaceOpWithNewOp( op, rewriter.getBoolAttr(result)); return success(); } }; + +class ConvertAtenAnyOp : public ConvertAtenAnyOrAllBoolOp { + using ConvertAtenAnyOrAllBoolOp::ConvertAtenAnyOrAllBoolOp; + bool reductionFunction(ArrayRef inputArray) const override { + return llvm::any_of(inputArray, + [](bool inputListElem) { return inputListElem; }); + } +}; + +class ConvertAtenAllOp : public ConvertAtenAnyOrAllBoolOp { + using ConvertAtenAnyOrAllBoolOp::ConvertAtenAnyOrAllBoolOp; + bool reductionFunction(ArrayRef inputArray) const override { + return llvm::all_of(inputArray, + [](bool inputListElem) { return inputListElem; }); + } +}; } // namespace namespace { @@ -340,8 +359,9 @@ public: target.addIllegalOp(); patterns.add>( typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add< ConvertAtenBoolLikeOp (Tensor)") emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") + emit("aten::all.bool : (bool[]) -> (bool)") emit("aten::any : (Tensor) -> (Tensor)") emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 30732849e..3e7f8a79a 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -496,7 +496,6 @@ class ElementwiseNeIntScalarModule(torch.nn.Module): def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(2, 4, (8, 5))) - # ============================================================================== class AnyBoolTrueModule(torch.nn.Module): @@ -533,3 +532,43 @@ class AnyBoolFalseModule(torch.nn.Module): @register_test_case(module_factory=lambda: AnyBoolFalseModule()) def AnyBoolFalseModule_basic(module, tu: TestUtils): module.forward() + +# ================================================================================= + +class AllBoolTrueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + input = [True, True, True, True, True] + return torch.ops.aten.all(input) + + +@register_test_case(module_factory=lambda: AllBoolTrueModule()) +def AllBoolTrueModule_basic(module, tu: TestUtils): + module.forward() + +# ================================================================================= + +class AllBoolFalseModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + input = [True, False, True, True, False] + return torch.ops.aten.all(input) + +@register_test_case(module_factory=lambda: AllBoolFalseModule()) +def AllBoolFalseModule_basic(module, tu: TestUtils): + module.forward() +