lower torch.aten.isinf to linalg (#2638)

Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
pull/2708/head snapshot-20231229.1067
Xida Ren (Cedar) 2023-12-28 17:20:32 -08:00 committed by GitHub
parent 9fc212ea9a
commit 6660a26594
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 3 deletions

View File

@ -426,6 +426,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}
if (isa<AtenAbsOp>(op))
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
if (isa<AtenIsinfOp>(op)){
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
Value infinity = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(abs.getType(), std::numeric_limits<double>::infinity()));
return createEqual(b, loc, abs.getType(), abs, infinity);
}
if (isa<AtenSigmoidOp>(op)) {
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
b, converter, payloadArgs[0], op);
@ -1343,7 +1349,7 @@ public:
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
@ -1992,7 +1998,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);

View File

@ -39,6 +39,10 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
return {Shape(at::kBool, self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
const at::Tensor& self, at::IntArrayRef kernel_size,
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,

View File

@ -1033,6 +1033,7 @@ TOSA_PASS_SET = {
"ElementwiseAddScalarIntModule_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseAtenIsinfOpModule_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseBinaryModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
@ -1328,6 +1329,8 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
"SliceWholeTensorModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
}) - {
### Test failing in make_fx_tosa but not in tosa
@ -1489,5 +1492,4 @@ LTC_XFAIL_SET = {
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseIsinfModule_basic",
}

View File

@ -3385,6 +3385,31 @@ def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, high=2).bool())
# ==============================================================================
class ElementwiseAtenIsinfOpModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 5], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.isinf(x)
@register_test_case(module_factory=lambda: ElementwiseAtenIsinfOpModule())
def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils):
test_input = torch.tensor(
[
[1, float('inf'), 2, float('-inf'), float('nan')],
[1, float('inf'), float('-inf'), float('nan'), 3],
]
)
module.forward(test_input)
# ==============================================================================