[LINALG] Add E2E support for `aten.mean.dim` op

- This commit adds support for `aten.mean.dim` op.
- It also adds a new test script `stats.py` for statistics related ops.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
pull/829/head
Gaurav Shukla 2022-03-10 22:55:21 +05:30
parent 32159c4e54
commit 4b911ada40
5 changed files with 325 additions and 151 deletions

View File

@ -22,6 +22,18 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// Helper function to check whether the `dtype` is None or Float type.
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
if (dtype.getType().isa<Torch::NoneType>())
return true;
int64_t dtypeInt;
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
return false;
Type resDtype =
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
return resDtype.isa<mlir::FloatType>();
}
// Helper function to compute the return type of the reduction function.
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
// the input tensor.
@ -840,6 +852,54 @@ public:
};
} // namespace
// productDimSize = product(size(dim) for dim in dims)
// aten.mean(x, dims) = aten.sum(x, dims) / productDimSize.
namespace {
class DecomposeAtenMeanDimOp : public OpRewritePattern<AtenMeanDimOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMeanDimOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
Value dimList = op.dim();
Value keepDim = op.keepdim();
Value dtype = op.dtype();
Type outputType = op.getType();
MLIRContext *context = op.getContext();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
!isNoneOrFloatDtype(context, dtype)) {
return rewriter.notifyMatchFailure(
op, "only floating-point type is supported");
}
auto dimListConstruct = dimList.getDefiningOp<PrimListConstructOp>();
if (!dimListConstruct) {
return rewriter.notifyMatchFailure(
op, "expect dimList to be constructed from list construct");
}
// Compute sum along dimensions specified in `dimList`.
Value sumAlongDims = rewriter.create<AtenSumDimIntListOp>(
loc, outputType, input, dimList, keepDim, dtype);
// `productDimSize` is product of sizes of dimensions to be reduced.
Value productDimSize = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
for (Value dim : dimListConstruct.elements()) {
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
productDimSize =
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
}
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
productDimSize);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
public:
@ -1776,6 +1836,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenAddmmOp>();
patterns.add<DecomposeAtenMeanOp>(context);
target.addIllegalOp<AtenMeanOp>();
patterns.add<DecomposeAtenMeanDimOp>(context);
target.addIllegalOp<AtenMeanDimOp>();
patterns.add<DecomposeAtenSelectIntOp>(context);
target.addIllegalOp<AtenSelectIntOp>();
patterns.add<DecomposeAtenMatmulOp>(context);

View File

@ -52,3 +52,4 @@ def register_all_tests():
from . import pooling
from . import return_types
from . import control_flow
from . import stats

View File

@ -759,6 +759,9 @@ def _LogSoftmaxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
# ==============================================================================
class _LogSoftmaxModuleStable(torch.nn.Module):
def __init__(self):
@ -1143,50 +1146,6 @@ def DropoutTrainModule_basic(module, tu: TestUtils):
# ==============================================================================
class MeanModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.float32, True),
])
def forward(self, x):
return torch.mean(x)
@register_test_case(module_factory=lambda: MeanModule())
def MeanModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 4))
# ==============================================================================
class MeanDynamicSizesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.mean(x)
@register_test_case(module_factory=lambda: MeanDynamicSizesModule())
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 4))
# ==============================================================================
class NumelModule(torch.nn.Module):
def __init__(self):
@ -1502,94 +1461,6 @@ def SquareModule_basic(module, tu: TestUtils):
# ==============================================================================
class VarUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, unbiased=True)
@register_test_case(module_factory=lambda: VarUnbiasedModule())
def VarUnbiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class VarBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, unbiased=False)
@register_test_case(module_factory=lambda: VarBiasedModule())
def VarBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class StdUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, unbiased=True)
@register_test_case(module_factory=lambda: StdUnbiasedModule())
def StdUnbiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class StdBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, unbiased=False)
@register_test_case(module_factory=lambda: StdBiasedModule())
def StdBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class HardswishModule(torch.nn.Module):
def __init__(self):
@ -1609,6 +1480,9 @@ def HardswishModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]]))
# ==============================================================================
class HardswishRandomModule(torch.nn.Module):
def __init__(self):
@ -1693,6 +1567,7 @@ class HardTanhIntModule(torch.nn.Module):
def HardTanhIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(-5, 5, (100, 100)))
# ==============================================================================
class BincountModule(torch.nn.Module):
@ -1712,6 +1587,7 @@ class BincountModule(torch.nn.Module):
def BincountModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (1000, )))
# ==============================================================================
class BincountStaticSizeModule(torch.nn.Module):
@ -1731,6 +1607,7 @@ class BincountStaticSizeModule(torch.nn.Module):
def BincountStaticSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randint(100, (200, )))
# ==============================================================================
class BincountMinlengthModule(torch.nn.Module):

View File

@ -220,25 +220,6 @@ def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class ReduceMeanDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.mean(a, dtype=torch.float32)
@register_test_case(module_factory=lambda: ReduceMeanDtypeModule())
def ReduceMeanDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceMaxAlongDim(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -0,0 +1,253 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class MeanModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x)
@register_test_case(module_factory=lambda: MeanModule())
def MeanModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class MeanDynamicSizesModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x)
@register_test_case(module_factory=lambda: MeanDynamicSizesModule())
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class MeanDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, dtype=torch.float32)
@register_test_case(module_factory=lambda: MeanDtypeModule())
def MeanDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class MeanDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (0, 2))
@register_test_case(module_factory=lambda: MeanDimModule())
def MeanDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 7))
# ==============================================================================
class MeanDimDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, 0, dtype=torch.float32)
@register_test_case(module_factory=lambda: MeanDimDtypeModule())
def MeanDimDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class MeanDimKeepdimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (1, 2), keepdim=True)
@register_test_case(module_factory=lambda: MeanDimKeepdimModule())
def MeanDimKeepdimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class MeanDimAllReduceModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (0, 1, 2))
@register_test_case(module_factory=lambda: MeanDimAllReduceModule())
def MeanDimAllReduceModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class MeanDimAllReduceKeepdimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (0, 1, 2), keepdim=True)
@register_test_case(module_factory=lambda: MeanDimAllReduceKeepdimModule())
def MeanDimAllReduceKeepdimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class MeanDimNegativeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (-1, 1))
@register_test_case(module_factory=lambda: MeanDimNegativeModule())
def MeanDimNegativeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
# ==============================================================================
class VarUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, unbiased=True)
@register_test_case(module_factory=lambda: VarUnbiasedModule())
def VarUnbiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class VarBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, unbiased=False)
@register_test_case(module_factory=lambda: VarBiasedModule())
def VarBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class StdUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, unbiased=True)
@register_test_case(module_factory=lambda: StdUnbiasedModule())
def StdUnbiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class StdBiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.std(x, unbiased=False)
@register_test_case(module_factory=lambda: StdBiasedModule())
def StdBiasedModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))