mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix shape calculation for aten::pow.Tensor_Tensor op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1320/head snapshot-20220909.591
parent
e35741fb1d
commit
326f21229e
|
@ -5695,7 +5695,7 @@ module {
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.pow.Tensor_Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.pow.Tensor_Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.leaky_relu"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.leaky_relu"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
|
||||||
|
|
|
@ -494,7 +494,7 @@ def aten〇pow〇Tensor_Scalar(self: List[int], exponent: float) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇pow〇Tensor_Tensor(self: List[int], exponent: List[int]) -> List[int]:
|
def aten〇pow〇Tensor_Tensor(self: List[int], exponent: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.broadcast(self, exponent)
|
||||||
|
|
||||||
def aten〇rsub〇Scalar(self: List[int], other: float, alpha: float = 1) -> List[int]:
|
def aten〇rsub〇Scalar(self: List[int], other: float, alpha: float = 1) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
|
@ -1125,8 +1125,8 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.float32, True),
|
|
||||||
([-1, 1], torch.float32, True),
|
([-1, 1], torch.float32, True),
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
])
|
])
|
||||||
def forward(self, a, b):
|
def forward(self, a, b):
|
||||||
return torch.pow(a, b)
|
return torch.pow(a, b)
|
||||||
|
@ -1134,7 +1134,7 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastModule())
|
@register_test_case(module_factory=lambda: ElementwisePowTensorBroadcastModule())
|
||||||
def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils):
|
def ElementwisePowTensorBroadcastModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4), tu.rand(3, 1))
|
module.forward(tu.rand(3, 1), tu.rand(3, 4))
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
Loading…
Reference in New Issue