mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for bool element type for aten.sum[.dim_IntList] op
This commit adds bool element type support for `aten.sum` and `aten.sum.dim_IntList` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1380/head
parent
1895b581c4
commit
04f3a4ffce
|
@ -930,6 +930,10 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
|
||||
if (auto sum = dyn_cast<AtenSumOp>(op)) {
|
||||
Type defaultDtype = operands[0]->getValue().dtype;
|
||||
// If the input dtype is bool, the result type should be i64.
|
||||
if (defaultDtype.isInteger(1))
|
||||
defaultDtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
|
@ -939,6 +943,10 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
}
|
||||
if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||
Type defaultDtype = operands[0]->getValue().dtype;
|
||||
// If the input dtype is bool, the result type should be i64.
|
||||
if (defaultDtype.isInteger(1))
|
||||
defaultDtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
Type dtype = getDtypeOrDefault(sumDimIntList.getContext(),
|
||||
sumDimIntList.dtype(), defaultDtype);
|
||||
visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
|
||||
|
|
|
@ -49,6 +49,25 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumElementTypeBoolModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumElementTypeBoolModule())
|
||||
def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumDimIntListFloatModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -125,6 +144,25 @@ def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dim=(-1), keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListElementTypeBoolModule())
|
||||
def ReduceSumDimIntListElementTypeBoolModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1, 128, high=2).to(dtype=torch.bool))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumUnsignedIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue