From d38f2cae5bcf833f062241f0254a1d3d09ab2609 Mon Sep 17 00:00:00 2001 From: Suraj Sudhir <16977902+sjarus@users.noreply.github.com> Date: Thu, 7 Jul 2022 13:05:33 -0700 Subject: [PATCH] [tosa] aten.transpose.int support (#1017) Signed-off-by: Suraj Sudhir --- e2e_testing/torchscript/xfail_sets.py | 3 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 42 ++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index b03f3abac..214b8df3c 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -169,4 +169,7 @@ TOSA_PASS_SET = { "NumpyTRankNStaticModule_basic", "NumpyTRankNDynamicModule_basic", "EmbeddingModuleI32Static_basic", + "TModuleRank2_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index fe97340a7..b3478d84e 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2705,6 +2705,47 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTransposeIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto selfType = adaptor.self().getType().dyn_cast(); + if (!selfType) + return op.emitError("Only tensor types are supported"); + + // Only statically resolvable values are currently supported + int64_t dim0, dim1; + if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0))) + return op->emitError("dim0 must be a Scalar constant"); + + if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1))) + return op->emitError("dim1 must be a Scalar constant"); + + dim0 = toPositiveDim(dim0, selfType.getRank()); + dim1 = toPositiveDim(dim1, selfType.getRank()); + + auto selfRank = selfType.getRank(); + if (!isValidDim(dim0, selfRank) || !isValidDim(dim1, selfRank)) + return op->emitError("dim0 and dim1 must be less than tensor rank"); + + SmallVector transposeDims; + for (auto i = 0; i < selfType.getRank(); ++i) + transposeDims.push_back(i); + + transposeDims[dim0] = dim1; + transposeDims[dim1] = dim0; + + auto transposeDimsConst = mlir::tosa::getConstTensor( + rewriter, op.getOperation(), transposeDims, {selfType.getRank()}); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.self(), + transposeDimsConst.getValue()); + + return success(); +} + template class ConvertAtenPoolingBaseOp : public OpConversionPattern { public: @@ -3314,6 +3355,7 @@ public: INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenTransposeIntOp); #undef INSERT_ATENOP_PATTERN if (failed(applyPartialConversion(getOperation(), target,