[MLIR][TORCH] Add support for dim=None to Aten[Var|Std]DimOp

PyTorch recently added support for `dim=None` in the `torch.var`
(5ca9b2b6fa)
and `torch.std`op (eb0e30e0bc).
This commit adds the corresponding support in torch-mlir.

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1166/head
Vivek Khandelwal 2022-08-05 11:54:00 +05:30
parent 31727f81d8
commit c129a6de93
6 changed files with 122 additions and 20 deletions

View File

@ -317,6 +317,8 @@ LTC_XFAIL_SET = {
"StdDimBiasedModule_basic", "StdDimBiasedModule_basic",
"StdDimKeepDimFalseModule_basic", "StdDimKeepDimFalseModule_basic",
"StdDimKeepDimTrueModule_basic", "StdDimKeepDimTrueModule_basic",
"StdDimEmptyDimModule_basic",
"StdDimNoneDimModule_basic",
"StdUnbiasedModule_basic", "StdUnbiasedModule_basic",
"SubFloatModule_basic", "SubFloatModule_basic",
"SubIntModule_basic", "SubIntModule_basic",

View File

@ -4049,10 +4049,10 @@ def Torch_AtenStdDimOp : Torch_Op<"aten.std.dim", [
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)`"; let summary = "Generated op for `aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$self, AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim, AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$unbiased, Torch_BoolType:$unbiased,
Torch_BoolType:$keepdim Torch_BoolType:$keepdim
); );
@ -4099,10 +4099,10 @@ def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)`"; let summary = "Generated op for `aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$self, AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim, AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$unbiased, Torch_BoolType:$unbiased,
Torch_BoolType:$keepdim Torch_BoolType:$keepdim
); );

View File

@ -5564,12 +5564,19 @@ module {
%0 = torch.prim.ListConstruct : () -> !torch.list<int> %0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> { func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
%none = torch.constant.none
%true = torch.constant.bool true %true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0 %int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int %0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool %1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) { %2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int %5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int> %6 = torch.prim.ListConstruct : () -> !torch.list<int>
@ -5580,7 +5587,8 @@ module {
} : (!torch.int, !torch.bool) -> () } : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int> torch.prim.If.yield %6 : !torch.list<int>
} else { } else {
torch.prim.If.yield %arg1 : !torch.list<int> %5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
} }
%3 = torch.derefine %none : !torch.none to !torch.any %3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int> %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
@ -5620,11 +5628,35 @@ module {
%0 = torch.prim.ListConstruct : () -> !torch.list<int> %0 = torch.prim.ListConstruct : () -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> { func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none %none = torch.constant.none
%0 = torch.derefine %none : !torch.none to !torch.any %int0 = torch.constant.int 0
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int> %0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
return %1 : !torch.list<int> %1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
}
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
%3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
} }
func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> { func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.list<int> {
%none = torch.constant.none %none = torch.constant.none

View File

@ -489,8 +489,8 @@ def atenmean(self: List[int], dtype: Optional[int] = None) -> List[int]:
def atenvar(self: List[int], unbiased: bool = True) -> List[int]: def atenvar(self: List[int], unbiased: bool = True) -> List[int]:
return [] return []
def atenvardim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]: def atenvardim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if len(dim)==0: if dim is None or len(dim)==0:
dim = list(range(len(self))) dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, None) return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
@ -502,7 +502,9 @@ def atenvarcorrection(self: List[int], dim: Optional[List[int]], correctio
def atenstd(self: List[int], unbiased: bool = True) -> List[int]: def atenstd(self: List[int], unbiased: bool = True) -> List[int]:
return [] return []
def atenstddim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]: def atenstddim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, None) return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): def _reduce_along_dim(self: List[int], dim: int, keepdim: bool):

View File

@ -382,9 +382,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)")
emit("aten::std : (Tensor, bool) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)")
emit("aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)") emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
emit("aten::var : (Tensor, bool) -> (Tensor)") emit("aten::var : (Tensor, bool) -> (Tensor)")
emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)") emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)")
emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)") emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)")
emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)")
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")

View File

@ -361,6 +361,50 @@ def StdDimBiasedModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class StdDimEmptyDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=[], keepdim=False)
@register_test_case(module_factory=lambda: StdDimEmptyDimModule())
def StdDimEmptyDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class StdDimNoneDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, dim=None, keepdim=False)
@register_test_case(module_factory=lambda: StdDimNoneDimModule())
def StdDimNoneDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class VarDimModule(torch.nn.Module): class VarDimModule(torch.nn.Module):
def __init__(self): def __init__(self):
@ -416,7 +460,7 @@ class VarDimBiasedModule(torch.nn.Module):
([-1, -1, -1], torch.float64, True), ([-1, -1, -1], torch.float64, True),
]) ])
def forward(self, x): def forward(self, x):
return torch.ops.aten.var(x, dim=0, unbiased=False, keepdim=True) return torch.ops.aten.var(x, dim=(0,1), unbiased=False, keepdim=True)
@register_test_case(module_factory=lambda: VarDimBiasedModule()) @register_test_case(module_factory=lambda: VarDimBiasedModule())
@ -438,7 +482,7 @@ class VarDimSingleDimModule(torch.nn.Module):
([-1, -1, -1], torch.float64, True), ([-1, -1, -1], torch.float64, True),
]) ])
def forward(self, x): def forward(self, x):
return torch.ops.aten.var(x, dim=0, keepdim=True) return torch.ops.aten.var(x, dim=(0,), keepdim=True)
@register_test_case(module_factory=lambda: VarDimSingleDimModule()) @register_test_case(module_factory=lambda: VarDimSingleDimModule())
@ -537,6 +581,28 @@ def VarDimEmptyDimModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class VarDimNoneDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=None, keepdim=False)
@register_test_case(module_factory=lambda: VarDimNoneDimModule())
def VarDimNoneDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class VarCorrectionModule(torch.nn.Module): class VarCorrectionModule(torch.nn.Module):
def __init__(self): def __init__(self):