mirror of https://github.com/llvm/torch-mlir
[LINALG] Add E2E support for `aten.mean.dim` op
- This commit adds support for `aten.mean.dim` op. - It also adds a new test script `stats.py` for statistics related ops. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/829/head
parent
32159c4e54
commit
4b911ada40
|
@ -22,6 +22,18 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// Helper function to check whether the `dtype` is None or Float type.
|
||||
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
||||
if (dtype.getType().isa<Torch::NoneType>())
|
||||
return true;
|
||||
int64_t dtypeInt;
|
||||
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
return false;
|
||||
Type resDtype =
|
||||
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
return resDtype.isa<mlir::FloatType>();
|
||||
}
|
||||
|
||||
// Helper function to compute the return type of the reduction function.
|
||||
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
|
||||
// the input tensor.
|
||||
|
@ -840,6 +852,54 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// productDimSize = product(size(dim) for dim in dims)
|
||||
// aten.mean(x, dims) = aten.sum(x, dims) / productDimSize.
|
||||
namespace {
|
||||
class DecomposeAtenMeanDimOp : public OpRewritePattern<AtenMeanDimOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenMeanDimOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
Value dimList = op.dim();
|
||||
Value keepDim = op.keepdim();
|
||||
Value dtype = op.dtype();
|
||||
Type outputType = op.getType();
|
||||
MLIRContext *context = op.getContext();
|
||||
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
|
||||
!isNoneOrFloatDtype(context, dtype)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only floating-point type is supported");
|
||||
}
|
||||
|
||||
auto dimListConstruct = dimList.getDefiningOp<PrimListConstructOp>();
|
||||
if (!dimListConstruct) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "expect dimList to be constructed from list construct");
|
||||
}
|
||||
|
||||
// Compute sum along dimensions specified in `dimList`.
|
||||
Value sumAlongDims = rewriter.create<AtenSumDimIntListOp>(
|
||||
loc, outputType, input, dimList, keepDim, dtype);
|
||||
|
||||
// `productDimSize` is product of sizes of dimensions to be reduced.
|
||||
Value productDimSize = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
for (Value dim : dimListConstruct.elements()) {
|
||||
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
||||
productDimSize =
|
||||
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
|
||||
productDimSize);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSquareOp : public OpRewritePattern<AtenSquareOp> {
|
||||
public:
|
||||
|
@ -1776,6 +1836,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenAddmmOp>();
|
||||
patterns.add<DecomposeAtenMeanOp>(context);
|
||||
target.addIllegalOp<AtenMeanOp>();
|
||||
patterns.add<DecomposeAtenMeanDimOp>(context);
|
||||
target.addIllegalOp<AtenMeanDimOp>();
|
||||
patterns.add<DecomposeAtenSelectIntOp>(context);
|
||||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
|
|
|
@ -52,3 +52,4 @@ def register_all_tests():
|
|||
from . import pooling
|
||||
from . import return_types
|
||||
from . import control_flow
|
||||
from . import stats
|
||||
|
|
|
@ -759,6 +759,9 @@ def _LogSoftmaxModule_basic(module, tu: TestUtils):
|
|||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class _LogSoftmaxModuleStable(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1143,50 +1146,6 @@ def DropoutTrainModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.mean(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanModule())
|
||||
def MeanModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class MeanDynamicSizesModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.mean(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDynamicSizesModule())
|
||||
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class NumelModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1502,94 +1461,6 @@ def SquareModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class VarUnbiasedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarUnbiasedModule())
|
||||
def VarUnbiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class VarBiasedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: VarBiasedModule())
|
||||
def VarBiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdUnbiasedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdUnbiasedModule())
|
||||
def StdUnbiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class StdBiasedModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: StdBiasedModule())
|
||||
def StdBiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class HardswishModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1609,6 +1480,9 @@ def HardswishModule_basic(module, tu: TestUtils):
|
|||
module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class HardswishRandomModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1693,6 +1567,7 @@ class HardTanhIntModule(torch.nn.Module):
|
|||
def HardTanhIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-5, 5, (100, 100)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BincountModule(torch.nn.Module):
|
||||
|
||||
|
@ -1712,6 +1587,7 @@ class BincountModule(torch.nn.Module):
|
|||
def BincountModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (1000, )))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BincountStaticSizeModule(torch.nn.Module):
|
||||
|
||||
|
@ -1731,6 +1607,7 @@ class BincountStaticSizeModule(torch.nn.Module):
|
|||
def BincountStaticSizeModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (200, )))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class BincountMinlengthModule(torch.nn.Module):
|
||||
|
||||
|
|
|
@ -220,25 +220,6 @@ def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceMeanDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.mean(a, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMeanDtypeModule())
|
||||
def ReduceMeanDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceMaxAlongDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -0,0 +1,253 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanModule())
|
||||
def MeanModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDynamicSizesModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDynamicSizesModule())
|
||||
def MeanDynamicSizesModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDtypeModule())
|
||||
def MeanDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 2))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDimModule())
|
||||
def MeanDimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 7))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDimDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, 0, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDimDtypeModule())
|
||||
def MeanDimDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDimKeepdimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (1, 2), keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDimKeepdimModule())
|
||||
def MeanDimKeepdimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDimAllReduceModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 1, 2))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDimAllReduceModule())
|
||||
def MeanDimAllReduceModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDimAllReduceKeepdimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 1, 2), keepdim=True)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDimAllReduceKeepdimModule())
|
||||
def MeanDimAllReduceKeepdimModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MeanDimNegativeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (-1, 1))
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MeanDimNegativeModule())
|
||||
def MeanDimNegativeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class VarUnbiasedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: VarUnbiasedModule())
|
||||
def VarUnbiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class VarBiasedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: VarBiasedModule())
|
||||
def VarBiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class StdUnbiasedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: StdUnbiasedModule())
|
||||
def StdUnbiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class StdBiasedModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=False)
|
||||
|
||||
@register_test_case(module_factory=lambda: StdBiasedModule())
|
||||
def StdBiasedModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
Loading…
Reference in New Issue