mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] emit aten.nonzero, aten.nonzero_numpy, aten.nonzero_static op
parent
238c0501da
commit
8fa10911ea
|
@ -6070,6 +6070,77 @@ def Torch_AtenCrossEntropyLossOp : Torch_Op<"aten.cross_entropy_loss", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNonzeroOp : Torch_Op<"aten.nonzero", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::nonzero : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenNonzeroOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenNonzeroOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNonzeroNumpyOp : Torch_Op<"aten.nonzero_numpy", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::nonzero_numpy : (Tensor) -> (Tensor[])`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchListOfTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenNonzeroNumpyOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenNonzeroNumpyOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNonzeroStaticOp : Torch_Op<"aten.nonzero_static", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::nonzero_static : (Tensor, int, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$size,
|
||||||
|
Torch_IntType:$fill_value
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenNonzeroStaticOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenNonzeroStaticOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -7888,6 +7888,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" return %none : !torch.none\n"
|
" return %none : !torch.none\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.nonzero\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
||||||
|
" %1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" return %2 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.masked_select\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n"
|
||||||
|
" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list<int>\n"
|
||||||
|
" return %1 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.nonzero_static\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||||
|
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %1 = torch.prim.ListConstruct %arg1, %0 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" return %1 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.linalg_vector_norm\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.linalg_vector_norm\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.optional<list<int>>, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = torch.derefine %arg4 : !torch.optional<int> to !torch.any\n"
|
" %0 = torch.derefine %arg4 : !torch.optional<int> to !torch.any\n"
|
||||||
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||||
|
@ -9366,6 +9382,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %4 : !torch.int\n"
|
" return %4 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.nonzero\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" return %int4 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.nonzero_static\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" return %int4 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.addmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
|
|
@ -1117,6 +1117,15 @@ def hacky_get_unknown_dimension_size():
|
||||||
def aten〇bincount〡shape(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]:
|
def aten〇bincount〡shape(self: List[int], weights: Optional[List[int]] = None, minlength: int = 0) -> List[int]:
|
||||||
return [hacky_get_unknown_dimension_size()]
|
return [hacky_get_unknown_dimension_size()]
|
||||||
|
|
||||||
|
def aten〇nonzero〡shape(self: List[int]) -> List[int]:
|
||||||
|
return [hacky_get_unknown_dimension_size(), len(self)]
|
||||||
|
|
||||||
|
def aten〇masked_select〡shape(self: List[int], mask: List[int]) -> List[int]:
|
||||||
|
return [hacky_get_unknown_dimension_size()]
|
||||||
|
|
||||||
|
def aten〇nonzero_static〡shape(self: List[int], size: int, fill_value: int = -1) -> List[int]:
|
||||||
|
return [size, len(self)]
|
||||||
|
|
||||||
def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
|
||||||
|
|
||||||
|
@ -2430,6 +2439,14 @@ def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype
|
||||||
return torch.int64
|
return torch.int64
|
||||||
return torch.float64
|
return torch.float64
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device=torch.device("cpu")))
|
||||||
|
def aten〇nonzero〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
|
return torch.int64
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=5, tensor_device=torch.device("cpu")))
|
||||||
|
def aten〇nonzero_static〡dtype(self_rank_dtype: Tuple[int, int], size: int, fill_value: int = -1) -> int:
|
||||||
|
return torch.int64
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) +
|
_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) +
|
||||||
# Different width
|
# Different width
|
||||||
|
|
|
@ -446,6 +446,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
||||||
emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)")
|
emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)")
|
||||||
emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)")
|
emit("aten::cross_entropy_loss : (Tensor, Tensor, Tensor?, int, int, float) -> (Tensor)")
|
||||||
|
emit("aten::nonzero : (Tensor) -> (Tensor)")
|
||||||
|
emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])")
|
||||||
|
emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)")
|
||||||
|
|
||||||
# Misc tensor ops.
|
# Misc tensor ops.
|
||||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue