[Torch Dialect] emit aten.nonzero, aten.nonzero_numpy, aten.nonzero_static op

wjw.emit_nonzero
wujiawei.aml 2023-07-24 16:21:53 +08:00
parent 238c0501da
commit 8fa10911ea
4 changed files with 115 additions and 0 deletions

View File

@ -6070,6 +6070,77 @@ def Torch_AtenCrossEntropyLossOp : Torch_Op<"aten.cross_entropy_loss", [
}]; }];
} }
def Torch_AtenNonzeroOp : Torch_Op<"aten.nonzero", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::nonzero : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNonzeroOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenNonzeroOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenNonzeroNumpyOp : Torch_Op<"aten.nonzero_numpy", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::nonzero_numpy : (Tensor) -> (Tensor[])`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNonzeroNumpyOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenNonzeroNumpyOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenNonzeroStaticOp : Torch_Op<"aten.nonzero_static", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::nonzero_static : (Tensor, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$size,
Torch_IntType:$fill_value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNonzeroStaticOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenNonzeroStaticOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -7888,6 +7888,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" return %none : !torch.none\n" " return %none : !torch.none\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nonzero\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.masked_select\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nonzero_static\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.prim.ListConstruct %arg1, %0 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linalg_vector_norm\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.linalg_vector_norm\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg4 : !torch.optional<int> to !torch.any\n" " %0 = torch.derefine %arg4 : !torch.optional<int> to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
@ -9366,6 +9382,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n" " }\n"
" return %4 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.nonzero\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.nonzero_static\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -1117,6 +1117,15 @@ def hacky_get_unknown_dimension_size():
def atenbincount〡shape(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]: def atenbincount〡shape(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]:
return [hacky_get_unknown_dimension_size()] return [hacky_get_unknown_dimension_size()]
def atennonzero〡shape(self: List[int]) -> List[int]:
return [hacky_get_unknown_dimension_size(), len(self)]
def atenmasked_select〡shape(self: List[int], mask: List[int]) -> List[int]:
return [hacky_get_unknown_dimension_size()]
def atennonzero_static〡shape(self: List[int], size: int, fill_value: int = -1) -> List[int]:
return [size, len(self)]
def atenlinalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: def atenlinalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
@ -2430,6 +2439,14 @@ def atenbincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype
return torch.int64 return torch.int64
return torch.float64 return torch.float64
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device=torch.device("cpu")))
def atennonzero〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.int64
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=5, tensor_device=torch.device("cpu")))
def atennonzero_static〡dtype(self_rank_dtype: Tuple[int, int], size: int, fill_value: int = -1) -> int:
return torch.int64
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) +
# Different width # Different width

View File

@ -446,6 +446,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)") emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)")
emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)") emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)")
emit("aten::nonzero : (Tensor) -> (Tensor)")
emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])")
emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)")
# Misc tensor ops. # Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")