Add ODS for aten.scatter.src

Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>
pull/2060/head
rahul shrivastava 2023-04-21 08:45:27 -07:00 committed by rahuls-cerebras
parent a3a62a9951
commit 85916dab33
2 changed files with 27 additions and 0 deletions

View File

@ -8576,6 +8576,32 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
}]; }];
} }
def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatterSrcOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -546,6 +546,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")