mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix shape calculation for aten::pow.Tensor_Tensor op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1320/head snapshot-20220909.591
parent
e35741fb1d
commit
326f21229e
|
@ -5695,7 +5695,7 @@ module {
|
|||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.pow.Tensor_Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.leaky_relu"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
|
||||
|
|
|
@ -494,7 +494,7 @@ def aten〇pow〇Tensor_Scalar(self: List[int], exponent: float) -> List[int]:
|
|||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇pow〇Tensor_Tensor(self: List[int], exponent: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
return upstream_shape_functions.broadcast(self, exponent)
|
||||
|
||||
def aten〇rsub〇Scalar(self: List[int], other: float, alpha: float = 1) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
|
|
@ -1125,8 +1125,8 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, 1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.pow(a, b)
|
||||
|
@ -1134,7 +1134,7 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastModule())
|
||||
def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(3, 1))
|
||||
module.forward(tu.rand(3, 1), tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
|
Loading…
Reference in New Issue