[MLIR][TORCH] Fix shape calculation for aten::pow.Tensor_Tensor op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1320/head snapshot-20220909.591
Vivek Khandelwal 2022-09-08 18:20:47 +05:30
parent e35741fb1d
commit 326f21229e
3 changed files with 4 additions and 4 deletions

View File

@ -5695,7 +5695,7 @@ module {
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
func.func @"__torch_mlir_shape_fn.aten.pow.Tensor_Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> { func.func @"__torch_mlir_shape_fn.aten.pow.Tensor_Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
func.func @"__torch_mlir_shape_fn.aten.leaky_relu"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> { func.func @"__torch_mlir_shape_fn.aten.leaky_relu"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {

View File

@ -494,7 +494,7 @@ def atenpowTensor_Scalar(self: List[int], exponent: float) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
def atenpowTensor_Tensor(self: List[int], exponent: List[int]) -> List[int]: def atenpowTensor_Tensor(self: List[int], exponent: List[int]) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.broadcast(self, exponent)
def atenrsubScalar(self: List[int], other: float, alpha: float = 1) -> List[int]: def atenrsubScalar(self: List[int], other: float, alpha: float = 1) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)

View File

@ -1125,8 +1125,8 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module):
@export @export
@annotate_args([ @annotate_args([
None, None,
([-1, -1], torch.float32, True),
([-1, 1], torch.float32, True), ([-1, 1], torch.float32, True),
([-1, -1], torch.float32, True),
]) ])
def forward(self, a, b): def forward(self, a, b):
return torch.pow(a, b) return torch.pow(a, b)
@ -1134,7 +1134,7 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastModule()) @register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastModule())
def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils): def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4), tu.rand(3, 1)) module.forward(tu.rand(3, 1), tu.rand(3, 4))
# ============================================================================== # ==============================================================================