remove slice_scatter from TorchToLinalg/DataMovement.cpp

tanyo/slice_scatter_stage
TanyoKwok 2022-11-23 14:06:27 +08:00
parent 329d02061b
commit 1052419156
2 changed files with 1 additions and 52 deletions

View File

@ -1288,55 +1288,6 @@ public:
};
} // namespace
namespace {
class ConvertAtenSliceScatterOp
: public OpConversionPattern<AtenSliceScatterOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.self();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure();
}
Value src = adaptor.src();
auto srcType = src.getType().cast<RankedTensorType>();
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
auto abstractSrcType =
RankedTensorType::get(srcAbstractSizes, srcType.getElementType());
Value abstractSrc =
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
Value result = rewriter.create<tensor::InsertSliceOp>(
loc, abstractSrc, input, offsets, resultShape, strides);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
@ -1365,6 +1316,4 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<AtenCopyOp>();
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
target.addIllegalOp<AtenSliceScatterOp>();
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
}

View File

@ -240,7 +240,7 @@ class ExampleArgs:
# compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', 'torch.aten.slice_scatter'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints'],
OutputType.MHLO: [],
}