[tosa] aten.transpose.int support (#1017)

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
pull/1024/head
Suraj Sudhir 2022-07-07 13:05:33 -07:00 committed by GitHub
parent f0c3b5a7ed
commit d38f2cae5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 0 deletions

View File

@ -169,4 +169,7 @@ TOSA_PASS_SET = {
"NumpyTRankNStaticModule_basic",
"NumpyTRankNDynamicModule_basic",
"EmbeddingModuleI32Static_basic",
"TModuleRank2_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
}

View File

@ -2705,6 +2705,47 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
AtenTransposeIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are supported");
// Only statically resolvable values are currently supported
int64_t dim0, dim1;
if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0)))
return op->emitError("dim0 must be a Scalar constant");
if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1)))
return op->emitError("dim1 must be a Scalar constant");
dim0 = toPositiveDim(dim0, selfType.getRank());
dim1 = toPositiveDim(dim1, selfType.getRank());
auto selfRank = selfType.getRank();
if (!isValidDim(dim0, selfRank) || !isValidDim(dim1, selfRank))
return op->emitError("dim0 and dim1 must be less than tensor rank");
SmallVector<int32_t> transposeDims;
for (auto i = 0; i < selfType.getRank(); ++i)
transposeDims.push_back(i);
transposeDims[dim0] = dim1;
transposeDims[dim1] = dim0;
auto transposeDimsConst = mlir::tosa::getConstTensor<int32_t>(
rewriter, op.getOperation(), transposeDims, {selfType.getRank()});
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
transposeDimsConst.getValue());
return success();
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public:
@ -3314,6 +3355,7 @@ public:
INSERT_ATENOP_PATTERN(AtenGeluOp);
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,