[TM_TENSOR] Add `aten.scatter.[src|value]` op

This commit adds support of `aten.scatter.src` and `aten.scatter.value`
ops.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/2183/head snapshot-20230529.853
Gaurav Shukla 2022-10-16 03:16:06 +05:30
parent b9d29dc055
commit 552887783a
9 changed files with 337 additions and 54 deletions

View File

@ -263,6 +263,11 @@ TORCHDYNAMO_XFAIL_SET = {
# ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int # ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int
"UnbindIntListUnpack_Module_basic", "UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic", "UnbindIntGetItem_Module_basic",
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
"ScatterValueFloatModule_basic",
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
"ScatterValueIntModule_basic",
} }
TORCHDYNAMO_CRASHING_SET = { TORCHDYNAMO_CRASHING_SET = {
@ -1247,4 +1252,6 @@ LTC_XFAIL_SET = {
"ChunkListUnpackUneven_Module_basic", "ChunkListUnpackUneven_Module_basic",
"ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic",
"ScatterValueFloatModule_basic",
"ScatterValueIntModule_basic",
} }

View File

@ -5073,6 +5073,108 @@ def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [
}]; }];
} }
def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatterSrcOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenScatter_SrcOp : Torch_Op<"aten.scatter_.src", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::scatter_.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatter_SrcOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatter_SrcOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatterValueOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenScatter_ValueOp : Torch_Op<"aten.scatter_.value", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::scatter_.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatter_ValueOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatter_ValueOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -8899,58 +9001,6 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
}]; }];
} }
def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatterSrcOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatterValueOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -299,6 +299,55 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
return SmallVector<Value>(sortOp.getResults()); return SmallVector<Value>(sortOp.getResults());
} }
namespace {
class ConvertAtenScatterSrcOp : public OpConversionPattern<AtenScatterSrcOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
Value self = adaptor.getSelf();
Value index = adaptor.getIndex();
Value src = adaptor.getSrc();
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
RankedTensorType indexType = index.getType().cast<RankedTensorType>();
RankedTensorType srcType = src.getType().cast<RankedTensorType>();
if (selfType.getRank() != indexType.getRank() ||
indexType.getRank() != srcType.getRank())
return rewriter.notifyMatchFailure(op,
"'self', 'index' and 'src' should all"
"have the same number of dimensions.");
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op,
"unimplemented: dim is not constant");
// Get the inputs reformatted for the TMScatterOp
auto [indices, updates] =
convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(rewriter, index,
src, dim);
Value scatterOp = createTMTensorScatterOp(
rewriter, loc, updates, indices, self,
/*uniqueIndices=*/false,
[&](OpBuilder &b, Location loc, Value updatesElement,
Value inputElement) {
b.create<TMTensor::YieldOp>(loc, updatesElement);
});
auto resultType = typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, scatterOp);
return success();
}
};
} // namespace
namespace { namespace {
// aten::bincount op counts the frequency of each value in a 1-d input tensor of // aten::bincount op counts the frequency of each value in a 1-d input tensor of
// non-negative ints. // non-negative ints.
@ -1606,6 +1655,9 @@ public:
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter, patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
context); context);
target.addIllegalOp<AtenScatterSrcOp>();
patterns.add<ConvertAtenScatterSrcOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) std::move(patterns))))
return signalPassFailure(); return signalPassFailure();

View File

@ -7403,6 +7403,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n" " return %arg0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scatter.src\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.float) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -8371,6 +8377,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.src\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.union<float, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"

View File

@ -4399,6 +4399,50 @@ public:
}; };
} // namespace } // namespace
namespace {
// Decompose `aten.scatter.value` op into `aten.scatter.src` op.
class DecomposeAtenScatterValueOp
: public OpRewritePattern<AtenScatterValueOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenScatterValueOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value self = op.getSelf();
Value index = op.getIndex();
std::optional<unsigned> maybeIndexRank = getTensorRank(index);
if (!maybeIndexRank) {
return rewriter.notifyMatchFailure(
op, "expected index tensor to have a rank");
}
unsigned indexRank = *maybeIndexRank;
SmallVector<Value> sizes;
for (int64_t i = 0; i < indexRank; ++i) {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, index, /*dim=*/dim));
}
Value sizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), sizes);
auto selfType = self.getType().cast<BaseTensorType>();
auto indexType = index.getType().cast<BaseTensorType>();
BaseTensorType srcType =
selfType
.getWithSizesAndDtype(indexType.getOptionalSizes(),
selfType.getOptionalDtype())
.cast<BaseTensorType>();
Value src =
createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList);
rewriter.replaceOpWithNewOp<AtenScatterSrcOp>(op, op.getType(), self,
op.getDim(), index, src);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeComplexOpsPass class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> { : public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -4563,6 +4607,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
GreedyRewriteConfig config; GreedyRewriteConfig config;
config.useTopDownTraversal = true; config.useTopDownTraversal = true;

View File

@ -479,6 +479,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenCrossEntropyLossOp>(); target.addIllegalOp<AtenCrossEntropyLossOp>();
target.addIllegalOp<AtenVarMeanDimOp>(); target.addIllegalOp<AtenVarMeanDimOp>();
target.addIllegalOp<AtenTopkOp>(); target.addIllegalOp<AtenTopkOp>();
target.addIllegalOp<AtenScatterValueOp>();
for (auto &opName : backendLegalOpsSet) { for (auto &opName : backendLegalOpsSet) {
target.addLegalOp( target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context)); OperationName(kTorchOpPrefix + opName.first().str(), context));

View File

@ -886,6 +886,12 @@ def atenselect_scatter〡shape(self: List[int], src: List[int], dim: int, ind
def atenscatter_reducetwo〡shape(self: List[int], dim: int, index: List[int], src: List[int], reduce: str, include_self: bool = True) -> List[int]: def atenscatter_reducetwo〡shape(self: List[int], dim: int, index: List[int], src: List[int], reduce: str, include_self: bool = True) -> List[int]:
return self return self
def atenscattersrc〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]:
return self
def atenscattervalue〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]:
return self
def atenindex_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: def atenindex_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
return upstream_shape_functions.index_select(self, dim, index) return upstream_shape_functions.index_select(self, dim, index)
@ -1716,6 +1722,18 @@ def atenselect_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dty
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return self_dtype return self_dtype
@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def atenscattersrc〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), 1.0) for dtype in _SORTED_TORCH_TYPES])
def atenscattervalue〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atensilu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def atensilu〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -410,6 +410,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit( emit(
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)" "aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
) )
emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
@ -559,8 +561,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")

View File

@ -821,6 +821,102 @@ def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ScatterSrcStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([10, 8, 6], torch.float32, True),
([2, 4, 3], torch.int64, True),
([5, 8, 6], torch.float32, True),
])
def forward(self, input, index, src):
return torch.ops.aten.scatter(input, 0, index, src)
@register_test_case(
module_factory=lambda: ScatterSrcStaticModule())
def ScatterSrcStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
tu.rand(5, 8, 6))
# ==============================================================================
class ScatterSrcModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.int64, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, input, index, src):
return torch.ops.aten.scatter(input, 1, index, src)
@register_test_case(
module_factory=lambda: ScatterSrcModule())
def ScatterSrcModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
tu.rand(3, 4, 3))
# ==============================================================================
class ScatterValueFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.int64, True),
([], torch.float64, True),
])
def forward(self, input, index, value):
return torch.ops.aten.scatter(input, 2, index, float(value))
@register_test_case(
module_factory=lambda: ScatterValueFloatModule())
def ScatterValueFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
tu.rand().double())
# ==============================================================================
class ScatterValueIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.int64, True),
([], torch.int64, True),
])
def forward(self, input, index, value):
return torch.ops.aten.scatter(input, 0, index, int(value))
@register_test_case(
module_factory=lambda: ScatterValueIntModule())
def ScatterValueIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4),
tu.randint(high=10))
# ==============================================================================
class ScatterReduceFloatModule(torch.nn.Module): class ScatterReduceFloatModule(torch.nn.Module):
include_self: bool include_self: bool
reduce_type: str reduce_type: str