From 85916dab331bd854e4c181b3aee1e8129daf8f6d Mon Sep 17 00:00:00 2001 From: rahul shrivastava Date: Fri, 21 Apr 2023 08:45:27 -0700 Subject: [PATCH] Add ODS for aten.scatter.src Signed-off-by: rahul shrivastava --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++++++++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + 2 files changed, 27 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4f794b98c..cd583f1fe 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8576,6 +8576,32 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [ }]; } +def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$src + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenScatterSrcOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index d0804de02..9e834f62f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -546,6 +546,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") + emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")