mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for bool element type for aten.sum[.dim_IntList] op
This commit adds bool element type support for `aten.sum` and `aten.sum.dim_IntList` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1380/head
parent
1895b581c4
commit
04f3a4ffce
|
@ -930,6 +930,10 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
|
|
||||||
if (auto sum = dyn_cast<AtenSumOp>(op)) {
|
if (auto sum = dyn_cast<AtenSumOp>(op)) {
|
||||||
Type defaultDtype = operands[0]->getValue().dtype;
|
Type defaultDtype = operands[0]->getValue().dtype;
|
||||||
|
// If the input dtype is bool, the result type should be i64.
|
||||||
|
if (defaultDtype.isInteger(1))
|
||||||
|
defaultDtype =
|
||||||
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||||
Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
|
Type dtype = getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
|
@ -939,6 +943,10 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
}
|
}
|
||||||
if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||||
Type defaultDtype = operands[0]->getValue().dtype;
|
Type defaultDtype = operands[0]->getValue().dtype;
|
||||||
|
// If the input dtype is bool, the result type should be i64.
|
||||||
|
if (defaultDtype.isInteger(1))
|
||||||
|
defaultDtype =
|
||||||
|
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||||
Type dtype = getDtypeOrDefault(sumDimIntList.getContext(),
|
Type dtype = getDtypeOrDefault(sumDimIntList.getContext(),
|
||||||
sumDimIntList.dtype(), defaultDtype);
|
sumDimIntList.dtype(), defaultDtype);
|
||||||
visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
|
visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
|
||||||
|
|
|
@ -49,6 +49,25 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceSumElementTypeBoolModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.bool, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.sum(a)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceSumElementTypeBoolModule())
|
||||||
|
def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(3, 4, 5, high=2).to(torch.bool))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ReduceSumDimIntListFloatModule(torch.nn.Module):
|
class ReduceSumDimIntListFloatModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -125,6 +144,25 @@ def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.bool, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.sum(a, dim=(-1), keepdim=False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceSumDimIntListElementTypeBoolModule())
|
||||||
|
def ReduceSumDimIntListElementTypeBoolModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(1, 128, high=2).to(dtype=torch.bool))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ReduceSumUnsignedIntModule(torch.nn.Module):
|
class ReduceSumUnsignedIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue