From 552887783a58376842d3b2ca64f97f8dcd84a347 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Sun, 16 Oct 2022 03:16:06 +0530 Subject: [PATCH] [TM_TENSOR] Add `aten.scatter.[src|value]` op This commit adds support of `aten.scatter.src` and `aten.scatter.value` ops. Signed-Off-by: Gaurav Shukla --- e2e_testing/xfail_sets.py | 7 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 154 ++++++++++++------ .../TorchToTMTensor/TorchToTMTensor.cpp | 52 ++++++ .../Transforms/AbstractInterpLibrary.cpp | 14 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 45 +++++ .../Transforms/LowerToBackendContract.cpp | 1 + .../build_tools/abstract_interp_lib_gen.py | 18 ++ .../jit_ir/build_tools/torch_ods_gen.py | 4 +- .../torch_mlir_e2e_test/test_suite/scatter.py | 96 +++++++++++ 9 files changed, 337 insertions(+), 54 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 8e120b658..b780aad79 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -263,6 +263,11 @@ TORCHDYNAMO_XFAIL_SET = { # ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", + + # ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} + "ScatterValueFloatModule_basic", + # ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} + "ScatterValueIntModule_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -1247,4 +1252,6 @@ LTC_XFAIL_SET = { "ChunkListUnpackUneven_Module_basic", "ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 3cb4fb642..2c5b6b659 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5073,6 +5073,108 @@ def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [ }]; } +def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$src + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenScatterSrcOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenScatter_SrcOp : Torch_Op<"aten.scatter_.src", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::scatter_.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$src + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatter_SrcOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenScatter_SrcOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchScalarType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenScatterValueOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenScatter_ValueOp : Torch_Op<"aten.scatter_.value", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::scatter_.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchScalarType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatter_ValueOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenScatter_ValueOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -8899,58 +9001,6 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [ }]; } -def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchTensorType:$index, - AnyTorchTensorType:$src - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); - } - void AtenScatterSrcOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); - } - }]; -} - -def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchTensorType:$index, - AnyTorchScalarType:$value - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); - } - void AtenScatterValueOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); - } - }]; -} - def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 7bb565861..a34e2db83 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -299,6 +299,55 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, return SmallVector(sortOp.getResults()); } +namespace { +class ConvertAtenScatterSrcOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op.getLoc(); + TypeConverter *typeConverter = getTypeConverter(); + Value self = adaptor.getSelf(); + Value index = adaptor.getIndex(); + Value src = adaptor.getSrc(); + + RankedTensorType selfType = self.getType().cast(); + RankedTensorType indexType = index.getType().cast(); + RankedTensorType srcType = src.getType().cast(); + if (selfType.getRank() != indexType.getRank() || + indexType.getRank() != srcType.getRank()) + return rewriter.notifyMatchFailure(op, + "'self', 'index' and 'src' should all" + "have the same number of dimensions."); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, + "unimplemented: dim is not constant"); + + // Get the inputs reformatted for the TMScatterOp + auto [indices, updates] = + convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(rewriter, index, + src, dim); + Value scatterOp = createTMTensorScatterOp( + rewriter, loc, updates, indices, self, + /*uniqueIndices=*/false, + [&](OpBuilder &b, Location loc, Value updatesElement, + Value inputElement) { + b.create(loc, updatesElement); + }); + + auto resultType = typeConverter->convertType(op->getResult(0).getType()) + .cast(); + rewriter.replaceOpWithNewOp(op, resultType, scatterOp); + return success(); + } +}; +} // namespace + namespace { // aten::bincount op counts the frequency of each value in a 1-d input tensor of // non-negative ints. @@ -1606,6 +1655,9 @@ public: patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5fd0b44fc..ac7afbb82 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7403,6 +7403,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.scatter_reduce.two\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.scatter.src\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8371,6 +8377,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter.src\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter.value\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.union) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9e03056d1..5dce34b9e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4399,6 +4399,50 @@ public: }; } // namespace +namespace { +// Decompose `aten.scatter.value` op into `aten.scatter.src` op. +class DecomposeAtenScatterValueOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenScatterValueOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + Value self = op.getSelf(); + Value index = op.getIndex(); + std::optional maybeIndexRank = getTensorRank(index); + if (!maybeIndexRank) { + return rewriter.notifyMatchFailure( + op, "expected index tensor to have a rank"); + } + unsigned indexRank = *maybeIndexRank; + SmallVector sizes; + for (int64_t i = 0; i < indexRank; ++i) { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + sizes.push_back(rewriter.create(loc, index, /*dim=*/dim)); + } + Value sizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), sizes); + + auto selfType = self.getType().cast(); + auto indexType = index.getType().cast(); + BaseTensorType srcType = + selfType + .getWithSizesAndDtype(indexType.getOptionalSizes(), + selfType.getOptionalDtype()) + .cast(); + Value src = + createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList); + rewriter.replaceOpWithNewOp(op, op.getType(), self, + op.getDim(), index, src); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -4563,6 +4607,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f7cf3c95a..6fd5c3379 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -479,6 +479,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index b2d251365..bb0be8940 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -886,6 +886,12 @@ def aten〇select_scatter〡shape(self: List[int], src: List[int], dim: int, ind def aten〇scatter_reduce〇two〡shape(self: List[int], dim: int, index: List[int], src: List[int], reduce: str, include_self: bool = True) -> List[int]: return self +def aten〇scatter〇src〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]: + return self + +def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]: + return self + def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: return upstream_shape_functions.index_select(self, dim, index) @@ -1716,6 +1722,18 @@ def aten〇select_scatter〡dtype(self_rank_dtype: Tuple[int, int], src_rank_dty self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter〇src〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), 1.0) for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], value: Union[int, float]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇silu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 368f34777..d848ced3a 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -410,6 +410,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::_log_softmax : (Tensor, int, bool) -> (Tensor)" ) + emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") + emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") @@ -559,8 +561,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") - emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") - emit("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)") emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/scatter.py b/python/torch_mlir_e2e_test/test_suite/scatter.py index 784ea2ac8..bc1df1832 100644 --- a/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -821,6 +821,102 @@ def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils): # ============================================================================== +class ScatterSrcStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([10, 8, 6], torch.float32, True), + ([2, 4, 3], torch.int64, True), + ([5, 8, 6], torch.float32, True), + ]) + def forward(self, input, index, src): + return torch.ops.aten.scatter(input, 0, index, src) + + +@register_test_case( + module_factory=lambda: ScatterSrcStaticModule()) +def ScatterSrcStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), + tu.rand(5, 8, 6)) + +# ============================================================================== + +class ScatterSrcModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, input, index, src): + return torch.ops.aten.scatter(input, 1, index, src) + + +@register_test_case( + module_factory=lambda: ScatterSrcModule()) +def ScatterSrcModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), + tu.rand(3, 4, 3)) + +# ============================================================================== + +class ScatterValueFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ([], torch.float64, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten.scatter(input, 2, index, float(value)) + + +@register_test_case( + module_factory=lambda: ScatterValueFloatModule()) +def ScatterValueFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), + tu.rand().double()) + +# ============================================================================== + +class ScatterValueIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ([], torch.int64, True), + ]) + def forward(self, input, index, value): + return torch.ops.aten.scatter(input, 0, index, int(value)) + + +@register_test_case( + module_factory=lambda: ScatterValueIntModule()) +def ScatterValueIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), + tu.randint(high=10)) + +# ============================================================================== + class ScatterReduceFloatModule(torch.nn.Module): include_self: bool reduce_type: str