mirror of https://github.com/llvm/torch-mlir
remove slice_scatter from TorchToLinalg/DataMovement.cpp
parent
329d02061b
commit
1052419156
|
@ -1288,55 +1288,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenSliceScatterOp
|
||||
: public OpConversionPattern<AtenSliceScatterOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto input = adaptor.self();
|
||||
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
SmallVector<Value> strides;
|
||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||
AtenSliceScatterOpAdaptor>(
|
||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value src = adaptor.src();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
int64_t srcRank = srcType.getRank();
|
||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||
auto abstractSrcType =
|
||||
RankedTensorType::get(srcAbstractSizes, srcType.getElementType());
|
||||
Value abstractSrc =
|
||||
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
||||
|
||||
Value result = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, abstractSrc, input, offsets, resultShape, strides);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -1365,6 +1316,4 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenCopyOp>();
|
||||
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSliceScatterOp>();
|
||||
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -240,7 +240,7 @@ class ExampleArgs:
|
|||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', 'torch.aten.slice_scatter'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints'],
|
||||
OutputType.MHLO: [],
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue