mirror of https://github.com/llvm/torch-mlir
remove slice_scatter from TorchToLinalg/DataMovement.cpp
parent
329d02061b
commit
1052419156
|
@ -1288,55 +1288,6 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ConvertAtenSliceScatterOp
|
|
||||||
: public OpConversionPattern<AtenSliceScatterOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
|
||||||
TypeConverter *typeConverter = getTypeConverter();
|
|
||||||
|
|
||||||
auto input = adaptor.self();
|
|
||||||
|
|
||||||
RankedTensorType resultType =
|
|
||||||
typeConverter->convertType(op->getResult(0).getType())
|
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
SmallVector<Value> resultShape;
|
|
||||||
SmallVector<Value> offsets;
|
|
||||||
SmallVector<Value> strides;
|
|
||||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
|
||||||
AtenSliceScatterOpAdaptor>(
|
|
||||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value src = adaptor.src();
|
|
||||||
auto srcType = src.getType().cast<RankedTensorType>();
|
|
||||||
int64_t srcRank = srcType.getRank();
|
|
||||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
|
||||||
auto abstractSrcType =
|
|
||||||
RankedTensorType::get(srcAbstractSizes, srcType.getElementType());
|
|
||||||
Value abstractSrc =
|
|
||||||
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
|
||||||
|
|
||||||
Value result = rewriter.create<tensor::InsertSliceOp>(
|
|
||||||
loc, abstractSrc, input, offsets, resultShape, strides);
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target) {
|
||||||
|
@ -1365,6 +1316,4 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenCopyOp>();
|
target.addIllegalOp<AtenCopyOp>();
|
||||||
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
|
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenSliceScatterOp>();
|
|
||||||
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -240,7 +240,7 @@ class ExampleArgs:
|
||||||
# compiler where each backend can "own" its set of legal ops.
|
# compiler where each backend can "own" its set of legal ops.
|
||||||
BACKEND_LEGAL_OPS = {
|
BACKEND_LEGAL_OPS = {
|
||||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
||||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', 'torch.aten.slice_scatter'],
|
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints'],
|
||||||
OutputType.MHLO: [],
|
OutputType.MHLO: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue