diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index cf2653d0f..b1be759d8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5688,6 +5688,32 @@ def Torch_AtenGatherOp : Torch_Op<"aten.gather", [ }]; } +def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::scatter_add : (Tensor self, int dim, Tensor index, Tensor src) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$src + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatterAddOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenScatterAddOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index aa3a4ede7..204210271 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -446,6 +446,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") + emit("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit("aten::IntImplicit : (Tensor) -> (int)") emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)