From 48383554da9582b7dc8337f6d7d20dfd78db9068 Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Fri, 14 Jul 2023 08:23:10 +0200 Subject: [PATCH] TorchToTosa: Legalization for torch.aten.sqrt (#2234) --- e2e_testing/xfail_sets.py | 2 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 30 ++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 6a9a981c7..d5dddd3c9 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -1108,6 +1108,8 @@ TOSA_PASS_SET = { "TensorsConcatStaticModule_basic", "TensorsConcatNegativeDimStaticModule_basic", "AtenComplex64Module_basic", + "ElementwiseSqrtIntModule_basic", + "ElementwiseSqrtModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorListUnpackModule_basic", "ChunkListUnpack_Module_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index f5a961cfd..4ec6fe758 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -4563,6 +4563,35 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Converts AtenSqrtOp into pow(x, 0.5) + auto self = adaptor.getSelf(); + auto selfTy = self.getType().dyn_cast(); + if (!selfTy) + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); + + auto resultType = typeConverter->convertType(op.getType()) + .template cast(); + auto elementType = resultType.getElementType(); + + if (isa(selfTy.getElementType())) { + self = rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(resultType.getShape(), elementType), + self); + } + + auto oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, elementType).value(); + + rewriter.replaceOpWithNewOp(op, resultType, self, oneHalf); + return success(); +} + } // namespace // ----------------------------------------------------------------------------- @@ -4792,6 +4821,7 @@ public: INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenSqrtOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \