mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add TorchToTosa lowering for aten.clone op (#1388)
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com> Co-authored-by: Suraj Sudhir <16977902+sjarus@users.noreply.github.com>pull/1395/head snapshot-20220921.603
parent
5090ac9359
commit
4ef6e69ed4
|
@ -204,6 +204,8 @@ MHLO_PASS_SET = {
|
||||||
# Write the TOSA set as a "passing" set as it is very early in development
|
# Write the TOSA set as a "passing" set as it is very early in development
|
||||||
# and very few tests work yet.
|
# and very few tests work yet.
|
||||||
TOSA_PASS_SET = {
|
TOSA_PASS_SET = {
|
||||||
|
"ElementwiseCloneContiguousModule_basic",
|
||||||
|
"ElementwiseCloneModule_basic",
|
||||||
"ElementwiseUnaryModule_basic",
|
"ElementwiseUnaryModule_basic",
|
||||||
"ElementwiseBinaryModule_basic",
|
"ElementwiseBinaryModule_basic",
|
||||||
"ElementwiseSigmoidModule_basic",
|
"ElementwiseSigmoidModule_basic",
|
||||||
|
@ -364,6 +366,7 @@ TOSA_PASS_SET = {
|
||||||
"ArgmaxModule_keepDim",
|
"ArgmaxModule_keepDim",
|
||||||
"ArgmaxModule_with_dim",
|
"ArgmaxModule_with_dim",
|
||||||
"_LogSoftmaxModuleStable_basic",
|
"_LogSoftmaxModuleStable_basic",
|
||||||
|
"LiftFreshCopyModule_basic",
|
||||||
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
|
||||||
"BroadcastToIdentityCaseStaticModule_basic",
|
"BroadcastToIdentityCaseStaticModule_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -3393,6 +3393,31 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Legalizes the torch.clone op.
|
||||||
|
template <typename AtenOpT>
|
||||||
|
class ConvertAtenCloneOp : public OpConversionPattern<AtenOpT> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||||
|
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
int64_t memoryFormat;
|
||||||
|
if (!op.memory_format().getType().template isa<Torch::NoneType>() &&
|
||||||
|
(!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)) ||
|
||||||
|
memoryFormat != torch_upstream::MemoryFormat::Contiguous)) {
|
||||||
|
return op.emitError(
|
||||||
|
"unimplemented: only default memory format is supported");
|
||||||
|
}
|
||||||
|
auto outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
|
->convertType(op.getType())
|
||||||
|
.template dyn_cast<TensorType>();
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, adaptor.self());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -3599,6 +3624,12 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
target.addIllegalOp<AtenOp>(); \
|
||||||
|
patterns.add<ConvertAtenCloneOp<AtenOp>>(typeConverter, context);
|
||||||
|
INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp);
|
||||||
|
#undef INSERT_CLONE_ATENOP_PATTERN
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
Loading…
Reference in New Issue