From 5aa323dd29083ef90b3956e50a6839635e7c1181 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 31 Oct 2024 17:37:25 -0700 Subject: [PATCH] [linalg] Fix torch.aten.add of `torch.bool` (#3820) Addition of bools saturate which equates to an `or` operator. Updated to avoid some noticed downstream failures. --- .../TorchToLinalg/Uncategorized.cpp | 3 ++ .../test_suite/elementwise.py | 29 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0f6f92bd7..c129c9614 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -827,6 +827,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (isa(dtype)) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); + } else if (dtype.isInteger(1)) { + Value scaled = b.create(loc, rhs, alpha); + return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index e9098698f..88a269a09 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -685,6 +685,35 @@ def ElementwiseAddModule_basic(module, tu: TestUtils): # ============================================================================== +# Addition is an interesting special case of a binary op, because under the hood +# it carries a third scalar "alpha" parameter, which needs special handling. +class ElementwiseAddBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4], torch.bool, True), + ([4], torch.bool, True), + ] + ) + def forward(self, a, b): + return a + b + + +@register_test_case(module_factory=lambda: ElementwiseAddBoolModule()) +def ElementwiseAddBoolModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([False, False, True, True]), + torch.tensor([False, True, False, False]), + ) + + +# ============================================================================== + + class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module): def __init__(self): super().__init__()