[MLIR][TORCH] Add TorchToTosa lowering for aten.clone op (#1388)

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>

Co-authored-by: Suraj Sudhir <16977902+sjarus@users.noreply.github.com>
pull/1395/head snapshot-20220921.603
Vivek Khandelwal 2022-09-21 03:37:46 +05:30 committed by GitHub
parent 5090ac9359
commit 4ef6e69ed4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 0 deletions

View File

@ -204,6 +204,8 @@ MHLO_PASS_SET = {
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseUnaryModule_basic", "ElementwiseUnaryModule_basic",
"ElementwiseBinaryModule_basic", "ElementwiseBinaryModule_basic",
"ElementwiseSigmoidModule_basic", "ElementwiseSigmoidModule_basic",
@ -364,6 +366,7 @@ TOSA_PASS_SET = {
"ArgmaxModule_keepDim", "ArgmaxModule_keepDim",
"ArgmaxModule_with_dim", "ArgmaxModule_with_dim",
"_LogSoftmaxModuleStable_basic", "_LogSoftmaxModuleStable_basic",
"LiftFreshCopyModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"BroadcastToIdentityCaseStaticModule_basic", "BroadcastToIdentityCaseStaticModule_basic",
} }

View File

@ -3393,6 +3393,31 @@ public:
} }
}; };
// Legalizes the torch.clone op.
template <typename AtenOpT>
class ConvertAtenCloneOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
int64_t memoryFormat;
if (!op.memory_format().getType().template isa<Torch::NoneType>() &&
(!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)) ||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
return op.emitError(
"unimplemented: only default memory format is supported");
}
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template dyn_cast<TensorType>();
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, adaptor.self());
return success();
}
};
} // namespace } // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -3599,6 +3624,12 @@ public:
INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenCloneOp<AtenOp>>(typeConverter, context);
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
#undef INSERT_CLONE_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))
return signalPassFailure(); return signalPassFailure();