Add aten.scatter.value Op ODS

Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>
pull/2063/head snapshot-20230425.819
rahul shrivastava 2023-04-21 00:03:21 -07:00 committed by rahuls-cerebras
parent 0831424f52
commit e3d876af42
2 changed files with 27 additions and 0 deletions

View File

@ -8659,6 +8659,32 @@ def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
}];
}
def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenScatterValueOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -549,6 +549,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")