mirror of https://github.com/llvm/torch-mlir
[tosa] aten.transpose.int support (#1017)
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>pull/1024/head
parent
f0c3b5a7ed
commit
d38f2cae5b
|
@ -169,4 +169,7 @@ TOSA_PASS_SET = {
|
||||||
"NumpyTRankNStaticModule_basic",
|
"NumpyTRankNStaticModule_basic",
|
||||||
"NumpyTRankNDynamicModule_basic",
|
"NumpyTRankNDynamicModule_basic",
|
||||||
"EmbeddingModuleI32Static_basic",
|
"EmbeddingModuleI32Static_basic",
|
||||||
|
"TModuleRank2_basic",
|
||||||
|
"TransposeIntModule_basic",
|
||||||
|
"TransposeIntNegDimsModule_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -2705,6 +2705,47 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
|
||||||
|
AtenTransposeIntOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||||
|
if (!selfType)
|
||||||
|
return op.emitError("Only tensor types are supported");
|
||||||
|
|
||||||
|
// Only statically resolvable values are currently supported
|
||||||
|
int64_t dim0, dim1;
|
||||||
|
if (!matchPattern(op.dim0(), m_TorchConstantInt(&dim0)))
|
||||||
|
return op->emitError("dim0 must be a Scalar constant");
|
||||||
|
|
||||||
|
if (!matchPattern(op.dim1(), m_TorchConstantInt(&dim1)))
|
||||||
|
return op->emitError("dim1 must be a Scalar constant");
|
||||||
|
|
||||||
|
dim0 = toPositiveDim(dim0, selfType.getRank());
|
||||||
|
dim1 = toPositiveDim(dim1, selfType.getRank());
|
||||||
|
|
||||||
|
auto selfRank = selfType.getRank();
|
||||||
|
if (!isValidDim(dim0, selfRank) || !isValidDim(dim1, selfRank))
|
||||||
|
return op->emitError("dim0 and dim1 must be less than tensor rank");
|
||||||
|
|
||||||
|
SmallVector<int32_t> transposeDims;
|
||||||
|
for (auto i = 0; i < selfType.getRank(); ++i)
|
||||||
|
transposeDims.push_back(i);
|
||||||
|
|
||||||
|
transposeDims[dim0] = dim1;
|
||||||
|
transposeDims[dim1] = dim0;
|
||||||
|
|
||||||
|
auto transposeDimsConst = mlir::tosa::getConstTensor<int32_t>(
|
||||||
|
rewriter, op.getOperation(), transposeDims, {selfType.getRank()});
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
||||||
|
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
||||||
|
transposeDimsConst.getValue());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename AtenOpT, typename TosaOpT>
|
template <typename AtenOpT, typename TosaOpT>
|
||||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
|
@ -3314,6 +3355,7 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
INSERT_ATENOP_PATTERN(AtenGeluBackwardOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenTransposeIntOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
|
Loading…
Reference in New Issue