[TORCH][MLIR] Add lowering of `aten.slice_scatter` and

`aten.select_scatter` op.

This commit adds:
1.  Lowering of `aten.slice_scatter` op into `tensor.insert_slice`
op.
2. Decomposes the `aten.select_scatter` op into `aten.slice_scater`
op.

Signed-Off-By: Prateek Gupta <gprateek93@gmail.com>
pull/1035/head snapshot-20220711.530
Prateek Gupta 2022-05-10 13:15:59 +00:00
parent a08ff0d7f2
commit 2d75654b2c
9 changed files with 766 additions and 432 deletions

View File

@ -5292,6 +5292,32 @@ def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [
}];
}
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$src,
Torch_IntType:$dim,
Torch_IntType:$index
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSelectScatterOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenSelectScatterOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [
AllowsTypeRefinement,
HasValueSemantics,
@ -5747,6 +5773,34 @@ def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
}];
}
def Torch_AtenSliceScatterOp : Torch_Op<"aten.slice_scatter", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$src,
Torch_IntType:$dim,
AnyTorchOptionalIntType:$start,
AnyTorchOptionalIntType:$end,
Torch_IntType:$step
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSliceScatterOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenSliceScatterOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenLenTensorOp : Torch_Op<"aten.len.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -7,6 +7,10 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "../PassDetail.h"
@ -29,6 +33,94 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
static Value toPositiveValidDim(ConversionPatternRewriter &rewriter,
Location loc, Value torchType,
Value builtinType, Value valueForNone,
Value dimSize) {
if (torchType.getType().isa<Torch::NoneType>())
return valueForNone;
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
Value positiveDim =
toPositiveDimDynamic(rewriter, loc, builtinType, dimSizeAsInt);
// startOrEnd < 0 ? 0 : startOrEnd
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
Value startOrEndAtLeastZero =
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
}
template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
SmallVector<Value> &resultShape,
SmallVector<Value> &offsets,
SmallVector<Value> &strides) {
Location loc = op.getLoc();
auto input = adaptor.self();
RankedTensorType inputType =
input.getType().template cast<RankedTensorType>();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
Value dimSize = inputShape[dim];
Value torchTypeStart = op.start();
Value torchTypeEnd = op.end();
Value builtinTypeStart = adaptor.start();
Value builtinTypeEnd = adaptor.end();
if (torchTypeStart.getType().isa<OptionalType>() ||
torchTypeEnd.getType().isa<OptionalType>())
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
int64_t step;
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
if (!op.step().getType().template isa<Torch::NoneType>())
return op->emitError("unimplemented: step is not constant");
step = 1;
}
Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
builtinTypeStart, zero, dimSize);
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
dimSize, dimSize);
// end >= start ? end : start
Value endSgeStart = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, end, start);
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
resultShape = getTensorSizes(rewriter, loc, input);
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
resultShape[dim] = resultSize;
strides.resize(inputType.getRank(), one);
offsets.resize(inputType.getRank(), zero);
offsets[dim] = start;
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
return success();
}
namespace {
class ConvertAtenFlattenUsingIntsOp
: public OpConversionPattern<AtenFlattenUsingIntsOp> {
@ -742,77 +834,19 @@ public:
TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.self();
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
int64_t dim;
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
Value dimSize = inputShape[dim];
auto adjustStartOrEnd = [&](Value startOrEndTorchType,
Value startOrEndBuiltin, Value valueForNone) {
if (startOrEndTorchType.getType().isa<Torch::NoneType>())
return valueForNone;
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
Value startOrEndToPositive =
toPositiveDimDynamic(rewriter, loc, startOrEndBuiltin, dimSizeAsInt);
// startOrEnd < 0 ? 0 : startOrEnd
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, startOrEndToPositive, cst0);
Value startOrEndAtLeastZero = rewriter.create<arith::SelectOp>(
loc, predDimSltZero, cst0, startOrEndToPositive);
// startOrEnd > dimSizeAsInt ? dimSizeAsInt : startOrEnd
Value startOrEndSgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, startOrEndAtLeastZero, dimSizeAsInt);
Value startOrEndBoundedByDimSize = rewriter.create<arith::SelectOp>(
loc, startOrEndSgtDimSize, dimSizeAsInt, startOrEndAtLeastZero);
return castIntToIndex(rewriter, loc, startOrEndBoundedByDimSize);
};
if (op.start().getType().isa<OptionalType>() ||
op.end().getType().isa<OptionalType>())
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
Value start = adjustStartOrEnd(op.start(), adaptor.start(), zero);
Value end = adjustStartOrEnd(op.end(), adaptor.end(), dimSize);
// end >= start ? end : start
Value endSgeStart = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, end, start);
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
int64_t step;
if (!matchPattern(op.step(), m_TorchConstantInt(&step))) {
if (!op.step().getType().isa<Torch::NoneType>())
return op->emitError("unimplemented: step is not constant");
step = 1;
SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
if (failed(prepareArgumentsForSlicingOp<AtenSliceTensorOp,
AtenSliceTensorOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure();
}
// Slice logic: resultSize = floordiv(end - start + step - 1, step)
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
resultSize =
rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
SmallVector<Value> resultShape = getTensorSizes(rewriter, loc, input);
resultShape[dim] = resultSize;
SmallVector<Value> offsets(inputType.getRank(), zero);
SmallVector<Value> strides(inputType.getRank(), one);
offsets[dim] = start;
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
Value result = rewriter.create<tensor::ExtractSliceOp>(
loc, input, offsets, resultShape, strides);
@ -1019,6 +1053,55 @@ public:
};
} // namespace
namespace {
class ConvertAtenSliceScatterOp
: public OpConversionPattern<AtenSliceScatterOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.self();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure();
}
Value src = adaptor.src();
auto srcType = src.getType().cast<RankedTensorType>();
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
auto abstractSrcType =
RankedTensorType::get(srcAbstractSizes, srcType.getElementType());
Value abstractSrc =
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
Value result = rewriter.create<tensor::InsertSliceOp>(
loc, abstractSrc, input, offsets, resultShape, strides);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
@ -1047,4 +1130,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<ValsemVariantAtenCopyOp>();
patterns.add<ConvertValsemVariantAtenCopyOp>(typeConverter, context);
target.addIllegalOp<AtenSliceScatterOp>();
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
}

View File

@ -16,7 +16,9 @@
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringExtras.h"
#include <cstdint>
using namespace mlir;
using namespace mlir::torch;
@ -2120,6 +2122,55 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
};
} // namespace
namespace {
// Decompose the `aten.select_scatter` operation into `aten.slice_scatter` op.
class DecomposeAtenSelectScatterOp
: public OpRewritePattern<AtenSelectScatterOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSelectScatterOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value start = op.index();
Value dim = op.dim();
Value self = op.self();
Value src = op.src();
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusOne =
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
SmallVector<int64_t> sizes;
if (!srcTensorType.hasSizes())
return rewriter.notifyMatchFailure(op, "src tensor must have size");
ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
// `src` has a reduced rank. Hence add 1.
int64_t srcRank = srcShape.size() + 1;
int64_t dimInt = 0;
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, srcRank);
if (!isValidDim(dimInt, srcRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
sizes.append(srcShape.begin(), srcShape.end());
sizes.insert(sizes.begin() + dimInt, 1);
} else {
sizes.resize(srcShape.size() + 1, kUnknownSize);
}
Type srcType = srcTensorType.getWithSizesAndDtype(llvm::makeArrayRef(sizes),
srcTensorType.getDtype());
src = rewriter.create<AtenUnsqueezeOp>(loc, srcType, src, dim);
rewriter.replaceOpWithNewOp<AtenSliceScatterOp>(
op, op.self().getType(), self, src, dim, start, startPlusOne,
/*step=*/one);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -2271,6 +2322,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenFloorDivideOp>();
patterns.add<DecomposeAtenNumpyTOp>(context);
target.addIllegalOp<AtenNumpyTOp>();
patterns.add<DecomposeAtenSelectScatterOp>(context);
target.addIllegalOp<AtenSelectScatterOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -645,10 +645,11 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenSqueezeDimOp, AtenUnsqueezeOp, AtenViewOp, Aten_UnsafeViewOp,
AtenReshapeOp, Aten_ReshapeAliasOp, AtenResize_Op, AtenTransposeIntOp,
AtenTOp, AtenPermuteOp, AtenIndexSelectOp, AtenSelectIntOp,
AtenSliceTensorOp, AtenGatherOp, AtenExpandOp, AtenExpandAsOp,
AtenBroadcastToOp, AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp,
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
AtenSelectScatterOp, AtenSliceTensorOp, AtenSliceScatterOp,
AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp,
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());

File diff suppressed because it is too large Load Diff

View File

@ -904,9 +904,15 @@ def atenbatch_norm(input: List[int], weight: Optional[List[int]], bias: Optio
def atensliceTensor(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, end, step)
def atenslice_scatter(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return self
def atenselectint(self: List[int], dim: int, index: int) -> List[int]:
return upstream_shape_functions.select(self, dim, index)
def atenselect_scatter(self: List[int], src: List[int], dim: int, index: int) -> List[int]:
return self
def atenindex_select(self: List[int], dim: int, index: List[int]) -> List[int]:
return upstream_shape_functions.index_select(self, dim, index)

View File

@ -437,6 +437,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)")
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::sum : (Tensor, int?) -> (Tensor)")
@ -455,6 +456,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)")
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::len.Tensor : (Tensor) -> (int)")
emit("aten::cpu : (Tensor) -> (Tensor)")
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")

View File

@ -232,3 +232,112 @@ def SelectIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (5,5)))
# ==============================================================================
# For aten.slice_scatter op, The arguments are: SliceScatter(input, src, dim=0, start=None, end=None, step=1).
# For aten.select_scatter op, The arguments are: SelectScatter(input, src, dim=0, index).
class SliceScatterModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
@register_test_case(module_factory=lambda: SliceScatterModule())
def SliceScatterModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(6, 1))
class SliceScatterZeroDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 0, end = 1, step = 1)
@register_test_case(module_factory=lambda: SliceScatterZeroDimModule())
def SliceScatterZeroDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(1, 8))
class SliceScatterStepVariationModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 2)
@register_test_case(module_factory=lambda: SliceScatterStepVariationModule())
def SliceScatterStepVariationModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(6, 1))
class SliceScatterStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([6, 8], torch.float32, True),
([6, 1], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
@register_test_case(module_factory=lambda: SliceScatterStaticModule())
def SliceScatterStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(6, 1))
class SelectScatterModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.select_scatter(x, src, dim = 0, index = 0)
@register_test_case(module_factory=lambda: SelectScatterModule())
def SelectScattertModule_basic(module, tu: TestUtils):
module.forward(torch.rand(6, 8, 5), torch.rand(8, 5))
class SelectScatterStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([6, 8, 5], torch.float32, True),
([6, 5], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.select_scatter(x, src, dim = 1, index = 0)
@register_test_case(module_factory=lambda: SelectScatterStaticModule())
def SelectScattertStaticModule_basic(module, tu: TestUtils):
module.forward(torch.rand(6, 8, 5), torch.rand(6, 5))

View File

@ -1087,3 +1087,21 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int
%2 = torch.aten.repeat %arg0, %1 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?],f32>
return %2 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.select_scatter
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?],f32>, %[[SRC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-NEXT: %[[START:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[DIM:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[STEP:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[END:.*]] = torch.aten.add.int %[[START]], %[[STEP]]
// CHECK-NEXT: %[[UNSQUEEZE_SRC:.*]] = torch.aten.unsqueeze %[[SRC]], %[[DIM]]
// CHECK-NEXT: %[[SLICE_SCATTER:.*]] = torch.aten.slice_scatter %[[SELF]], %[[UNSQUEEZE_SRC]], %[[DIM]], %[[START]], %[[END]], %[[STEP]]
// CHECK-NEXT: return %[[SLICE_SCATTER]]
// CHECK-NEXT: }
func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%0 = torch.aten.select_scatter %arg0, %arg1, %int1, %int0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}