Add lowering support for math::AbsIOp (#2875)

There is no lowering support for math::AbsIOp, so if the operand is an
integer type, it will fail to lower to math::AbsFOp since the op operand
#0 must be floating-point-like.
pull/2895/head
Avinash Sharma 2024-02-08 14:53:40 -08:00 committed by GitHub
parent 44f8f89826
commit 9659a436d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 7 deletions

View File

@ -424,8 +424,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
return createEqual(b, loc, floatDtype, self, zero);
}
if (isa<AtenAbsOp>(op))
if (isa<AtenAbsOp>(op)) {
if (payloadArgs[0].getType().isa<IntegerType>())
return b.create<math::AbsIOp>(loc, payloadArgs[0]);
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
}
if (isa<AtenIsinfOp>(op)) {
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
Value infinity = b.create<arith::ConstantOp>(

View File

@ -579,7 +579,8 @@ STABLEHLO_PASS_SET = {
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseAbsModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic",
@ -1060,7 +1061,8 @@ TOSA_PASS_SET = {
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"ElementwiseAbsModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"ElementwiseAddModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",

View File

@ -2113,7 +2113,7 @@ def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseAbsModule(torch.nn.Module):
class ElementwiseAbsFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -2127,9 +2127,31 @@ class ElementwiseAbsModule(torch.nn.Module):
return torch.abs(a)
@register_test_case(module_factory=lambda: ElementwiseAbsModule())
def ElementwiseAbsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0))
@register_test_case(module_factory=lambda: ElementwiseAbsFloatModule())
def ElementwiseAbsFloatModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[[-1.0, 0.0, 1.0]]]))
# ==============================================================================
class ElementwiseAbsIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.int64, True),
])
def forward(self, a):
return torch.abs(a)
@register_test_case(module_factory=lambda: ElementwiseAbsIntModule())
def ElementwiseAbsIntModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[[-1, 0, 1]]]))
# ==============================================================================