mirror of https://github.com/llvm/torch-mlir
Fixes https://github.com/llvm/torch-mlir/issues/751 where `torch.bool` is parsed as signless `i1`. (#752)
parent
d46f169c1a
commit
24f9de7120
|
@ -54,6 +54,7 @@ TOSA_PASS_SET = {
|
|||
"BoolTensorReturnFalseModule_basic",
|
||||
"BoolTensorReturnTrueModule_basic",
|
||||
"BoolTensorReturnMixedModule_basic",
|
||||
"BoolTensorHandleSignless_basic",
|
||||
"ElementwiseRsqrtModule_basic",
|
||||
"SqueezeModule_static",
|
||||
"SqueezeModule_noUnitDim",
|
||||
|
|
|
@ -50,7 +50,7 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
|||
return torch_upstream::ScalarType::Long;
|
||||
if (type.isSignedInteger(32))
|
||||
return torch_upstream::ScalarType::Int;
|
||||
if (type.isUnsignedInteger(1))
|
||||
if (type.isSignlessInteger(1))
|
||||
return torch_upstream::ScalarType::Bool;
|
||||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||
}
|
||||
|
|
|
@ -1131,6 +1131,28 @@ def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class BoolTensorHandleSignless(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return a * b
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BoolTensorHandleSignless())
|
||||
def BoolTensorHandleSignless_basic(module, tu: TestUtils):
|
||||
a = torch.tensor([[1, 1], [1, 1]], dtype=torch.bool)
|
||||
b = torch.tensor([[0, 0], [0, 0]], dtype=torch.bool)
|
||||
module.forward(a, b)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class TModuleRank2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue