From 24f9de7120bcd95de3d341f325a46d19a6a2dc4c Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 13 Apr 2022 12:28:27 -0500 Subject: [PATCH] Fixes https://github.com/llvm/torch-mlir/issues/751 where `torch.bool` is parsed as signless `i1`. (#752) --- e2e_testing/torchscript/xfail_sets.py | 1 + lib/Dialect/Torch/Utils/Utils.cpp | 2 +- .../torch_mlir_e2e_test/test_suite/basic.py | 22 +++++++++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 6765872ed..2da196052 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -54,6 +54,7 @@ TOSA_PASS_SET = { "BoolTensorReturnFalseModule_basic", "BoolTensorReturnTrueModule_basic", "BoolTensorReturnMixedModule_basic", + "BoolTensorHandleSignless_basic", "ElementwiseRsqrtModule_basic", "SqueezeModule_static", "SqueezeModule_noUnitDim", diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index e818e895f..77fdf5a6c 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -50,7 +50,7 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::Long; if (type.isSignedInteger(32)) return torch_upstream::ScalarType::Int; - if (type.isUnsignedInteger(1)) + if (type.isSignlessInteger(1)) return torch_upstream::ScalarType::Bool; llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 5ba7f2780..6b7487235 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1131,6 +1131,28 @@ def BoolTensorReturnMixedModule_basic(module, tu: TestUtils): # ============================================================================== +class BoolTensorHandleSignless(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ([-1, -1], torch.bool, True), + ]) + def forward(self, a, b): + return a * b + + +@register_test_case(module_factory=lambda: BoolTensorHandleSignless()) +def BoolTensorHandleSignless_basic(module, tu: TestUtils): + a = torch.tensor([[1, 1], [1, 1]], dtype=torch.bool) + b = torch.tensor([[0, 0], [0, 0]], dtype=torch.bool) + module.forward(a, b) + +# ============================================================================== + class TModuleRank2(torch.nn.Module): def __init__(self): super().__init__()