mirror of https://github.com/llvm/torch-mlir
Add ODS for aten.scatter.src
Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>pull/2060/head
parent
a3a62a9951
commit
85916dab33
|
@ -8576,6 +8576,32 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
AnyTorchTensorType:$index,
|
||||||
|
AnyTorchTensorType:$src
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenScatterSrcOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -546,6 +546,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
||||||
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||||
|
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue