From 8fa10911ea34262b42d5643fa9e7eb421bcb6837 Mon Sep 17 00:00:00 2001 From: "wujiawei.aml" Date: Mon, 24 Jul 2023 16:21:53 +0800 Subject: [PATCH] [Torch Dialect] emit aten.nonzero, aten.nonzero_numpy, aten.nonzero_static op --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 71 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 24 +++++++ .../build_tools/abstract_interp_lib_gen.py | 17 +++++ .../jit_ir/build_tools/torch_ods_gen.py | 3 + 4 files changed, 115 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 51a6abc04..7999c6b51 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6070,6 +6070,77 @@ def Torch_AtenCrossEntropyLossOp : Torch_Op<"aten.cross_entropy_loss", [ }]; } +def Torch_AtenNonzeroOp : Torch_Op<"aten.nonzero", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nonzero : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNonzeroOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenNonzeroOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenNonzeroNumpyOp : Torch_Op<"aten.nonzero_numpy", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nonzero_numpy : (Tensor) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNonzeroNumpyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenNonzeroNumpyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenNonzeroStaticOp : Torch_Op<"aten.nonzero_static", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nonzero_static : (Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$size, + Torch_IntType:$fill_value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNonzeroStaticOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenNonzeroStaticOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 520c11c5e..915f4317a 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7888,6 +7888,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " return %none : !torch.none\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nonzero\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.masked_select\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" +" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nonzero_static\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.prim.ListConstruct %arg1, %0 : (!torch.int, !torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.linalg_vector_norm\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg4 : !torch.optional to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" @@ -9366,6 +9382,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nonzero\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nonzero_static\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index a0ae1e027..2e906cb31 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -1117,6 +1117,15 @@ def hacky_get_unknown_dimension_size(): def aten〇bincount〡shape(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]: return [hacky_get_unknown_dimension_size()] +def aten〇nonzero〡shape(self: List[int]) -> List[int]: + return [hacky_get_unknown_dimension_size(), len(self)] + +def aten〇masked_select〡shape(self: List[int], mask: List[int]) -> List[int]: + return [hacky_get_unknown_dimension_size()] + +def aten〇nonzero_static〡shape(self: List[int], size: int, fill_value: int = -1) -> List[int]: + return [size, len(self)] + def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) @@ -2430,6 +2439,14 @@ def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype return torch.int64 return torch.float64 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device=torch.device("cpu"))) +def aten〇nonzero〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=5, tensor_device=torch.device("cpu"))) +def aten〇nonzero_static〡dtype(self_rank_dtype: Tuple[int, int], size: int, fill_value: int = -1) -> int: + return torch.int64 + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + # Different width diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 5dbec3999..cde821522 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -446,6 +446,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)") emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)") + emit("aten::nonzero : (Tensor) -> (Tensor)") + emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])") + emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")