mirror of https://github.com/llvm/torch-mlir
Add ODS for aten.scatter.src
Signed-off-by: rahul shrivastava <rahul.shrivastava@cerebras.net>pull/2060/head
parent
a3a62a9951
commit
85916dab33
|
@ -8576,6 +8576,32 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenScatterSrcOp : Torch_Op<"aten.scatter.src", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
AnyTorchTensorType:$index,
|
||||
AnyTorchTensorType:$src
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenScatterSrcOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenScatterSrcOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -546,6 +546,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
|
||||
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue