mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH} Fix empty dim cases for the .dim ops
This commit fixes the shape calculation for: 1.) aten.mean.dim 2.) aten.var.dim 3.) aten.sum.dim_IntList op Also, it fixes the lowering of `aten.mean.dim` and `aten.sum.dim_IntList` for handling the cases of empty dim list. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.compull/1122/head
parent
d386b8f9e5
commit
c681c3497a
|
@ -270,6 +270,8 @@ private:
|
||||||
"`keepdim` must be a constant bool");
|
"`keepdim` must be a constant bool");
|
||||||
|
|
||||||
SmallVector<int64_t> dimList;
|
SmallVector<int64_t> dimList;
|
||||||
|
bool isNoneOrEmptyDimList =
|
||||||
|
op.dim().getType().template isa<Torch::NoneType>();
|
||||||
if (matchPattern(op.dim(), m_TorchConstantIntList(dimList))) {
|
if (matchPattern(op.dim(), m_TorchConstantIntList(dimList))) {
|
||||||
// Fix negative dimensions, if any, before adding to the list.
|
// Fix negative dimensions, if any, before adding to the list.
|
||||||
for (int64_t dim : dimList) {
|
for (int64_t dim : dimList) {
|
||||||
|
@ -278,13 +280,16 @@ private:
|
||||||
if (isValidDim(dim, inputType.getRank()))
|
if (isValidDim(dim, inputType.getRank()))
|
||||||
opInfo.dimSet.insert(dim);
|
opInfo.dimSet.insert(dim);
|
||||||
}
|
}
|
||||||
} else if (op.dim().getType().template isa<Torch::NoneType>()) {
|
if (dimList.empty())
|
||||||
|
isNoneOrEmptyDimList = true;
|
||||||
|
} else if (!isNoneOrEmptyDimList) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "`dim` argument must be a constant int list or None");
|
||||||
|
}
|
||||||
|
if (isNoneOrEmptyDimList) {
|
||||||
// If no dimensions were specified, reduce along all dimensions
|
// If no dimensions were specified, reduce along all dimensions
|
||||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||||
opInfo.dimSet.insert(i);
|
opInfo.dimSet.insert(i);
|
||||||
} else {
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "`dim` argument must be a constant int list or None");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return opInfo;
|
return opInfo;
|
||||||
|
|
|
@ -1012,6 +1012,7 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.self();
|
Value input = op.self();
|
||||||
|
unsigned inputRank = getTensorRank(input);
|
||||||
Value dimList = op.dim();
|
Value dimList = op.dim();
|
||||||
Value keepDim = op.keepdim();
|
Value keepDim = op.keepdim();
|
||||||
Value dtype = op.dtype();
|
Value dtype = op.dtype();
|
||||||
|
@ -1036,12 +1037,18 @@ public:
|
||||||
loc, outputType, input, dimList, keepDim, dtype);
|
loc, outputType, input, dimList, keepDim, dtype);
|
||||||
|
|
||||||
// `productDimSize` is product of sizes of dimensions to be reduced.
|
// `productDimSize` is product of sizes of dimensions to be reduced.
|
||||||
Value productDimSize = rewriter.create<Torch::ConstantIntOp>(
|
Value productDimSize;
|
||||||
loc, rewriter.getI64IntegerAttr(1));
|
// Case: Reduce along all dims.
|
||||||
for (Value dim : dimListConstruct.elements()) {
|
if (dimListConstruct.elements().empty() && inputRank != 0) {
|
||||||
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
productDimSize = rewriter.create<AtenNumelOp>(loc, input);
|
||||||
productDimSize =
|
} else {
|
||||||
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
productDimSize = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
for (Value dim : dimListConstruct.elements()) {
|
||||||
|
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
||||||
|
productDimSize =
|
||||||
|
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
|
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
|
||||||
productDimSize);
|
productDimSize);
|
||||||
|
|
|
@ -5566,9 +5566,25 @@ module {
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
|
||||||
%none = torch.constant.none
|
%none = torch.constant.none
|
||||||
%0 = torch.derefine %none : !torch.none to !torch.any
|
%true = torch.constant.bool true
|
||||||
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
%int0 = torch.constant.int 0
|
||||||
return %1 : !torch.list<int>
|
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||||
|
%1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||||
|
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
|
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
torch.prim.Loop %5, %true, init() {
|
||||||
|
^bb0(%arg4: !torch.int):
|
||||||
|
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
torch.prim.Loop.condition %true, iter()
|
||||||
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
|
torch.prim.If.yield %6 : !torch.list<int>
|
||||||
|
} else {
|
||||||
|
torch.prim.If.yield %arg1 : !torch.list<int>
|
||||||
|
}
|
||||||
|
%3 = torch.derefine %none : !torch.none to !torch.any
|
||||||
|
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||||
|
return %4 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
|
@ -5657,25 +5673,55 @@ module {
|
||||||
return %1 : !torch.tuple<list<int>, list<int>>
|
return %1 : !torch.tuple<list<int>, list<int>>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
|
||||||
%0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
%true = torch.constant.bool true
|
||||||
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
%int0 = torch.constant.int 0
|
||||||
return %1 : !torch.list<int>
|
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||||
|
%1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
|
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||||
|
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
|
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
torch.prim.Loop %5, %true, init() {
|
||||||
|
^bb0(%arg4: !torch.int):
|
||||||
|
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
torch.prim.Loop.condition %true, iter()
|
||||||
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
|
torch.prim.If.yield %6 : !torch.list<int>
|
||||||
|
} else {
|
||||||
|
torch.prim.If.yield %arg1 : !torch.list<int>
|
||||||
|
}
|
||||||
|
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
||||||
|
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||||
|
return %4 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
|
||||||
|
%true = torch.constant.bool true
|
||||||
%none = torch.constant.none
|
%none = torch.constant.none
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
|
||||||
%1 = torch.prim.If %0 -> (!torch.list<int>) {
|
%1 = torch.prim.If %0 -> (!torch.bool) {
|
||||||
%2 = torch.prim.ListConstruct : () -> !torch.list<int>
|
torch.prim.If.yield %true : !torch.bool
|
||||||
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
|
||||||
%4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
|
||||||
torch.prim.If.yield %4 : !torch.list<int>
|
|
||||||
} else {
|
} else {
|
||||||
%2 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
|
||||||
%4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
|
||||||
torch.prim.If.yield %4 : !torch.list<int>
|
torch.prim.If.yield %7 : !torch.bool
|
||||||
}
|
}
|
||||||
return %1 : !torch.list<int>
|
%2 = torch.prim.If %1 -> (!torch.list<int>) {
|
||||||
|
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
|
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
torch.prim.Loop %5, %true, init() {
|
||||||
|
^bb0(%arg4: !torch.int):
|
||||||
|
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||||
|
torch.prim.Loop.condition %true, iter()
|
||||||
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
|
torch.prim.If.yield %6 : !torch.list<int>
|
||||||
|
} else {
|
||||||
|
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
|
||||||
|
torch.prim.If.yield %5 : !torch.list<int>
|
||||||
|
}
|
||||||
|
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
||||||
|
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
|
||||||
|
return %4 : !torch.list<int>
|
||||||
}
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
|
|
|
@ -490,6 +490,8 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||||
|
if len(dim)==0:
|
||||||
|
dim = list(range(len(self)))
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
|
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)
|
||||||
|
|
||||||
def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
|
def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
|
||||||
|
@ -533,13 +535,14 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[
|
||||||
return reduced_shape, reduced_shape
|
return reduced_shape, reduced_shape
|
||||||
|
|
||||||
def aten〇mean〇dim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
def aten〇mean〇dim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||||
|
if len(dim)==0:
|
||||||
|
dim = list(range(len(self)))
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
||||||
|
|
||||||
def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
|
||||||
if dim is None:
|
if dim is None or len(dim)==0:
|
||||||
return upstream_shape_functions.mean_dim(self, [], keepdim, dtype)
|
dim = list(range(len(self)))
|
||||||
else:
|
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
||||||
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
|
|
||||||
|
|
||||||
def aten〇permute(self: List[int], dims: List[int]) -> List[int]:
|
def aten〇permute(self: List[int], dims: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.permute(self, dims)
|
return upstream_shape_functions.permute(self, dims)
|
||||||
|
|
|
@ -106,6 +106,25 @@ def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils):
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ReduceSumDimIntListEmptyDimModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a):
|
||||||
|
return torch.sum(a, dim=[])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ReduceSumDimIntListEmptyDimModule())
|
||||||
|
def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
class ReduceSumUnsignedIntModule(torch.nn.Module):
|
class ReduceSumUnsignedIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -180,6 +180,26 @@ class MeanDimNegativeModule(torch.nn.Module):
|
||||||
def MeanDimNegativeModule_basic(module, tu: TestUtils):
|
def MeanDimNegativeModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5))
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class MeanDimEmptyDimModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.mean(x, dim=[])
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: MeanDimEmptyDimModule())
|
||||||
|
def MeanDimEmptyDimModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
class VarUnbiasedModule(torch.nn.Module):
|
class VarUnbiasedModule(torch.nn.Module):
|
||||||
|
@ -410,7 +430,7 @@ def VarDimNegativeModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class VarDimKeepDimFalseModule(torch.nn.Module):
|
class VarDimEmptyDimModule(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -421,11 +441,11 @@ class VarDimKeepDimFalseModule(torch.nn.Module):
|
||||||
([-1, -1, -1], torch.float32, True),
|
([-1, -1, -1], torch.float32, True),
|
||||||
])
|
])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.ops.aten.var(x, dim=(0, 1, 2), keepdim=False)
|
return torch.ops.aten.var(x, dim=[], keepdim=False)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: VarDimKeepDimFalseModule())
|
@register_test_case(module_factory=lambda: VarDimEmptyDimModule())
|
||||||
def VarDimKeepDimFalseModule_basic(module, tu: TestUtils):
|
def VarDimEmptyDimModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4, 5))
|
module.forward(tu.rand(3, 4, 5))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue