[Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op (#2358)

* [Torch Dialect] emit aten.masked_scatter and aten.masked_scatter_ op
pull/2372/head snapshot-20230803.919
Jiawei Wu 2023-08-03 16:21:14 +08:00 committed by GitHub
parent 48f4e8f673
commit 6db92d1b14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 0 deletions

View File

@ -5250,6 +5250,55 @@ def Torch_AtenScatter_ValueOp : Torch_Op<"aten.scatter_.value", [
}];
}
def Torch_AtenMaskedScatterOp : Torch_Op<"aten.masked_scatter", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mask,
AnyTorchTensorType:$source
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedScatterOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedScatterOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::masked_scatter_ : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mask,
AnyTorchTensorType:$source
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedScatter_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedScatter_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -7646,6 +7646,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.slice_scatter\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.int) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.masked_scatter\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.select.int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -8672,6 +8675,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.masked_scatter\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.silu\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -939,6 +939,9 @@ def atennarrowTensor〡shape(self: List[int], dim: int, start: List[int],
def atenslice_scatter〡shape(self: List[int], src: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return self
def atenmasked_scatter〡shape(self: List[int], mask: List[int], source: List[int]) -> List[int]:
return self
def atenselectint〡shape(self: List[int], dim: int, index: int) -> List[int]:
return upstream_shape_functions.select(self, dim, index)
@ -1825,6 +1828,12 @@ def atenscattervalue〡dtype(self_rank_dtype: Tuple[int, int], dim: int, i
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(
[Invocation(TensorOfShape(3, dtype=dtype), TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def atenmasked_scatter〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atensilu〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -416,6 +416,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
)
emit_with_mutating_variants("aten::scatter.src : (Tensor, int, Tensor, Tensor) -> (Tensor)")
emit_with_mutating_variants("aten::scatter.value : (Tensor, int, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")