diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 4f76f8164..a38bcc73c 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -22,6 +22,18 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +// Helper function to check whether the `dtype` is None or Float type. +static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { + if (dtype.getType().isa()) + return true; + int64_t dtypeInt; + if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) + return false; + Type resDtype = + getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + return resDtype.isa(); +} + // Helper function to compute the return type of the reduction function. // `dim` specifies the dimension to reduce and `keepDim` preserves the rank of // the input tensor. @@ -840,6 +852,54 @@ public: }; } // namespace +// productDimSize = product(size(dim) for dim in dims) +// aten.mean(x, dims) = aten.sum(x, dims) / productDimSize. +namespace { +class DecomposeAtenMeanDimOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMeanDimOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.self(); + Value dimList = op.dim(); + Value keepDim = op.keepdim(); + Value dtype = op.dtype(); + Type outputType = op.getType(); + MLIRContext *context = op.getContext(); + + BaseTensorType inputType = input.getType().cast(); + if (!inputType.hasDtype() || !inputType.getDtype().isa() || + !isNoneOrFloatDtype(context, dtype)) { + return rewriter.notifyMatchFailure( + op, "only floating-point type is supported"); + } + + auto dimListConstruct = dimList.getDefiningOp(); + if (!dimListConstruct) { + return rewriter.notifyMatchFailure( + op, "expect dimList to be constructed from list construct"); + } + + // Compute sum along dimensions specified in `dimList`. + Value sumAlongDims = rewriter.create( + loc, outputType, input, dimList, keepDim, dtype); + + // `productDimSize` is product of sizes of dimensions to be reduced. + Value productDimSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + for (Value dim : dimListConstruct.elements()) { + Value dimSize = rewriter.create(loc, input, dim); + productDimSize = + rewriter.create(loc, productDimSize, dimSize); + } + rewriter.replaceOpWithNewOp(op, outputType, sumAlongDims, + productDimSize); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenSquareOp : public OpRewritePattern { public: @@ -1776,6 +1836,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/python/torch_mlir_e2e_test/test_suite/__init__.py b/python/torch_mlir_e2e_test/test_suite/__init__.py index c8696a3e4..e088e4fad 100644 --- a/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -52,3 +52,4 @@ def register_all_tests(): from . import pooling from . import return_types from . import control_flow + from . import stats diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index a39014caa..3cf5837fc 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -759,6 +759,9 @@ def _LogSoftmaxModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4)) +# ============================================================================== + + class _LogSoftmaxModuleStable(torch.nn.Module): def __init__(self): @@ -1143,50 +1146,6 @@ def DropoutTrainModule_basic(module, tu: TestUtils): # ============================================================================== -class MeanModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([3, 4], torch.float32, True), - ]) - def forward(self, x): - return torch.mean(x) - - -@register_test_case(module_factory=lambda: MeanModule()) -def MeanModule_basic(module, tu: TestUtils): - module.forward(torch.randn(3, 4)) - - -# ============================================================================== - - -class MeanDynamicSizesModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True), - ]) - def forward(self, x): - return torch.mean(x) - - -@register_test_case(module_factory=lambda: MeanDynamicSizesModule()) -def MeanDynamicSizesModule_basic(module, tu: TestUtils): - module.forward(torch.randn(3, 4)) - - -# ============================================================================== - - class NumelModule(torch.nn.Module): def __init__(self): @@ -1502,94 +1461,6 @@ def SquareModule_basic(module, tu: TestUtils): # ============================================================================== -class VarUnbiasedModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, x): - return torch.ops.aten.var(x, unbiased=True) - - -@register_test_case(module_factory=lambda: VarUnbiasedModule()) -def VarUnbiasedModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4)) - - -# ============================================================================== - - -class VarBiasedModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, x): - return torch.ops.aten.var(x, unbiased=False) - - -@register_test_case(module_factory=lambda: VarBiasedModule()) -def VarBiasedModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4)) - - -# ============================================================================== - - -class StdUnbiasedModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, x): - return torch.ops.aten.std(x, unbiased=True) - - -@register_test_case(module_factory=lambda: StdUnbiasedModule()) -def StdUnbiasedModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4)) - - -# ============================================================================== - - -class StdBiasedModule(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float32, True), - ]) - def forward(self, x): - return torch.ops.aten.std(x, unbiased=False) - - -@register_test_case(module_factory=lambda: StdBiasedModule()) -def StdBiasedModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4)) - - -# ============================================================================== - - class HardswishModule(torch.nn.Module): def __init__(self): @@ -1609,6 +1480,9 @@ def HardswishModule_basic(module, tu: TestUtils): module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]])) +# ============================================================================== + + class HardswishRandomModule(torch.nn.Module): def __init__(self): @@ -1693,6 +1567,7 @@ class HardTanhIntModule(torch.nn.Module): def HardTanhIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(-5, 5, (100, 100))) +# ============================================================================== class BincountModule(torch.nn.Module): @@ -1712,6 +1587,7 @@ class BincountModule(torch.nn.Module): def BincountModule_basic(module, tu: TestUtils): module.forward(torch.randint(10, (1000, ))) +# ============================================================================== class BincountStaticSizeModule(torch.nn.Module): @@ -1731,6 +1607,7 @@ class BincountStaticSizeModule(torch.nn.Module): def BincountStaticSizeModule_basic(module, tu: TestUtils): module.forward(torch.randint(100, (200, ))) +# ============================================================================== class BincountMinlengthModule(torch.nn.Module): diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index 5727d4a2d..35711201c 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -220,25 +220,6 @@ def ReduceSumDimIntListKeepDimIntModule_basic(module, tu: TestUtils): # ============================================================================== -class ReduceMeanDtypeModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1], torch.float64, True), - ]) - def forward(self, a): - return torch.mean(a, dtype=torch.float32) - - -@register_test_case(module_factory=lambda: ReduceMeanDtypeModule()) -def ReduceMeanDtypeModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5).to(torch.float64)) - -# ============================================================================== - class ReduceMaxAlongDim(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/stats.py b/python/torch_mlir_e2e_test/test_suite/stats.py new file mode 100644 index 000000000..05769b652 --- /dev/null +++ b/python/torch_mlir_e2e_test/test_suite/stats.py @@ -0,0 +1,253 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + +class MeanModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x) + + +@register_test_case(module_factory=lambda: MeanModule()) +def MeanModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class MeanDynamicSizesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x) + + +@register_test_case(module_factory=lambda: MeanDynamicSizesModule()) +def MeanDynamicSizesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class MeanDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, dtype=torch.float32) + + +@register_test_case(module_factory=lambda: MeanDtypeModule()) +def MeanDtypeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float64)) + +# ============================================================================== + +class MeanDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, (0, 2)) + + +@register_test_case(module_factory=lambda: MeanDimModule()) +def MeanDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + +# ============================================================================== + +class MeanDimDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, 0, dtype=torch.float32) + + +@register_test_case(module_factory=lambda: MeanDimDtypeModule()) +def MeanDimDtypeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float64)) + +# ============================================================================== + +class MeanDimKeepdimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, (1, 2), keepdim=True) + + +@register_test_case(module_factory=lambda: MeanDimKeepdimModule()) +def MeanDimKeepdimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class MeanDimAllReduceModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, (0, 1, 2)) + + +@register_test_case(module_factory=lambda: MeanDimAllReduceModule()) +def MeanDimAllReduceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class MeanDimAllReduceKeepdimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, (0, 1, 2), keepdim=True) + + +@register_test_case(module_factory=lambda: MeanDimAllReduceKeepdimModule()) +def MeanDimAllReduceKeepdimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class MeanDimNegativeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, (-1, 1)) + + +@register_test_case(module_factory=lambda: MeanDimNegativeModule()) +def MeanDimNegativeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class VarUnbiasedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, unbiased=True) + +@register_test_case(module_factory=lambda: VarUnbiasedModule()) +def VarUnbiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + +class VarBiasedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, unbiased=False) + +@register_test_case(module_factory=lambda: VarBiasedModule()) +def VarBiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + +class StdUnbiasedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, unbiased=True) + +@register_test_case(module_factory=lambda: StdUnbiasedModule()) +def StdUnbiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + +# ============================================================================== + +class StdBiasedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.std(x, unbiased=False) + +@register_test_case(module_factory=lambda: StdBiasedModule()) +def StdBiasedModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4))