mirror of https://github.com/llvm/torch-mlir
Add lowering support for math::AbsIOp (#2875)
There is no lowering support for math::AbsIOp, so if the operand is an integer type, it will fail to lower to math::AbsFOp since the op operand #0 must be floating-point-like.pull/2895/head
parent
44f8f89826
commit
9659a436d1
|
@ -424,8 +424,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
|
||||
return createEqual(b, loc, floatDtype, self, zero);
|
||||
}
|
||||
if (isa<AtenAbsOp>(op))
|
||||
if (isa<AtenAbsOp>(op)) {
|
||||
if (payloadArgs[0].getType().isa<IntegerType>())
|
||||
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
|
||||
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
||||
}
|
||||
if (isa<AtenIsinfOp>(op)) {
|
||||
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
||||
Value infinity = b.create<arith::ConstantOp>(
|
||||
|
|
|
@ -579,7 +579,8 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseSubScalarFloatModule_basic",
|
||||
"ElementwiseSubScalarIntModule_basic",
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"ElementwiseAbsModule_basic",
|
||||
"ElementwiseAbsFloatModule_basic",
|
||||
"ElementwiseAbsIntModule_basic",
|
||||
"EmbeddingModule1DIndices_basic",
|
||||
"EmbeddingModuleI32Static_basic",
|
||||
"EmbeddingModuleI32_basic",
|
||||
|
@ -1060,7 +1061,8 @@ TOSA_PASS_SET = {
|
|||
"EinsumStaticContractRhsModule_basic",
|
||||
"EinsumStaticFourDimensionModule_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
"ElementwiseAbsModule_basic",
|
||||
"ElementwiseAbsFloatModule_basic",
|
||||
"ElementwiseAbsIntModule_basic",
|
||||
"ElementwiseAddModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarInt64Module_basic",
|
||||
|
|
|
@ -2113,7 +2113,7 @@ def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAbsModule(torch.nn.Module):
|
||||
class ElementwiseAbsFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -2127,9 +2127,31 @@ class ElementwiseAbsModule(torch.nn.Module):
|
|||
return torch.abs(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAbsModule())
|
||||
def ElementwiseAbsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0))
|
||||
@register_test_case(module_factory=lambda: ElementwiseAbsFloatModule())
|
||||
def ElementwiseAbsFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[[-1.0, 0.0, 1.0]]]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAbsIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.abs(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAbsIntModule())
|
||||
def ElementwiseAbsIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[[-1, 0, 1]]]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
Loading…
Reference in New Issue