mirror of https://github.com/llvm/torch-mlir
[linalg] Fix torch.aten.add of `torch.bool` (#3820)
Addition of bools saturate which equates to an `or` operator. Updated to avoid some noticed downstream failures.pull/3828/head
parent
9c1e3b8154
commit
5aa323dd29
|
@ -827,6 +827,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
if (isa<mlir::FloatType>(dtype)) {
|
if (isa<mlir::FloatType>(dtype)) {
|
||||||
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||||
return b.create<arith::AddFOp>(loc, lhs, scaled);
|
return b.create<arith::AddFOp>(loc, lhs, scaled);
|
||||||
|
} else if (dtype.isInteger(1)) {
|
||||||
|
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
|
||||||
|
return b.create<arith::OrIOp>(loc, lhs, scaled);
|
||||||
} else {
|
} else {
|
||||||
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
|
Value scaled = b.create<arith::MulIOp>(loc, rhs, alpha);
|
||||||
return b.create<arith::AddIOp>(loc, lhs, scaled);
|
return b.create<arith::AddIOp>(loc, lhs, scaled);
|
||||||
|
|
|
@ -685,6 +685,35 @@ def ElementwiseAddModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
# Addition is an interesting special case of a binary op, because under the hood
|
||||||
|
# it carries a third scalar "alpha" parameter, which needs special handling.
|
||||||
|
class ElementwiseAddBoolModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([4], torch.bool, True),
|
||||||
|
([4], torch.bool, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAddBoolModule())
|
||||||
|
def ElementwiseAddBoolModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
torch.tensor([False, False, True, True]),
|
||||||
|
torch.tensor([False, True, False, False]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
|
class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue