From 2c2009a13ded11b4725972d3450e1739ce1ce1c4 Mon Sep 17 00:00:00 2001 From: Zachary Cetinic Date: Fri, 3 Feb 2023 12:54:28 -0500 Subject: [PATCH] Add in-place variant of torch.scatter_add (#1836) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9aa76ff35..46e5a50e4 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7613,6 +7613,31 @@ def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [ }]; } +def Torch_AtenScatterAdd_Op : Torch_Op<"aten.scatter_add_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::scatter_add_ : (Tensor, int, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$src + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatterAdd_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenScatterAdd_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index addb4aba0..3fc94387f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -509,7 +509,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") - emit("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") + emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit("aten::IntImplicit : (Tensor) -> (int)") emit("aten::FloatImplicit : (Tensor) -> (float)") emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")