mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for dim=None to Aten[Var|Std]DimOp
PyTorch recently added support for `dim=None` in the `torch.var` (pull/1166/head5ca9b2b6fa
) and `torch.std`op (eb0e30e0bc
). This commit adds the corresponding support in torch-mlir. Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
parent
31727f81d8
commit
c129a6de93
|
@ -317,6 +317,8 @@ LTC_XFAIL_SET = {
|
|||
"StdDimBiasedModule_basic",
|
||||
"StdDimKeepDimFalseModule_basic",
|
||||
"StdDimKeepDimTrueModule_basic",
|
||||
"StdDimEmptyDimModule_basic",
|
||||
"StdDimNoneDimModule_basic",
|
||||
"StdUnbiasedModule_basic",
|
||||
"SubFloatModule_basic",
|
||||
"SubIntModule_basic",
|
||||
|
|
|
@ -4049,10 +4049,10 @@ def Torch_AtenStdDimOp : Torch_Op<"aten.std.dim", [
|
|||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$dim,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
Torch_BoolType:$unbiased,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
|
@ -4099,10 +4099,10 @@ def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [
|
|||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$dim,
|
||||
AnyTorchOptionalListOfTorchIntType:$dim,
|
||||
Torch_BoolType:$unbiased,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
|
|
|
@ -5564,12 +5564,19 @@ module {
|
|||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
|
||||
%none = torch.constant.none
|
||||
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||
%1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %7 : !torch.bool
|
||||
}
|
||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
|
@ -5580,7 +5587,8 @@ module {
|
|||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
} else {
|
||||
torch.prim.If.yield %arg1 : !torch.list<int>
|
||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %5 : !torch.list<int>
|
||||
}
|
||||
%3 = torch.derefine %none : !torch.none to !torch.any
|
||||
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
|
@ -5620,11 +5628,35 @@ module {
|
|||
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
|
||||
func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
|
||||
%true = torch.constant.bool true
|
||||
%none = torch.constant.none
|
||||
%0 = torch.derefine %none : !torch.none to !torch.any
|
||||
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %1 : !torch.list<int>
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||
torch.prim.If.yield %true : !torch.bool
|
||||
} else {
|
||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
||||
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If.yield %7 : !torch.bool
|
||||
}
|
||||
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
torch.prim.Loop %5, %true, init() {
|
||||
^bb0(%arg4: !torch.int):
|
||||
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
torch.prim.If.yield %6 : !torch.list<int>
|
||||
} else {
|
||||
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||
torch.prim.If.yield %5 : !torch.list<int>
|
||||
}
|
||||
%3 = torch.derefine %none : !torch.none to !torch.any
|
||||
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||
return %4 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {
|
||||
%none = torch.constant.none
|
||||
|
|
|
@ -489,8 +489,8 @@ def aten〇mean(self: List[int], dtype: Optional[int] = None) -> List[int]:
|
|||
def aten〇var(self: List[int], unbiased: bool = True) -> List[int]:
|
||||
return []
|
||||
|
||||
def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||
if len(dim)==0:
|
||||
def aten〇var〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||
if dim is None or len(dim)==0:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
|
||||
|
||||
|
@ -502,7 +502,9 @@ def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correctio
|
|||
def aten〇std(self: List[int], unbiased: bool = True) -> List[int]:
|
||||
return []
|
||||
|
||||
def aten〇std〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||
def aten〇std〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||
if dim is None or len(dim)==0:
|
||||
dim = list(range(len(self)))
|
||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
|
||||
|
||||
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):
|
||||
|
|
|
@ -382,9 +382,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
|
||||
emit("aten::mean : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::std : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)")
|
||||
emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
|
||||
emit("aten::var : (Tensor, bool) -> (Tensor)")
|
||||
emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)")
|
||||
emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
|
||||
emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)")
|
||||
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
|
||||
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -361,6 +361,50 @@ def StdDimBiasedModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class StdDimEmptyDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=[], keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdDimEmptyDimModule())
|
||||
def StdDimEmptyDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdDimNoneDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=None, keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdDimNoneDimModule())
|
||||
def StdDimNoneDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -416,7 +460,7 @@ class VarDimBiasedModule(torch.nn.Module):
|
|||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=0, unbiased=False, keepdim=True)
|
||||
return torch.ops.aten.var(x, dim=(0,1), unbiased=False, keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarDimBiasedModule())
|
||||
|
@ -438,7 +482,7 @@ class VarDimSingleDimModule(torch.nn.Module):
|
|||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=0, keepdim=True)
|
||||
return torch.ops.aten.var(x, dim=(0,), keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarDimSingleDimModule())
|
||||
|
@ -537,6 +581,28 @@ def VarDimEmptyDimModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class VarDimNoneDimModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=None, keepdim=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarDimNoneDimModule())
|
||||
def VarDimNoneDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarCorrectionModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue