[linalg] Fix torch.aten.add of `torch.bool` (#3820)

Addition of bools saturate which equates to an `or` operator. Updated to
avoid some noticed downstream failures.
pull/3828/head
Rob Suderman 2024-10-31 17:37:25 -07:00 committed by GitHub
parent 9c1e3b8154
commit 5aa323dd29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 0 deletions

View File

@ -827,6 +827,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (isa<mlir::FloatType>(dtype)) { if (isa<mlir::FloatType>(dtype)) {
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha); Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
return b.create<arith::AddFOp>(loc, lhs, scaled); return b.create<arith::AddFOp>(loc, lhs, scaled);
} else if (dtype.isInteger(1)) {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::OrIOp>(loc, lhs, scaled);
} else { } else {
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha); Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
return b.create<arith::AddIOp>(loc, lhs, scaled); return b.create<arith::AddIOp>(loc, lhs, scaled);

View File

@ -685,6 +685,35 @@ def ElementwiseAddModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
# Addition is an interesting special case of a binary op, because under the hood
# it carries a third scalar "alpha" parameter, which needs special handling.
class ElementwiseAddBoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([4], torch.bool, True),
([4], torch.bool, True),
]
)
def forward(self, a, b):
return a + b
@register_test_case(module_factory=lambda: ElementwiseAddBoolModule())
def ElementwiseAddBoolModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([False, False, True, True]),
torch.tensor([False, True, False, False]),
)
# ==============================================================================
class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module): class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()