From 105241915668c15f5b8eab216c103eb501220552 Mon Sep 17 00:00:00 2001 From: TanyoKwok Date: Wed, 23 Nov 2022 14:06:27 +0800 Subject: [PATCH] remove slice_scatter from TorchToLinalg/DataMovement.cpp --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 51 ------------------- python/torch_mlir/__init__.py | 2 +- 2 files changed, 1 insertion(+), 52 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 799d0f91f..20a94bc66 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1288,55 +1288,6 @@ public: }; } // namespace -namespace { -class ConvertAtenSliceScatterOp - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - - Location loc = op.getLoc(); - TypeConverter *typeConverter = getTypeConverter(); - - auto input = adaptor.self(); - - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); - - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; - if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { - return failure(); - } - - Value src = adaptor.src(); - auto srcType = src.getType().cast(); - int64_t srcRank = srcType.getRank(); - SmallVector srcAbstractSizes(srcRank, kUnknownSize); - auto abstractSrcType = - RankedTensorType::get(srcAbstractSizes, srcType.getElementType()); - Value abstractSrc = - rewriter.create(loc, abstractSrcType, src); - - Value result = rewriter.create( - loc, abstractSrc, input, offsets, resultShape, strides); - - rewriter.replaceOpWithNewOp(op, resultType, result); - - return success(); - } -}; -} // namespace - void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1365,6 +1316,4 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); } diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 9b135ad24..50f9aa537 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -240,7 +240,7 @@ class ExampleArgs: # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'], - OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', 'torch.aten.slice_scatter'], + OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints'], OutputType.MHLO: [], }