From 6660a26594dc82cd3dd6fc33c9269ff09ecd263a Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Thu, 28 Dec 2023 17:20:32 -0800 Subject: [PATCH] lower torch.aten.isinf to linalg (#2638) Co-authored-by: Rob Suderman --- .../TorchToLinalg/Uncategorized.cpp | 11 ++++++-- .../base_lazy_backend/shape_inference.cpp | 4 +++ projects/pt1/e2e_testing/xfail_sets.py | 4 ++- .../test_suite/elementwise.py | 25 +++++++++++++++++++ 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e947ae73a..0943534db 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -426,6 +426,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (isa(op)) return b.create(loc, payloadArgs[0]); + if (isa(op)){ + Value abs = b.create(loc, payloadArgs[0]); + Value infinity = b.create( + loc, b.getFloatAttr(abs.getType(), std::numeric_limits::infinity())); + return createEqual(b, loc, abs.getType(), abs, infinity); + } if (isa(op)) { auto negate = createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -1343,7 +1349,7 @@ public: AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -1992,7 +1998,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, - AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, + AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, + AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp>(); patterns.add(typeConverter, context); diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index d5458f9c4..244ee7b88 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -39,6 +39,10 @@ std::vector compute_shape_div(const at::Tensor& self, return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape_isinf(const at::Tensor& self) { + return {Shape(at::kBool, self.sizes().vec())}; +} + std::vector compute_shape_max_pool3d_with_indices( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6f683a43c..d6cb60e57 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1033,6 +1033,7 @@ TOSA_PASS_SET = { "ElementwiseAddScalarIntModule_basic", "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseAtenIsinfOpModule_basic", "ElementwiseAtenWhereSelfModule_basic", "ElementwiseBinaryModule_basic", "ElementwiseBinaryStaticShapeModule_basic", @@ -1328,6 +1329,8 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { "SliceWholeTensorModule_basic", "TensorFloatModule_basic", "TensorIntModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", }) - { ### Test failing in make_fx_tosa but not in tosa @@ -1489,5 +1492,4 @@ LTC_XFAIL_SET = { "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseIsinfModule_basic", } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 33c420a1c..15e45b52e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3385,6 +3385,31 @@ def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, high=2).bool()) +# ============================================================================== + +class ElementwiseAtenIsinfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isinf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsinfOpModule()) +def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils): + test_input = torch.tensor( + [ + [1, float('inf'), 2, float('-inf'), float('nan')], + [1, float('inf'), float('-inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + # ==============================================================================