From c129a6de9390641d65afaa6d874d594e7268ab3c Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 5 Aug 2022 11:54:00 +0530 Subject: [PATCH] [MLIR][TORCH] Add support for dim=None to Aten[Var|Std]DimOp PyTorch recently added support for `dim=None` in the `torch.var` (https://github.com/pytorch/pytorch/commit/5ca9b2b6fa4d5871ad900a9d4e5ba34c02934bb3) and `torch.std`op (https://github.com/pytorch/pytorch/commit/eb0e30e0bc86477f5ae17e9d5651a9fbc01a91e3). This commit adds the corresponding support in torch-mlir. Signed-Off By: Vivek Khandelwal --- e2e_testing/torchscript/xfail_sets.py | 2 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 8 +-- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 50 ++++++++++--- .../jit_ir/build_tools/shape_lib_gen.py | 8 ++- .../jit_ir/build_tools/torch_ods_gen.py | 4 +- .../torch_mlir_e2e_test/test_suite/stats.py | 70 ++++++++++++++++++- 6 files changed, 122 insertions(+), 20 deletions(-) diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 0ac56e9cf..62aef8991 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -317,6 +317,8 @@ LTC_XFAIL_SET = { "StdDimBiasedModule_basic", "StdDimKeepDimFalseModule_basic", "StdDimKeepDimTrueModule_basic", + "StdDimEmptyDimModule_basic", + "StdDimNoneDimModule_basic", "StdUnbiasedModule_basic", "SubFloatModule_basic", "SubIntModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 043e206dc..ec73d5d33 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4049,10 +4049,10 @@ def Torch_AtenStdDimOp : Torch_Op<"aten.std.dim", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dim, + AnyTorchOptionalListOfTorchIntType:$dim, Torch_BoolType:$unbiased, Torch_BoolType:$keepdim ); @@ -4099,10 +4099,10 @@ def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchListOfTorchIntType:$dim, + AnyTorchOptionalListOfTorchIntType:$dim, Torch_BoolType:$unbiased, Torch_BoolType:$keepdim ); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index dd2407bcc..ce315375b 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5564,12 +5564,19 @@ module { %0 = torch.prim.ListConstruct : () -> !torch.list return %0 : !torch.list } - func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { - %none = torch.constant.none + func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { %true = torch.constant.bool true + %none = torch.constant.none %int0 = torch.constant.int 0 - %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int - %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool + %1 = torch.prim.If %0 -> (!torch.bool) { + torch.prim.If.yield %true : !torch.bool + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %7 : !torch.bool + } %2 = torch.prim.If %1 -> (!torch.list) { %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %6 = torch.prim.ListConstruct : () -> !torch.list @@ -5580,7 +5587,8 @@ module { } : (!torch.int, !torch.bool) -> () torch.prim.If.yield %6 : !torch.list } else { - torch.prim.If.yield %arg1 : !torch.list + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %5 : !torch.list } %3 = torch.derefine %none : !torch.none to !torch.any %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list @@ -5620,11 +5628,35 @@ module { %0 = torch.prim.ListConstruct : () -> !torch.list return %0 : !torch.list } - func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { + func.func @"__torch_mlir_shape_fn.aten.std.dim"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { + %true = torch.constant.bool true %none = torch.constant.none - %0 = torch.derefine %none : !torch.none to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list + %int0 = torch.constant.int 0 + %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool + %1 = torch.prim.If %0 -> (!torch.bool) { + torch.prim.If.yield %true : !torch.bool + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %7 : !torch.bool + } + %2 = torch.prim.If %1 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %5 : !torch.list + } + %3 = torch.derefine %none : !torch.none to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.argmax"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list { %none = torch.constant.none diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 58a9de829..8036a6c0b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -489,8 +489,8 @@ def aten〇mean(self: List[int], dtype: Optional[int] = None) -> List[int]: def aten〇var(self: List[int], unbiased: bool = True) -> List[int]: return [] -def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]: - if len(dim)==0: +def aten〇var〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: + if dim is None or len(dim)==0: dim = list(range(len(self))) return upstream_shape_functions.mean_dim(self, dim, keepdim, None) @@ -502,7 +502,9 @@ def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correctio def aten〇std(self: List[int], unbiased: bool = True) -> List[int]: return [] -def aten〇std〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]: +def aten〇std〇dim(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: + if dim is None or len(dim)==0: + dim = list(range(len(self))) return upstream_shape_functions.mean_dim(self, dim, keepdim, None) def _reduce_along_dim(self: List[int], dim: int, keepdim: bool): diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 2dbaf9a40..0f92c14b9 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -382,9 +382,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)") - emit("aten::std.dim : (Tensor, int[], bool, bool) -> (Tensor)") + emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)") emit("aten::var : (Tensor, bool) -> (Tensor)") - emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)") + emit("aten::var.dim : (Tensor, int[]?, bool, bool) -> (Tensor)") emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/stats.py b/python/torch_mlir_e2e_test/test_suite/stats.py index 1093a90aa..ebf9573bb 100644 --- a/python/torch_mlir_e2e_test/test_suite/stats.py +++ b/python/torch_mlir_e2e_test/test_suite/stats.py @@ -361,6 +361,50 @@ def StdDimBiasedModule_basic(module, tu: TestUtils): # ============================================================================== +class StdDimEmptyDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, dim=[], keepdim=False) + + +@register_test_case(module_factory=lambda: StdDimEmptyDimModule()) +def StdDimEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class StdDimNoneDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, dim=None, keepdim=False) + + +@register_test_case(module_factory=lambda: StdDimNoneDimModule()) +def StdDimNoneDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + class VarDimModule(torch.nn.Module): def __init__(self): @@ -416,7 +460,7 @@ class VarDimBiasedModule(torch.nn.Module): ([-1, -1, -1], torch.float64, True), ]) def forward(self, x): - return torch.ops.aten.var(x, dim=0, unbiased=False, keepdim=True) + return torch.ops.aten.var(x, dim=(0,1), unbiased=False, keepdim=True) @register_test_case(module_factory=lambda: VarDimBiasedModule()) @@ -438,7 +482,7 @@ class VarDimSingleDimModule(torch.nn.Module): ([-1, -1, -1], torch.float64, True), ]) def forward(self, x): - return torch.ops.aten.var(x, dim=0, keepdim=True) + return torch.ops.aten.var(x, dim=(0,), keepdim=True) @register_test_case(module_factory=lambda: VarDimSingleDimModule()) @@ -537,6 +581,28 @@ def VarDimEmptyDimModule_basic(module, tu: TestUtils): # ============================================================================== +class VarDimNoneDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=None, keepdim=False) + + +@register_test_case(module_factory=lambda: VarDimNoneDimModule()) +def VarDimNoneDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + class VarCorrectionModule(torch.nn.Module): def __init__(self):