mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add lowering of `aten.slice_scatter` and
`aten.select_scatter` op. This commit adds: 1. Lowering of `aten.slice_scatter` op into `tensor.insert_slice` op. 2. Decomposes the `aten.select_scatter` op into `aten.slice_scater` op. Signed-Off-By: Prateek Gupta <gprateek93@gmail.com>pull/1035/head snapshot-20220711.530
parent
a08ff0d7f2
commit
2d75654b2c
|
@ -5292,6 +5292,32 @@ def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$src,
|
||||
Torch_IntType:$dim,
|
||||
Torch_IntType:$index
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSelectScatterOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenSelectScatterOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -5747,6 +5773,34 @@ def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSliceScatterOp : Torch_Op<"aten.slice_scatter", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$src,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchOptionalIntType:$start,
|
||||
AnyTorchOptionalIntType:$end,
|
||||
Torch_IntType:$step
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSliceScatterOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 6, 1);
|
||||
}
|
||||
void AtenSliceScatterOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 6, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLenTensorOp : Torch_Op<"aten.len.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -7,6 +7,10 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
|
@ -29,6 +33,94 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
static Value toPositiveValidDim(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value torchType,
|
||||
Value builtinType, Value valueForNone,
|
||||
Value dimSize) {
|
||||
if (torchType.getType().isa<Torch::NoneType>())
|
||||
return valueForNone;
|
||||
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
||||
Value positiveDim =
|
||||
toPositiveDimDynamic(rewriter, loc, builtinType, dimSizeAsInt);
|
||||
// startOrEnd < 0 ? 0 : startOrEnd
|
||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
||||
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
|
||||
Value startOrEndAtLeastZero =
|
||||
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
|
||||
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
|
||||
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
|
||||
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
|
||||
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
|
||||
|
||||
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
||||
}
|
||||
|
||||
template <typename OpTy, typename OpAdaptor>
|
||||
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value> &resultShape,
|
||||
SmallVector<Value> &offsets,
|
||||
SmallVector<Value> &strides) {
|
||||
Location loc = op.getLoc();
|
||||
auto input = adaptor.self();
|
||||
RankedTensorType inputType =
|
||||
input.getType().template cast<RankedTensorType>();
|
||||
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("unimplemented: dim is not constant");
|
||||
|
||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||
Value dimSize = inputShape[dim];
|
||||
|
||||
Value torchTypeStart = op.start();
|
||||
Value torchTypeEnd = op.end();
|
||||
Value builtinTypeStart = adaptor.start();
|
||||
Value builtinTypeEnd = adaptor.end();
|
||||
|
||||
if (torchTypeStart.getType().isa<OptionalType>() ||
|
||||
torchTypeEnd.getType().isa<OptionalType>())
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||
|
||||
int64_t step;
|
||||
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
|
||||
if (!op.step().getType().template isa<Torch::NoneType>())
|
||||
return op->emitError("unimplemented: step is not constant");
|
||||
step = 1;
|
||||
}
|
||||
|
||||
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
|
||||
builtinTypeStart, zero, dimSize);
|
||||
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
|
||||
dimSize, dimSize);
|
||||
|
||||
// end >= start ? end : start
|
||||
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, end, start);
|
||||
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
||||
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
|
||||
|
||||
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
||||
resultShape = getTensorSizes(rewriter, loc, input);
|
||||
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
||||
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
||||
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
||||
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
||||
resultShape[dim] = resultSize;
|
||||
|
||||
strides.resize(inputType.getRank(), one);
|
||||
offsets.resize(inputType.getRank(), zero);
|
||||
|
||||
offsets[dim] = start;
|
||||
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
||||
return success();
|
||||
}
|
||||
namespace {
|
||||
class ConvertAtenFlattenUsingIntsOp
|
||||
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
|
||||
|
@ -742,77 +834,19 @@ public:
|
|||
TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto input = adaptor.self();
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("unimplemented: dim is not constant");
|
||||
|
||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||
Value dimSize = inputShape[dim];
|
||||
|
||||
auto adjustStartOrEnd = [&](Value startOrEndTorchType,
|
||||
Value startOrEndBuiltin, Value valueForNone) {
|
||||
if (startOrEndTorchType.getType().isa<Torch::NoneType>())
|
||||
return valueForNone;
|
||||
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
|
||||
Value startOrEndToPositive =
|
||||
toPositiveDimDynamic(rewriter, loc, startOrEndBuiltin, dimSizeAsInt);
|
||||
// startOrEnd < 0 ? 0 : startOrEnd
|
||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
|
||||
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0);
|
||||
Value startOrEndAtLeastZero = rewriter.create<arith::SelectOp>(
|
||||
loc, predDimSltZero, cst0, startOrEndToPositive);
|
||||
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
|
||||
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
|
||||
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
|
||||
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
|
||||
|
||||
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
|
||||
};
|
||||
|
||||
if (op.start().getType().isa<OptionalType>() ||
|
||||
op.end().getType().isa<OptionalType>())
|
||||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
|
||||
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
|
||||
|
||||
// end >= start ? end : start
|
||||
Value endSgeStart = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, end, start);
|
||||
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
|
||||
|
||||
int64_t step;
|
||||
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
|
||||
if (!op.step().getType().isa<Torch::NoneType>())
|
||||
return op->emitError("unimplemented: step is not constant");
|
||||
step = 1;
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
SmallVector<Value> strides;
|
||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
|
||||
AtenSliceTensorOpAdaptor>(
|
||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
|
||||
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
|
||||
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
|
||||
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
|
||||
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
|
||||
resultSize =
|
||||
rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
|
||||
|
||||
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
|
||||
resultShape[dim] = resultSize;
|
||||
|
||||
SmallVector<Value> offsets(inputType.getRank(), zero);
|
||||
SmallVector<Value> strides(inputType.getRank(), one);
|
||||
offsets[dim] = start;
|
||||
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
|
||||
|
||||
Value result = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, input, offsets, resultShape, strides);
|
||||
|
||||
|
@ -1019,6 +1053,55 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenSliceScatterOp
|
||||
: public OpConversionPattern<AtenSliceScatterOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto input = adaptor.self();
|
||||
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
SmallVector<Value> strides;
|
||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||
AtenSliceScatterOpAdaptor>(
|
||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value src = adaptor.src();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
int64_t srcRank = srcType.getRank();
|
||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||
auto abstractSrcType =
|
||||
RankedTensorType::get(srcAbstractSizes, srcType.getElementType());
|
||||
Value abstractSrc =
|
||||
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
||||
|
||||
Value result = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, abstractSrc, input, offsets, resultShape, strides);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -1047,4 +1130,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<ValsemVariantAtenCopyOp>();
|
||||
patterns.add<ConvertValsemVariantAtenCopyOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSliceScatterOp>();
|
||||
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -16,7 +16,9 @@
|
|||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include <cstdint>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -2120,6 +2122,55 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose the `aten.select_scatter` operation into `aten.slice_scatter` op.
|
||||
class DecomposeAtenSelectScatterOp
|
||||
: public OpRewritePattern<AtenSelectScatterOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSelectScatterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value start = op.index();
|
||||
Value dim = op.dim();
|
||||
Value self = op.self();
|
||||
Value src = op.src();
|
||||
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value startPlusOne =
|
||||
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
||||
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
|
||||
SmallVector<int64_t> sizes;
|
||||
if (!srcTensorType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "src tensor must have size");
|
||||
|
||||
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
|
||||
// `src` has a reduced rank. Hence add 1.
|
||||
int64_t srcRank = srcShape.size() + 1;
|
||||
int64_t dimInt = 0;
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, srcRank);
|
||||
if (!isValidDim(dimInt, srcRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
|
||||
sizes.append(srcShape.begin(), srcShape.end());
|
||||
sizes.insert(sizes.begin() + dimInt, 1);
|
||||
|
||||
} else {
|
||||
sizes.resize(srcShape.size() + 1, kUnknownSize);
|
||||
}
|
||||
Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
|
||||
srcTensorType.getDtype());
|
||||
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
|
||||
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
|
||||
op, op.self().getType(), self, src, dim, start, startPlusOne,
|
||||
/*step=*/one);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -2271,6 +2322,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenFloorDivideOp>();
|
||||
patterns.add<DecomposeAtenNumpyTOp>(context);
|
||||
target.addIllegalOp<AtenNumpyTOp>();
|
||||
patterns.add<DecomposeAtenSelectScatterOp>(context);
|
||||
target.addIllegalOp<AtenSelectScatterOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -645,10 +645,11 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp,
|
||||
AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp,
|
||||
AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
|
||||
AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
|
||||
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
|
||||
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
|
||||
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
|
||||
AtenSelectScatterOp, AtenSliceTensorOp, AtenSliceScatterOp,
|
||||
AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp,
|
||||
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
|
||||
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
|
||||
ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
|
||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
|
||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -904,9 +904,15 @@ def aten〇batch_norm(input: List[int], weight: Optional[List[int]], bias: Optio
|
|||
def aten〇slice〇Tensor(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||
return upstream_shape_functions.slice(self, dim, start, end, step)
|
||||
|
||||
def aten〇slice_scatter(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇select〇int(self: List[int], dim: int, index: int) -> List[int]:
|
||||
return upstream_shape_functions.select(self, dim, index)
|
||||
|
||||
def aten〇select_scatter(self: List[int], src: List[int], dim: int, index: int) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇index_select(self: List[int], dim: int, index: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.index_select(self, dim, index)
|
||||
|
||||
|
|
|
@ -437,6 +437,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
|
||||
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
|
||||
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
|
||||
emit("aten::stack : (Tensor[], int) -> (Tensor)")
|
||||
emit("aten::sum : (Tensor, int?) -> (Tensor)")
|
||||
|
@ -455,6 +456,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)")
|
||||
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::len.Tensor : (Tensor) -> (int)")
|
||||
emit("aten::cpu : (Tensor) -> (Tensor)")
|
||||
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
||||
|
|
|
@ -232,3 +232,112 @@ def SelectIntModule_basic(module, tu: TestUtils):
|
|||
module.forward(torch.randint(10, (5,5)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
|
||||
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
|
||||
class SliceScatterModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceScatterModule())
|
||||
def SliceScatterModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(6, 1))
|
||||
|
||||
class SliceScatterZeroDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 0, end = 1, step = 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceScatterZeroDimModule())
|
||||
def SliceScatterZeroDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(1, 8))
|
||||
|
||||
class SliceScatterStepVariationModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceScatterStepVariationModule())
|
||||
def SliceScatterStepVariationModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(6, 1))
|
||||
|
||||
class SliceScatterStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([6, 8], torch.float32, True),
|
||||
([6, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SliceScatterStaticModule())
|
||||
def SliceScatterStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 8), tu.rand(6, 1))
|
||||
|
||||
class SelectScatterModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.select_scatter(x, src, dim = 0, index = 0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SelectScatterModule())
|
||||
def SelectScattertModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(6, 8, 5), torch.rand(8, 5))
|
||||
|
||||
class SelectScatterStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([6, 8, 5], torch.float32, True),
|
||||
([6, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.select_scatter(x, src, dim = 1, index = 0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SelectScatterStaticModule())
|
||||
def SelectScattertStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(6, 8, 5), torch.rand(6, 5))
|
||||
|
|
|
@ -1087,3 +1087,21 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int
|
|||
%2 = torch.aten.repeat %arg0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
|
||||
return %2 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.select_scatter
|
||||
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?],f32>, %[[SRC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK-NEXT: %[[START:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[DIM:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[STEP:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[END:.*]] = torch.aten.add.int %[[START]], %[[STEP]]
|
||||
// CHECK-NEXT: %[[UNSQUEEZE_SRC:.*]] = torch.aten.unsqueeze %[[SRC]], %[[DIM]]
|
||||
// CHECK-NEXT: %[[SLICE_SCATTER:.*]] = torch.aten.slice_scatter %[[SELF]], %[[UNSQUEEZE_SRC]], %[[DIM]], %[[START]], %[[END]], %[[STEP]]
|
||||
// CHECK-NEXT: return %[[SLICE_SCATTER]]
|
||||
// CHECK-NEXT: }
|
||||
func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.select_scatter %arg0, %arg1, %int1, %int0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue