[MLIR][TORCH] Add support for bool element type for aten.sum[.dim_IntList] op

This commit adds bool element type support for `aten.sum` and
`aten.sum.dim_IntList` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1380/head
Vivek Khandelwal 2022-09-14 19:42:48 +05:30
parent 1895b581c4
commit 04f3a4ffce
2 changed files with 46 additions and 0 deletions

View File

@ -930,6 +930,10 @@ void TypeAnalysis::visitOperation(Operation *op,
if (auto sum = dyn_cast<AtenSumOp>(op)) {
Type defaultDtype = operands[0]->getValue().dtype;
// If the input dtype is bool, the result type should be i64.
if (defaultDtype.isInteger(1))
defaultDtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
@ -939,6 +943,10 @@ void TypeAnalysis::visitOperation(Operation *op,
}
if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
Type defaultDtype = operands[0]->getValue().dtype;
// If the input dtype is bool, the result type should be i64.
if (defaultDtype.isInteger(1))
defaultDtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
Type dtype = getDtypeOrDefault(sumDimIntList.getContext(),
sumDimIntList.dtype(), defaultDtype);
visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),

View File

@ -49,6 +49,25 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils):
# ==============================================================================
class ReduceSumElementTypeBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.bool, True),
])
def forward(self, a):
return torch.sum(a)
@register_test_case(module_factory=lambda: ReduceSumElementTypeBoolModule())
def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool))
# ==============================================================================
class ReduceSumDimIntListFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -125,6 +144,25 @@ def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils):
# ==============================================================================
class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, a):
return torch.sum(a, dim=(-1), keepdim=False)
@register_test_case(module_factory=lambda: ReduceSumDimIntListElementTypeBoolModule())
def ReduceSumDimIntListElementTypeBoolModule_basic(module, tu: TestUtils):
module.forward(tu.randint(1, 128, high=2).to(dtype=torch.bool))
# ==============================================================================
class ReduceSumUnsignedIntModule(torch.nn.Module):
def __init__(self):
super().__init__()