From 04f3a4ffce81495a2319a706c5914b152dd1aae1 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 14 Sep 2022 19:42:48 +0530 Subject: [PATCH] [MLIR][TORCH] Add support for bool element type for aten.sum[.dim_IntList] op This commit adds bool element type support for `aten.sum` and `aten.sum.dim_IntList` op. Signed-Off By: Vivek Khandelwal --- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 8 ++++ .../test_suite/reduction.py | 38 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 1a982f910..c04cd7681 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -930,6 +930,10 @@ void TypeAnalysis::visitOperation(Operation *op, if (auto sum = dyn_cast(op)) { Type defaultDtype = operands[0]->getValue().dtype; + // If the input dtype is bool, the result type should be i64. + if (defaultDtype.isInteger(1)) + defaultDtype = + IntegerType::get(op->getContext(), 64, IntegerType::Signed); Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype); auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); @@ -939,6 +943,10 @@ void TypeAnalysis::visitOperation(Operation *op, } if (auto sumDimIntList = dyn_cast(op)) { Type defaultDtype = operands[0]->getValue().dtype; + // If the input dtype is bool, the result type should be i64. + if (defaultDtype.isInteger(1)) + defaultDtype = + IntegerType::get(op->getContext(), 64, IntegerType::Signed); Type dtype = getDtypeOrDefault(sumDimIntList.getContext(), sumDimIntList.dtype(), defaultDtype); visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(), diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index a0ad28eac..90b18e6a2 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -49,6 +49,25 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumElementTypeBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.sum(a) + + +@register_test_case(module_factory=lambda: ReduceSumElementTypeBoolModule()) +def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool)) + +# ============================================================================== + class ReduceSumDimIntListFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -125,6 +144,25 @@ def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.sum(a, dim=(-1), keepdim=False) + + +@register_test_case(module_factory=lambda: ReduceSumDimIntListElementTypeBoolModule()) +def ReduceSumDimIntListElementTypeBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 128, high=2).to(dtype=torch.bool)) + +# ============================================================================== + class ReduceSumUnsignedIntModule(torch.nn.Module): def __init__(self): super().__init__()