Fixes https://github.com/llvm/torch-mlir/issues/751 where `torch.bool` is parsed as signless `i1`. (#752)

pull/756/head snapshot-20220413.387
Maksim Levental 2022-04-13 12:28:27 -05:00 committed by GitHub
parent d46f169c1a
commit 24f9de7120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 1 deletions

View File

@ -54,6 +54,7 @@ TOSA_PASS_SET = {
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic",
"BoolTensorHandleSignless_basic",
"ElementwiseRsqrtModule_basic",
"SqueezeModule_static",
"SqueezeModule_noUnitDim",

View File

@ -50,7 +50,7 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::Long;
if (type.isSignedInteger(32))
return torch_upstream::ScalarType::Int;
if (type.isUnsignedInteger(1))
if (type.isSignlessInteger(1))
return torch_upstream::ScalarType::Bool;
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}

View File

@ -1131,6 +1131,28 @@ def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
# ==============================================================================
class BoolTensorHandleSignless(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
([-1, -1], torch.bool, True),
])
def forward(self, a, b):
return a * b
@register_test_case(module_factory=lambda: BoolTensorHandleSignless())
def BoolTensorHandleSignless_basic(module, tu: TestUtils):
a = torch.tensor([[1, 1], [1, 1]], dtype=torch.bool)
b = torch.tensor([[0, 0], [0, 0]], dtype=torch.bool)
module.forward(a, b)
# ==============================================================================
class TModuleRank2(torch.nn.Module):
def __init__(self):
super().__init__()