mirror of https://github.com/llvm/torch-mlir
Add aten.scatter.value Op ODS
Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>pull/2063/head snapshot-20230425.819
parent
0831424f52
commit
e3d876af42
|
@ -8659,6 +8659,32 @@ def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenScatterValueOp : Torch_Op<"aten.scatter.value", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
AnyTorchTensorType:$index,
|
||||||
|
AnyTorchScalarType:$value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenScatterValueOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenScatterValueOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -549,6 +549,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
||||||
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||||
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||||
|
emit("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue