mirror of https://github.com/llvm/torch-mlir
[linalg] Fix torch.aten.add of `torch.bool` (#3820)
Addition of bools saturate which equates to an `or` operator. Updated to avoid some noticed downstream failures.pull/3828/head
parent
9c1e3b8154
commit
5aa323dd29
|
@ -827,6 +827,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
if (isa<mlir::FloatType>(dtype)) {
|
||||
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||
return b.create<arith::AddFOp>(loc, lhs, scaled);
|
||||
} else if (dtype.isInteger(1)) {
|
||||
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
|
||||
return b.create<arith::OrIOp>(loc, lhs, scaled);
|
||||
} else {
|
||||
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
|
||||
return b.create<arith::AddIOp>(loc, lhs, scaled);
|
||||
|
|
|
@ -685,6 +685,35 @@ def ElementwiseAddModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
# Addition is an interesting special case of a binary op, because under the hood
|
||||
# it carries a third scalar "alpha" parameter, which needs special handling.
|
||||
class ElementwiseAddBoolModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([4], torch.bool, True),
|
||||
([4], torch.bool, True),
|
||||
]
|
||||
)
|
||||
def forward(self, a, b):
|
||||
return a + b
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAddBoolModule())
|
||||
def ElementwiseAddBoolModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
torch.tensor([False, False, True, True]),
|
||||
torch.tensor([False, True, False, False]),
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue