Add E2E support for AtenAllBoolOp (#874)

pull/881/merge snapshot-20220602.488
Vidush Singhal 2022-06-01 21:20:25 -04:00 committed by GitHub
parent 7fdc1cff02
commit 0a913bc904
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 91 additions and 8 deletions

View File

@ -4146,6 +4146,29 @@ def Torch_AtenAllOp : Torch_Op<"aten.all", [
}]; }];
} }
def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::all.bool : (bool[]) -> (bool)`";
let arguments = (ins
AnyTorchListOfTorchBoolType:$self
);
let results = (outs
Torch_BoolType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAllBoolOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenAllBoolOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenAnyOp : Torch_Op<"aten.any", [ def Torch_AtenAnyOp : Torch_Op<"aten.any", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -196,11 +196,14 @@ public:
} // namespace } // namespace
namespace { namespace {
class ConvertAtenAnyBoolOp : public OpConversionPattern<AtenAnyBoolOp> { template <typename OpTy>
class ConvertAtenAnyOrAllBoolOp : public OpConversionPattern<OpTy> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern<OpTy>::OpConversionPattern;
using OpAdaptor = typename OpTy::Adaptor;
virtual bool reductionFunction(ArrayRef<bool> inputArray) const = 0;
LogicalResult LogicalResult
matchAndRewrite(AtenAnyBoolOp op, OpAdaptor adaptor, matchAndRewrite(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
SmallVector<Value> inputListTorchBool; SmallVector<Value> inputListTorchBool;
@ -216,13 +219,29 @@ public:
op, "only support constant bool input list elements"); op, "only support constant bool input list elements");
inputListBool.push_back(cst); inputListBool.push_back(cst);
} }
bool result = llvm::any_of( bool result = reductionFunction(inputListBool);
inputListBool, [](bool inputListElem) { return inputListElem; });
rewriter.replaceOpWithNewOp<arith::ConstantOp>( rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, rewriter.getBoolAttr(result)); op, rewriter.getBoolAttr(result));
return success(); return success();
} }
}; };
class ConvertAtenAnyOp : public ConvertAtenAnyOrAllBoolOp<AtenAnyBoolOp> {
using ConvertAtenAnyOrAllBoolOp<AtenAnyBoolOp>::ConvertAtenAnyOrAllBoolOp;
bool reductionFunction(ArrayRef<bool> inputArray) const override {
return llvm::any_of(inputArray,
[](bool inputListElem) { return inputListElem; });
}
};
class ConvertAtenAllOp : public ConvertAtenAnyOrAllBoolOp<AtenAllBoolOp> {
using ConvertAtenAnyOrAllBoolOp<AtenAllBoolOp>::ConvertAtenAnyOrAllBoolOp;
bool reductionFunction(ArrayRef<bool> inputArray) const override {
return llvm::all_of(inputArray,
[](bool inputListElem) { return inputListElem; });
}
};
} // namespace } // namespace
namespace { namespace {
@ -340,8 +359,9 @@ public:
target.addIllegalOp<AtenSqrtIntOp>(); target.addIllegalOp<AtenSqrtIntOp>();
patterns.add<ConvertAtenUnaryOpToFloatMathOp<AtenSqrtIntOp, math::SqrtOp>>( patterns.add<ConvertAtenUnaryOpToFloatMathOp<AtenSqrtIntOp, math::SqrtOp>>(
typeConverter, context); typeConverter, context);
target.addIllegalOp<AtenAnyBoolOp>(); target.addIllegalOp<AtenAnyBoolOp, AtenAllBoolOp>();
patterns.add<ConvertAtenAnyBoolOp>(typeConverter, context); patterns.add<ConvertAtenAnyOp>(typeConverter, context);
patterns.add<ConvertAtenAllOp>(typeConverter, context);
target.addIllegalOp<AtenBoolFloatOp, AtenBoolIntOp>(); target.addIllegalOp<AtenBoolFloatOp, AtenBoolIntOp>();
patterns.add< patterns.add<
ConvertAtenBoolLikeOp<AtenBoolFloatOp, arith::CmpFOp, ConvertAtenBoolLikeOp<AtenBoolFloatOp, arith::CmpFOp,

View File

@ -387,6 +387,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)")
emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)")
emit("aten::all.bool : (bool[]) -> (bool)")
emit("aten::any : (Tensor) -> (Tensor)") emit("aten::any : (Tensor) -> (Tensor)")
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")

View File

@ -496,7 +496,6 @@ class ElementwiseNeIntScalarModule(torch.nn.Module):
def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils): def ElementwiseNeIntScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(2, 4, (8, 5))) module.forward(torch.randint(2, 4, (8, 5)))
# ============================================================================== # ==============================================================================
class AnyBoolTrueModule(torch.nn.Module): class AnyBoolTrueModule(torch.nn.Module):
@ -533,3 +532,43 @@ class AnyBoolFalseModule(torch.nn.Module):
@register_test_case(module_factory=lambda: AnyBoolFalseModule()) @register_test_case(module_factory=lambda: AnyBoolFalseModule())
def AnyBoolFalseModule_basic(module, tu: TestUtils): def AnyBoolFalseModule_basic(module, tu: TestUtils):
module.forward() module.forward()
# =================================================================================
class AllBoolTrueModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
input = [True, True, True, True, True]
return torch.ops.aten.all(input)
@register_test_case(module_factory=lambda: AllBoolTrueModule())
def AllBoolTrueModule_basic(module, tu: TestUtils):
module.forward()
# =================================================================================
class AllBoolFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
input = [True, False, True, True, False]
return torch.ops.aten.all(input)
@register_test_case(module_factory=lambda: AllBoolFalseModule())
def AllBoolFalseModule_basic(module, tu: TestUtils):
module.forward()