From 326f21229e684c3adcdcde88f439dca5aeb6e6d7 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 8 Sep 2022 18:20:47 +0530 Subject: [PATCH] [MLIR][TORCH] Fix shape calculation for aten::pow.Tensor_Tensor op Signed-Off By: Vivek Khandelwal --- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 2 +- .../torch/importer/jit_ir/build_tools/shape_lib_gen.py | 2 +- python/torch_mlir_e2e_test/test_suite/elementwise.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 839ea16b9..c53acdf82 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5695,7 +5695,7 @@ module { return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.pow.Tensor_Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.leaky_relu"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list { diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index a530d1ecd..677454fdd 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -494,7 +494,7 @@ def aten〇pow〇Tensor_Scalar(self: List[int], exponent: float) -> List[int]: return upstream_shape_functions.unary(self) def aten〇pow〇Tensor_Tensor(self: List[int], exponent: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) + return upstream_shape_functions.broadcast(self, exponent) def aten〇rsub〇Scalar(self: List[int], other: float, alpha: float = 1) -> List[int]: return upstream_shape_functions.unary(self) diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index b21a06dbf..68e84a07f 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1125,8 +1125,8 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module): @export @annotate_args([ None, - ([-1, -1], torch.float32, True), ([-1, 1], torch.float32, True), + ([-1, -1], torch.float32, True), ]) def forward(self, a, b): return torch.pow(a, b) @@ -1134,7 +1134,7 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastModule()) def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4), tu.rand(3, 1)) + module.forward(tu.rand(3, 1), tu.rand(3, 4)) # ==============================================================================