mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add lowering of `aten.slice_scatter` and
`aten.select_scatter` op. This commit adds: 1. Lowering of `aten.slice_scatter` op into `tensor.insert_slice` op. 2. Decomposes the `aten.select_scatter` op into `aten.slice_scater` op. Signed-Off-By: Prateek Gupta <gprateek93@gmail.com>pull/1035/head snapshot-20220711.530
parent
a08ff0d7f2
commit
2d75654b2c
|
@ -5292,6 +5292,32 @@ def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$src,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
Torch_IntType:$index
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenSelectScatterOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenSelectScatterOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
|
def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -5747,6 +5773,34 @@ def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenSliceScatterOp : Torch_Op<"aten.slice_scatter", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$src,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
AnyTorchOptionalIntType:$start,
|
||||||
|
AnyTorchOptionalIntType:$end,
|
||||||
|
Torch_IntType:$step
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenSliceScatterOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||||
|
}
|
||||||
|
void AtenSliceScatterOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 6, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenLenTensorOp : Torch_Op<"aten.len.Tensor", [
|
def Torch_AtenLenTensorOp : Torch_Op<"aten.len.Tensor", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -7,6 +7,10 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/TypeSupport.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
|
@ -29,6 +33,94 @@ using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
|
static Value toPositiveValidDim(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, Value torchType,
|
||||||
|
Value builtinType, Value valueForNone,
|
||||||
|
Value dimSize) {
|
||||||
|
if (torchType.getType().isa<Torch::NoneType>())
|
||||||
|
return valueForNone;
|
||||||
|
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
||||||
|
Value positiveDim =
|
||||||
|
toPositiveDimDynamic(rewriter, loc, builtinType, dimSizeAsInt);
|
||||||
|
// startOrEnd < 0 ? 0 : startOrEnd
|
||||||
|
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
||||||
|
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
|
||||||
|
Value startOrEndAtLeastZero =
|
||||||
|
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
|
||||||
|
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
|
||||||
|
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
|
||||||
|
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
|
||||||
|
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
|
||||||
|
|
||||||
|
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OpTy, typename OpAdaptor>
|
||||||
|
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter,
|
||||||
|
SmallVector<Value> &resultShape,
|
||||||
|
SmallVector<Value> &offsets,
|
||||||
|
SmallVector<Value> &strides) {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
auto input = adaptor.self();
|
||||||
|
RankedTensorType inputType =
|
||||||
|
input.getType().template cast<RankedTensorType>();
|
||||||
|
|
||||||
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||||
|
return op->emitError("unimplemented: dim is not constant");
|
||||||
|
|
||||||
|
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||||
|
Value dimSize = inputShape[dim];
|
||||||
|
|
||||||
|
Value torchTypeStart = op.start();
|
||||||
|
Value torchTypeEnd = op.end();
|
||||||
|
Value builtinTypeStart = adaptor.start();
|
||||||
|
Value builtinTypeEnd = adaptor.end();
|
||||||
|
|
||||||
|
if (torchTypeStart.getType().isa<OptionalType>() ||
|
||||||
|
torchTypeEnd.getType().isa<OptionalType>())
|
||||||
|
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||||
|
|
||||||
|
int64_t step;
|
||||||
|
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
|
||||||
|
if (!op.step().getType().template isa<Torch::NoneType>())
|
||||||
|
return op->emitError("unimplemented: step is not constant");
|
||||||
|
step = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
|
||||||
|
builtinTypeStart, zero, dimSize);
|
||||||
|
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
|
||||||
|
dimSize, dimSize);
|
||||||
|
|
||||||
|
// end >= start ? end : start
|
||||||
|
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||||||
|
loc, arith::CmpIPredicate::sge, end, start);
|
||||||
|
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
||||||
|
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
|
||||||
|
|
||||||
|
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
||||||
|
resultShape = getTensorSizes(rewriter, loc, input);
|
||||||
|
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
||||||
|
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
||||||
|
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
||||||
|
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
||||||
|
resultShape[dim] = resultSize;
|
||||||
|
|
||||||
|
strides.resize(inputType.getRank(), one);
|
||||||
|
offsets.resize(inputType.getRank(), zero);
|
||||||
|
|
||||||
|
offsets[dim] = start;
|
||||||
|
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenFlattenUsingIntsOp
|
class ConvertAtenFlattenUsingIntsOp
|
||||||
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
|
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
|
||||||
|
@ -742,77 +834,19 @@ public:
|
||||||
TypeConverter *typeConverter = getTypeConverter();
|
TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
|
||||||
auto input = adaptor.self();
|
auto input = adaptor.self();
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType =
|
||||||
typeConverter->convertType(op->getResult(0).getType())
|
typeConverter->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
|
||||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
|
||||||
|
|
||||||
int64_t dim;
|
SmallVector<Value> resultShape;
|
||||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
SmallVector<Value> offsets;
|
||||||
return op->emitError("unimplemented: dim is not constant");
|
SmallVector<Value> strides;
|
||||||
|
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
|
||||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
AtenSliceTensorOpAdaptor>(
|
||||||
Value dimSize = inputShape[dim];
|
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||||
|
return failure();
|
||||||
auto adjustStartOrEnd = [&](Value startOrEndTorchType,
|
|
||||||
Value startOrEndBuiltin, Value valueForNone) {
|
|
||||||
if (startOrEndTorchType.getType().isa<Torch::NoneType>())
|
|
||||||
return valueForNone;
|
|
||||||
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
|
||||||
Value startOrEndToPositive =
|
|
||||||
toPositiveDimDynamic(rewriter, loc, startOrEndBuiltin, dimSizeAsInt);
|
|
||||||
// startOrEnd < 0 ? 0 : startOrEnd
|
|
||||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
|
||||||
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
|
||||||
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0);
|
|
||||||
Value startOrEndAtLeastZero = rewriter.create<arith::SelectOp>(
|
|
||||||
loc, predDimSltZero, cst0, startOrEndToPositive);
|
|
||||||
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
|
|
||||||
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
|
|
||||||
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
|
|
||||||
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
|
|
||||||
|
|
||||||
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
|
||||||
};
|
|
||||||
|
|
||||||
if (op.start().getType().isa<OptionalType>() ||
|
|
||||||
op.end().getType().isa<OptionalType>())
|
|
||||||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
|
||||||
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
|
|
||||||
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
|
|
||||||
|
|
||||||
// end >= start ? end : start
|
|
||||||
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
|
||||||
loc, arith::CmpIPredicate::sge, end, start);
|
|
||||||
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
|
||||||
|
|
||||||
int64_t step;
|
|
||||||
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
|
|
||||||
if (!op.step().getType().isa<Torch::NoneType>())
|
|
||||||
return op->emitError("unimplemented: step is not constant");
|
|
||||||
step = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
|
||||||
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
|
|
||||||
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
|
||||||
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
|
||||||
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
|
||||||
resultSize =
|
|
||||||
rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
|
||||||
|
|
||||||
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
|
|
||||||
resultShape[dim] = resultSize;
|
|
||||||
|
|
||||||
SmallVector<Value> offsets(inputType.getRank(), zero);
|
|
||||||
SmallVector<Value> strides(inputType.getRank(), one);
|
|
||||||
offsets[dim] = start;
|
|
||||||
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
|
||||||
|
|
||||||
Value result = rewriter.create<tensor::ExtractSliceOp>(
|
Value result = rewriter.create<tensor::ExtractSliceOp>(
|
||||||
loc, input, offsets, resultShape, strides);
|
loc, input, offsets, resultShape, strides);
|
||||||
|
|
||||||
|
@ -1019,6 +1053,55 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenSliceScatterOp
|
||||||
|
: public OpConversionPattern<AtenSliceScatterOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
|
||||||
|
auto input = adaptor.self();
|
||||||
|
|
||||||
|
RankedTensorType resultType =
|
||||||
|
typeConverter->convertType(op->getResult(0).getType())
|
||||||
|
.cast<RankedTensorType>();
|
||||||
|
|
||||||
|
SmallVector<Value> resultShape;
|
||||||
|
SmallVector<Value> offsets;
|
||||||
|
SmallVector<Value> strides;
|
||||||
|
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||||
|
AtenSliceScatterOpAdaptor>(
|
||||||
|
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value src = adaptor.src();
|
||||||
|
auto srcType = src.getType().cast<RankedTensorType>();
|
||||||
|
int64_t srcRank = srcType.getRank();
|
||||||
|
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||||
|
auto abstractSrcType =
|
||||||
|
RankedTensorType::get(srcAbstractSizes, srcType.getElementType());
|
||||||
|
Value abstractSrc =
|
||||||
|
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
||||||
|
|
||||||
|
Value result = rewriter.create<tensor::InsertSliceOp>(
|
||||||
|
loc, abstractSrc, input, offsets, resultShape, strides);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target) {
|
||||||
|
@ -1047,4 +1130,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||||
target.addIllegalOp<ValsemVariantAtenCopyOp>();
|
target.addIllegalOp<ValsemVariantAtenCopyOp>();
|
||||||
patterns.add<ConvertValsemVariantAtenCopyOp>(typeConverter, context);
|
patterns.add<ConvertValsemVariantAtenCopyOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenSliceScatterOp>();
|
||||||
|
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,9 @@
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
#include "llvm/ADT/StringExtras.h"
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -2120,6 +2122,55 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose the `aten.select_scatter` operation into `aten.slice_scatter` op.
|
||||||
|
class DecomposeAtenSelectScatterOp
|
||||||
|
: public OpRewritePattern<AtenSelectScatterOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenSelectScatterOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value start = op.index();
|
||||||
|
Value dim = op.dim();
|
||||||
|
Value self = op.self();
|
||||||
|
Value src = op.src();
|
||||||
|
|
||||||
|
Value one =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
Value startPlusOne =
|
||||||
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
||||||
|
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
|
||||||
|
SmallVector<int64_t> sizes;
|
||||||
|
if (!srcTensorType.hasSizes())
|
||||||
|
return rewriter.notifyMatchFailure(op, "src tensor must have size");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
|
||||||
|
// `src` has a reduced rank. Hence add 1.
|
||||||
|
int64_t srcRank = srcShape.size() + 1;
|
||||||
|
int64_t dimInt = 0;
|
||||||
|
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||||
|
dimInt = toPositiveDim(dimInt, srcRank);
|
||||||
|
if (!isValidDim(dimInt, srcRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||||
|
|
||||||
|
sizes.append(srcShape.begin(), srcShape.end());
|
||||||
|
sizes.insert(sizes.begin() + dimInt, 1);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
sizes.resize(srcShape.size() + 1, kUnknownSize);
|
||||||
|
}
|
||||||
|
Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
|
||||||
|
srcTensorType.getDtype());
|
||||||
|
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
||||||
|
op, op.self().getType(), self, src, dim, start, startPlusOne,
|
||||||
|
/*step=*/one);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -2271,6 +2322,8 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<AtenFloorDivideOp>();
|
target.addIllegalOp<AtenFloorDivideOp>();
|
||||||
patterns.add<DecomposeAtenNumpyTOp>(context);
|
patterns.add<DecomposeAtenNumpyTOp>(context);
|
||||||
target.addIllegalOp<AtenNumpyTOp>();
|
target.addIllegalOp<AtenNumpyTOp>();
|
||||||
|
patterns.add<DecomposeAtenSelectScatterOp>(context);
|
||||||
|
target.addIllegalOp<AtenSelectScatterOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
|
|
@ -645,10 +645,11 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp,
|
AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp,
|
||||||
AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp,
|
AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp,
|
||||||
AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
|
AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
|
||||||
AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
|
AtenSelectScatterOp, AtenSliceTensorOp, AtenSliceScatterOp,
|
||||||
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
|
AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp,
|
||||||
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
|
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
|
||||||
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
|
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||||
|
ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
|
||||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
|
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -904,9 +904,15 @@ def aten〇batch_norm(input: List[int], weight: Optional[List[int]], bias: Optio
|
||||||
def aten〇slice〇Tensor(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
def aten〇slice〇Tensor(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||||
return upstream_shape_functions.slice(self, dim, start, end, step)
|
return upstream_shape_functions.slice(self, dim, start, end, step)
|
||||||
|
|
||||||
|
def aten〇slice_scatter(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
def aten〇select〇int(self: List[int], dim: int, index: int) -> List[int]:
|
def aten〇select〇int(self: List[int], dim: int, index: int) -> List[int]:
|
||||||
return upstream_shape_functions.select(self, dim, index)
|
return upstream_shape_functions.select(self, dim, index)
|
||||||
|
|
||||||
|
def aten〇select_scatter(self: List[int], src: List[int], dim: int, index: int) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
def aten〇index_select(self: List[int], dim: int, index: List[int]) -> List[int]:
|
def aten〇index_select(self: List[int], dim: int, index: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.index_select(self, dim, index)
|
return upstream_shape_functions.index_select(self, dim, index)
|
||||||
|
|
||||||
|
|
|
@ -437,6 +437,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
|
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
|
||||||
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
|
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
|
||||||
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
|
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
|
||||||
|
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||||
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
||||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||||
emit("aten::sum : (Tensor, int?) -> (Tensor)")
|
emit("aten::sum : (Tensor, int?) -> (Tensor)")
|
||||||
|
@ -455,6 +456,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)")
|
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)")
|
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)")
|
||||||
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
|
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
|
||||||
|
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||||
emit("aten::len.Tensor : (Tensor) -> (int)")
|
emit("aten::len.Tensor : (Tensor) -> (int)")
|
||||||
emit("aten::cpu : (Tensor) -> (Tensor)")
|
emit("aten::cpu : (Tensor) -> (Tensor)")
|
||||||
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
||||||
|
|
|
@ -232,3 +232,112 @@ def SelectIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(10, (5,5)))
|
module.forward(torch.randint(10, (5,5)))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
|
||||||
|
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
|
||||||
|
class SliceScatterModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, src):
|
||||||
|
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SliceScatterModule())
|
||||||
|
def SliceScatterModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 8), tu.rand(6, 1))
|
||||||
|
|
||||||
|
class SliceScatterZeroDimModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, src):
|
||||||
|
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 0, end = 1, step = 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SliceScatterZeroDimModule())
|
||||||
|
def SliceScatterZeroDimModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 8), tu.rand(1, 8))
|
||||||
|
|
||||||
|
class SliceScatterStepVariationModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, src):
|
||||||
|
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 2)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SliceScatterStepVariationModule())
|
||||||
|
def SliceScatterStepVariationModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 8), tu.rand(6, 1))
|
||||||
|
|
||||||
|
class SliceScatterStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([6, 8], torch.float32, True),
|
||||||
|
([6, 1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, src):
|
||||||
|
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SliceScatterStaticModule())
|
||||||
|
def SliceScatterStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(6, 8), tu.rand(6, 1))
|
||||||
|
|
||||||
|
class SelectScatterModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, src):
|
||||||
|
return torch.ops.aten.select_scatter(x, src, dim = 0, index = 0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SelectScatterModule())
|
||||||
|
def SelectScattertModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.rand(6, 8, 5), torch.rand(8, 5))
|
||||||
|
|
||||||
|
class SelectScatterStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([6, 8, 5], torch.float32, True),
|
||||||
|
([6, 5], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, src):
|
||||||
|
return torch.ops.aten.select_scatter(x, src, dim = 1, index = 0)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SelectScatterStaticModule())
|
||||||
|
def SelectScattertStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.rand(6, 8, 5), torch.rand(6, 5))
|
||||||
|
|
|
@ -1087,3 +1087,21 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int
|
||||||
%2 = torch.aten.repeat %arg0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
|
%2 = torch.aten.repeat %arg0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
|
||||||
return %2 : !torch.vtensor<[?,?,?],f32>
|
return %2 : !torch.vtensor<[?,?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @torch.aten.select_scatter
|
||||||
|
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?],f32>, %[[SRC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK-NEXT: %[[START:.*]] = torch.constant.int 0
|
||||||
|
// CHECK-NEXT: %[[DIM:.*]] = torch.constant.int 1
|
||||||
|
// CHECK-NEXT: %[[STEP:.*]] = torch.constant.int 1
|
||||||
|
// CHECK-NEXT: %[[END:.*]] = torch.aten.add.int %[[START]], %[[STEP]]
|
||||||
|
// CHECK-NEXT: %[[UNSQUEEZE_SRC:.*]] = torch.aten.unsqueeze %[[SRC]], %[[DIM]]
|
||||||
|
// CHECK-NEXT: %[[SLICE_SCATTER:.*]] = torch.aten.slice_scatter %[[SELF]], %[[UNSQUEEZE_SRC]], %[[DIM]], %[[START]], %[[END]], %[[STEP]]
|
||||||
|
// CHECK-NEXT: return %[[SLICE_SCATTER]]
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%0 = torch.aten.select_scatter %arg0, %arg1, %int1, %int0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue