mirror of https://github.com/llvm/torch-mlir
Add aten.scatter.value Op ODS
Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>pull/2063/head snapshot-20230425.819
parent
0831424f52
commit
e3d876af42
|
@ -8659,6 +8659,32 @@ def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenScatterValueOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -549,6 +549,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue