mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add TorchToTosa lowering for aten.clone op (#1388)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com> Co-authored-by: Suraj Sudhir <16977902+sjarus@users.noreply.github.com>pull/1395/head snapshot-20220921.603
parent
5090ac9359
commit
4ef6e69ed4
|
@ -204,6 +204,8 @@ MHLO_PASS_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"ElementwiseCloneContiguousModule_basic",
|
||||
"ElementwiseCloneModule_basic",
|
||||
"ElementwiseUnaryModule_basic",
|
||||
"ElementwiseBinaryModule_basic",
|
||||
"ElementwiseSigmoidModule_basic",
|
||||
|
@ -364,6 +366,7 @@ TOSA_PASS_SET = {
|
|||
"ArgmaxModule_keepDim",
|
||||
"ArgmaxModule_with_dim",
|
||||
"_LogSoftmaxModuleStable_basic",
|
||||
"LiftFreshCopyModule_basic",
|
||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||
"BroadcastToIdentityCaseStaticModule_basic",
|
||||
}
|
||||
|
|
|
@ -3393,6 +3393,31 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// Legalizes the torch.clone op.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenCloneOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
int64_t memoryFormat;
|
||||
if (!op.memory_format().getType().template isa<Torch::NoneType>() &&
|
||||
(!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)) ||
|
||||
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
|
||||
return op.emitError(
|
||||
"unimplemented: only default memory format is supported");
|
||||
}
|
||||
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template dyn_cast<TensorType>();
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, adaptor.self());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -3599,6 +3624,12 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenCloneOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
|
||||
#undef INSERT_CLONE_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
|
Loading…
Reference in New Issue