diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index bab4ad600..643df60dc 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -204,6 +204,8 @@ MHLO_PASS_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", "ElementwiseUnaryModule_basic", "ElementwiseBinaryModule_basic", "ElementwiseSigmoidModule_basic", @@ -364,6 +366,7 @@ TOSA_PASS_SET = { "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "_LogSoftmaxModuleStable_basic", + "LiftFreshCopyModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "BroadcastToIdentityCaseStaticModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index a6c9851d7..906f462bd 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3393,6 +3393,31 @@ public: } }; +// Legalizes the torch.clone op. +template +class ConvertAtenCloneOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int64_t memoryFormat; + if (!op.memory_format().getType().template isa() && + (!matchPattern(op.memory_format(), m_TorchConstantInt(&memoryFormat)) || + memoryFormat != torch_upstream::MemoryFormat::Contiguous)) { + return op.emitError( + "unimplemented: only default memory format is supported"); + } + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + rewriter.replaceOpWithNewOp(op, outType, adaptor.self()); + + return success(); + } +}; + } // namespace // ----------------------------------------------------------------------------- @@ -3599,6 +3624,12 @@ public: INSERT_ATENOP_PATTERN(AtenBroadcastToOp); #undef INSERT_ATENOP_PATTERN +#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); +#undef INSERT_CLONE_ATENOP_PATTERN + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure();