2021-09-30 00:03:40 +08:00
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
2021-07-01 05:13:21 +08:00
# This file describes the sets of tests expected to fail for each config.
# This information is deliberately kept in a side table, rather than
# in-situ on the test, as a deliberate layering decision: tests should
# have unique keys to identify them and enable side tables of various kinds
# (this includes down into lower parts of the stack, where a side table
# might be used to keep more elaborate sets of testing configurations).
2022-04-20 03:35:56 +08:00
from torch_mlir_e2e_test . test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
2023-07-13 21:07:54 +08:00
from torch_mlir . _version import torch_version_for_comparison , version
2022-04-20 03:35:56 +08:00
2023-12-08 15:13:42 +08:00
print ( f " TORCH_VERSION_FOR_COMPARISON = " , torch_version_for_comparison ( ) )
2023-08-18 23:15:54 +08:00
LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
2023-08-27 21:56:36 +08:00
" Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier " ,
2023-10-17 00:44:53 +08:00
" IscloseStaticModule_basic " ,
2024-03-06 07:01:21 +08:00
" IscloseStaticModuleTrue_basic " ,
" SplitWithSizes_Module_basic " ,
2023-08-18 23:15:54 +08:00
}
2021-08-11 07:10:31 +08:00
2024-04-19 02:47:19 +08:00
LINALG_CRASHING_SET = {
# Crashes due to copy to a smaller destination buffer than the source buffer.
" SliceCopyStartGreaterThanDimSize_Module_basic " ,
}
2022-11-18 20:21:19 +08:00
TORCHDYNAMO_XFAIL_SET = {
#### General TorchDynamo/PyTorch errors
2024-02-03 02:46:33 +08:00
# torch._dynamo.exc.Unsupported: Tensor.item
" CumsumModule_basic " ,
2022-11-18 20:21:19 +08:00
# TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0
# RuntimeError: Failed running call_function aten.convolution_backward(...
2022-11-24 22:36:13 +08:00
# https://github.com/pytorch/pytorch/issues/89629
2022-11-18 20:21:19 +08:00
" ConvolutionBackwardModule2DPadded_basic " ,
" ConvolutionBackwardModule2D_basic " ,
2023-02-07 23:25:59 +08:00
2024-02-09 01:37:31 +08:00
# Size result mismatch (exposed by downstream canonicalizer
# on incompatabile casts).
# https://github.com/pytorch/pytorch/issues/119407
" ConvolutionBackwardModule2DStrided_basic " ,
2022-11-18 20:21:19 +08:00
# RuntimeError: Index tensor must have the same number of dimensions as self tensor
# RuntimeError: Failed running call_function aten.nll_loss_backward(...
2022-11-24 22:36:13 +08:00
# https://github.com/pytorch/pytorch/issues/89630
2022-11-18 20:21:19 +08:00
" NllLossModuleBackward1DMeanWeight_basic " ,
" NllLossModuleBackward1DMean_basic " ,
" NllLossModuleBackward1DSumWeight_basic " ,
" NllLossModuleBackward1DSum_basic " ,
" NllLossModuleBackward1DWeight_basic " ,
" NllLossModuleBackward1D_basic " ,
2022-11-29 22:01:42 +08:00
# TypeError: uniform() missing 2 required keyword-only arguments: 'dtype' and 'device'
# RuntimeError: Failed running call_function aten.uniform(...
# https://github.com/pytorch/torchdynamo/issues/1954
" UniformNoCorrelationModule_basic " ,
2022-11-18 20:21:19 +08:00
#### Torch-MLIR internal compiler errors
# These are probably due to slightly different ops being recorded by
# torchdynamo vs. torchscript.
2022-11-29 22:01:42 +08:00
# No upstream decompositions.
2022-11-25 20:33:34 +08:00
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
2022-11-29 22:01:42 +08:00
# See also: https://github.com/pytorch/torchdynamo/issues/327
2022-11-18 20:21:19 +08:00
" AtenEmbeddingBagSumExample_basic " ,
2023-05-12 13:46:33 +08:00
2023-01-01 05:54:25 +08:00
# error: unsupported by backend contract: tensor with unknown rank
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
" ElementwisePreluModule_basic " ,
2024-03-29 08:05:00 +08:00
# error: torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised: AssertionError: Unregistered operation: torch.aten._prelu_kernel
" ElementwisePreluStaticModule_basic " ,
2023-02-07 13:38:44 +08:00
#ERROR: value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=-1.336e-32, max=+0.9152, mean=+0.4837) is not close to golden value (Tensor with shape=[2, 3, 6, 10], dtype=torch.float32, min=+0.02233, max=+0.9152, mean=+0.4777)
" UpSampleNearest2dDynamicFactor_basic " ,
2023-02-06 19:52:04 +08:00
" ReduceMaxAlongDimUnsignedInt_basic " ,
2023-12-05 23:16:35 +08:00
" ReduceMinAlongDimUnsignedInt_basic " ,
2023-02-20 15:40:09 +08:00
#ERROR: value (-56) is not equal to golden value (200)
" AtenIntTensorByteDtypeModule_basic " ,
2023-03-07 02:12:58 +08:00
# ERROR: assert isinstance(e, FakeTensor)
" ElementwiseAddScalar_NumToTensorFloat_Module_basic " ,
2023-03-07 09:38:27 +08:00
# ERROR: assert isinstance(e, FakeTensor)
" RsubInt0d_NumToTensor_Module_basic " ,
2022-11-16 13:57:58 +08:00
# ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::squeeze.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.
" PrimsSqueezeModule_basic " ,
" PrimsSqueezeEmptyDimensionsModule_basic " ,
2023-11-21 23:56:09 +08:00
" SplitDimStaticModule_basic " ,
" SplitDimDynamicModule_basic " ,
2023-04-05 19:32:52 +08:00
2023-04-10 11:50:26 +08:00
# ERROR: RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: prims::view_of.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.
" PrimsViewOfModule_basic " ,
" PrimsViewOfZeroRankModule_basic " ,
2023-04-29 07:05:17 +08:00
# See https://github.com/llvm/torch-mlir/pull/2040 and corresponding upstream issue
# https://github.com/pytorch/pytorch/issues/99752.
# torch._dynamo.exc.Unsupported: call_function BuiltinVariable(bool) [TensorVariable()] {}
' TensorToBoolZeroRank_basic ' ,
' TensorToBool_basic ' ,
2023-05-12 13:46:33 +08:00
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
2023-04-29 07:05:17 +08:00
' AtenSubFloatModule_basic ' ,
2024-02-06 08:23:04 +08:00
' AtenMulFloatModule_basic ' ,
2023-04-29 07:05:17 +08:00
' BoolFloatFalseModule_basic ' ,
' BoolFloatTrueModule_basic ' ,
' CeilFloatModule_basic ' ,
' DivFloatModule_basic ' ,
' GeFloatIntModule_basic ' ,
' GeFloatModule_basic ' ,
' GtFloatIntModule_basic ' ,
' NeFloatIntModule_basic ' ,
' SubFloatModule_basic ' ,
2024-02-06 08:23:04 +08:00
' MulFloatModule_basic ' ,
2023-04-29 07:05:17 +08:00
' TensorToFloatZeroRank_basic ' ,
' TensorToFloat_basic ' ,
2023-05-12 13:46:33 +08:00
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
2023-04-29 07:05:17 +08:00
2023-05-12 13:46:33 +08:00
# START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
2023-04-29 07:05:17 +08:00
' AddIntModule_basic ' ,
' AtenIntTensorCharDtypeModule_basic ' ,
' BoolIntFalseModule_basic ' ,
' BoolIntTrueModule_basic ' ,
' DivIntModule_basic ' ,
' EqIntModule_basic ' ,
' GeIntModule_basic ' ,
' GtIntModule_basic ' ,
' MulIntModule_basic ' ,
' NeIntModule_basic ' ,
' SqrtIntModule_basic ' ,
' SubIntModule_basic ' ,
' TensorToIntZeroRank_basic ' ,
' TensorToInt_basic ' ,
' UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic ' ,
' ViewCollapseDynamicWithAtenSizeIntModule_basic ' ,
2023-05-12 13:46:33 +08:00
# END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
2023-04-29 07:05:17 +08:00
2024-02-01 07:09:12 +08:00
# ERROR: torch._dynamo.exc.Unsupported: Tensor.item
' AtenItemIntOpModule_basic ' ,
' AtenItemFpOpModule_basic ' ,
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)}
2023-04-29 07:05:17 +08:00
' SortIntListReverse_basic ' ,
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {}
2023-04-29 07:05:17 +08:00
' SortIntList_basic ' ,
2023-05-12 13:46:33 +08:00
# START tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
2023-04-29 07:05:17 +08:00
' AtenFloatScalarModule_basic ' ,
' AtenIntBoolOpModule_basic ' ,
' QuantizedMLP_basic ' ,
2024-04-02 07:21:05 +08:00
' QuantizedSingleLayer_basic ' ,
2024-04-16 07:06:47 +08:00
' QuantizedBatchedInputSingleLayer_basic ' ,
2024-04-11 03:36:58 +08:00
' QuantizedNoLayer_basic ' ,
2023-04-29 07:05:17 +08:00
' ScalarImplicitFloatModule_basic ' ,
' ScalarImplicitIntModule_basic ' ,
2023-05-12 13:46:33 +08:00
# END tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default
2023-04-29 07:05:17 +08:00
2023-05-12 13:46:33 +08:00
# START tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
2023-04-29 07:05:17 +08:00
' BincountMinlengthModule_basic ' ,
' BincountModule_basic ' ,
' BincountStaticSizeModule_basic ' ,
2023-05-12 13:46:33 +08:00
# END tests failing due to: torch._dynamo.exc.Unsupported: dynamic shape operator: aten.bincount.default
2023-04-29 07:05:17 +08:00
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.Bool
2023-04-29 07:05:17 +08:00
' BoolFloatConstantModule_basic ' ,
' BoolIntConstantModule_basic ' ,
2023-11-02 10:56:44 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
" ViewSizeFromOtherTensor_basic " ,
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__
2023-04-29 07:05:17 +08:00
' ContainsIntList_False ' ,
' ContainsIntList_True ' ,
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.all
2023-04-29 07:05:17 +08:00
' AllBoolFalseModule_basic ' ,
' AllBoolTrueModule_basic ' ,
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.any
2023-04-29 07:05:17 +08:00
' AnyBoolFalseModule_basic ' ,
' AnyBoolTrueModule_basic ' ,
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor float call_function aten.sqrt
2023-04-29 07:05:17 +08:00
' SqrtIntConstantModule_basic ' ,
2023-07-08 01:01:51 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
' BroadcastDynamicDimModule_basic ' ,
2023-05-12 13:46:33 +08:00
# START tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
2023-04-29 07:05:17 +08:00
' AtenIntBoolOpConstFalseModule_basic ' ,
' AtenIntBoolOpConstTrueModule_basic ' ,
' IntFloatModule_basic ' ,
' PowIntFloatModule_basic ' ,
2023-05-12 13:46:33 +08:00
# END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int
2023-04-29 07:05:17 +08:00
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len
2023-04-29 07:05:17 +08:00
' LenStrModule_basic ' ,
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.numel
2023-04-29 07:05:17 +08:00
' NumelModule_basic ' ,
' NumelZeroRankModule_basic ' ,
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.max
2023-04-29 07:05:17 +08:00
' PrimMaxIntModule_basic ' ,
2023-05-12 13:46:33 +08:00
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function prim.min
2023-04-29 07:05:17 +08:00
' PrimMinIntModule_basic ' ,
2023-05-26 06:40:12 +08:00
' PrimMinIntDynamicModule_basic ' ,
2023-04-29 07:05:17 +08:00
2023-05-12 13:46:33 +08:00
# START tests failing due to: empty graph in dynamo
2023-04-29 07:05:17 +08:00
' IsFloatingPointFloat_True ' ,
' IsFloatingPointInt_False ' ,
' TorchPrimLoopForLikeModule_basic ' ,
' TorchPrimLoopWhileLikeModule_basic ' ,
2023-05-31 14:14:14 +08:00
" ScalarConstantTupleModule_basic " ,
2023-05-12 13:46:33 +08:00
# END tests failing due to: empty graph in dynamo
# ERROR due to: backend never runs because of empty frame
' ConstantBoolParameterModule_basic ' ,
# START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
" UpSampleNearest2dDynamicSize_basic " ,
" UpSampleNearest2dStaticFactor_basic " ,
" UpSampleNearest2dStaticSize_basic " ,
" UpSampleNearest2d_basic " ,
# END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
# START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
" ElementwiseAddScalarFloatModule_basic " ,
# END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
" ElementwiseAddScalar_TensorLiteralInt32_Module_basic " ,
" HBC_basic " ,
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
" ElementwiseDivScalarModule_basic " ,
2023-06-12 17:18:38 +08:00
# ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int'
" ElementwiseAtenDivIntScalarModule_basic " ,
2023-05-12 13:46:33 +08:00
# ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
" ElementwiseSubScalarFloatModule_basic " ,
" ElementwiseSubScalarIntModule_basic " ,
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode
2024-04-16 04:45:10 +08:00
" ElementwiseAtenFloorDivideScalarNegativeModule_basic " ,
" ElementwiseAtenFloorDivideScalarModule_basic " ,
" ElementwiseDivTensorRoundingModeFloorModule_basic " ,
" ElementwiseDivTensorRoundingModeTruncModule_basic " ,
" ElementwiseDivTensorRoundingModeFloorStaticModule_basic " ,
" ElementwiseDivTensorRoundingModeTruncStaticModule_basic " ,
" ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic " ,
" ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic " ,
" ElementwiseDivScalarRoundingModeFloorModule_basic " ,
" ElementwiseDivScalarRoundingModeTruncModule_basic " ,
" ElementwiseDivScalarRoundingModeFloorStaticModule_basic " ,
" ElementwiseDivScalarRoundingModeTruncStaticModule_basic " ,
" ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic " ,
" ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic " ,
[RFC] general support for Adaptive Pooling Ops (#2661)
Adaptive pooling ops can only be decomposed into their non-adaptive
counterparts in trivial cases.
For example, the current decomposition for AtenAdaptiveAvgPool1dOp in
DecomposeComplexOps.cpp supports outSize = inSize (i.e., do literally
nothing), and outSize = 1 (i.e., do a batched average).
The reason adaptive pooling ops are difficult to lower to linalg is that
they are not constantly strided. They are computed by taking an input
tensor of shape (N, C, Hin), and an output size Hout, and computing the
output tensor at position (n,c, h) in the following way:
1. compute st(h) = (h*Hin)//Hout
2. compute en(h) = 1 + ((h+1)*Hin -1)//Hout
3. apply a computation (max or avg) to the slice: INPUT[n, c,
st(h):en(h)]
The provided sample implementation (for ConvertAtenAdaptiveAvgPool1dOp)
uses tensor.extract to access the input tensor inside the payload of a
linalg generic op. This is likely an unattractive use of linalg generic
ops, which is why I am asking for some more targeted feedback on the
validity of this approach before attempting to support the many other
adaptive pooling ops.
Specifically:
- Is the performance of this implementation bad enough to warrant
targeting different dialects entirely? e.g. TMtensor/linalg ext/ etc.
- If the provided implementation is of acceptable performance to the
community, then is it permissable to remove the Adaptive pooling
decompositions from DecomposeComplexOps.cpp? Based on the current
structure of the -torch-decompose-complex-ops pass, it does not seem
possible to only decompose the adaptive ops in special cases (it seems
to get stuck in an infinite loop on a match failure). I would be happy
to instead incorporate the case logic into the conversion directly, and
remove the decompositions once they are rendered completely obsolete.
As long as this approach is acceptable, I can clean up the
implementation with some helper functions, and quickly add support for
each of the remaining Adaptive pooling ops.
2024-01-10 03:14:10 +08:00
" AdaptiveAvgPool1dStaticLargerOutput_basic " ,
" AdaptiveAvgPool1dGeneralDynamic_basic " ,
2024-03-23 02:05:20 +08:00
" AdaptiveAvgPool1dGeneralDynamicNoBatches_basic " ,
" AdaptiveAvgPool2dDynamic_basic " ,
" AdaptiveAvgPool2dDynamicNoBatch_basic " ,
2023-05-12 13:46:33 +08:00
# ERROR: Exception: Unsupported op: get_attr
" NumToTensorFloatModule_basic " ,
" NumToTensorIntModule_basic " ,
" TensorFloatModule_basic " ,
" TensorIntModule_basic " ,
# ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.randn.generator
" RandnGeneratorF64Module_basic " ,
" RandnGeneratorModule_basic " ,
# START tests failing due to: complex floating point ops
# END tests failing due to: complex floating point ops
2023-05-19 10:07:58 +08:00
# ERROR: Exception: Unsupported: return type List[Tensor] in schema for aten.unbind.int
" UnbindIntListUnpack_Module_basic " ,
" UnbindIntGetItem_Module_basic " ,
2022-10-16 05:46:06 +08:00
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {}
" ScatterValueFloatModule_basic " ,
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
" ScatterValueIntModule_basic " ,
2023-05-31 14:14:14 +08:00
2023-07-14 15:26:54 +08:00
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
" UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic " ,
2023-08-18 23:15:54 +08:00
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
" Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier " ,
2024-02-14 13:18:01 +08:00
# AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu
2023-08-25 14:07:30 +08:00
" ScaledDotProductAttentionDifferentModule_basic " ,
2023-09-05 21:28:37 +08:00
# AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only
" AtenEmbeddingBagStaticModule_basic " ,
2023-09-29 20:19:18 +08:00
# Lowering not present for this case
" ElementwiseToDtypeI64ToUI8Module_basic " ,
2023-10-03 19:59:56 +08:00
# torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method add of type object at 0x7f4f8b05a720>(*(FakeTensor(..., size=(3, 4), dtype=torch.int8), 3, 2), **{}): Tensor with dtype torch.int64 is not the expected dtype of torch.int8!
" ElementwiseAddScalarInt8Module_basic " ,
2023-10-19 22:03:00 +08:00
# ERROR: dtype (torch.int64) is not equal to golden dtype (torch.float32)
" ThresholdBackward2dMixedModule_basic " ,
2023-11-17 00:51:55 +08:00
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
" ArangeStartOutViewModule_basic " ,
2023-12-13 11:05:12 +08:00
2024-01-13 11:11:14 +08:00
# Dynamo does not support tracing quantized tensors
2024-01-26 08:40:21 +08:00
" ElementwiseDequantizePerChannelModule_basic " ,
2024-01-13 11:11:14 +08:00
" ElementwiseDequantizePerTensorModule_basic " ,
" ElementwiseQuantizePerTensorModule_basic " ,
2024-03-21 04:37:47 +08:00
" ElementwiseQuantizePerTensorUIntModule_basic " ,
2024-01-25 06:02:50 +08:00
" AtenMmQuint8_basic " ,
2024-04-16 07:06:47 +08:00
" AtenMmQint8_basic " ,
" AtenMmQMixedSigni8_basic " ,
" AtenMatmulQMixedSigni8Transpose_basic " ,
" AtenMatmulQMixedSigni8_basic " ,
" AtenMatmulQint8MV_basic " ,
2024-04-17 00:28:28 +08:00
" AtenMatmulQint8VV_basic " ,
" AtenMatmulQint8VM_basic " ,
2024-04-16 07:06:47 +08:00
" AtenMatmulQint8_basic " ,
2024-01-31 05:46:47 +08:00
" Conv2dQInt8Module_basic " ,
2024-01-24 13:30:03 +08:00
# Dynamo not supporting conv_tbc
" ConvTbcModule_basic " ,
2024-02-27 13:32:05 +08:00
" FloatImplicitModule_basic " ,
" IntImplicitModule_basic " ,
2024-03-07 02:56:58 +08:00
# Others
" GridSamplerBasic1_basic " ,
" GridSamplerBasic2_basic " ,
2024-03-15 08:53:29 +08:00
" FakeQuantizePerTensorAffineModule_basic " ,
" FakeQuantizePerTensorAffineDynamicShapeModule_basic " ,
" FakeQuantizePerTensorAffineRoundToEvenModule_basic " ,
2023-04-29 07:05:17 +08:00
}
TORCHDYNAMO_CRASHING_SET = {
2023-05-12 13:46:33 +08:00
# No upstream decompositions.
# %6:4 = torch.operator "aten._embedding_bag_forward_only"(%1, %3, %5, %false, %int0, %false, %none, %false, %int-1) : (!torch.tensor<*,f32>, !torch.tensor<*,si64>, !torch.tensor<*,si64>, !torch.bool, !torch.int, !torch.bool, !torch.none, !torch.bool, !torch.int) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor)
# See also: https://github.com/pytorch/torchdynamo/issues/327
" Aten_EmbeddingBagExample_basic " ,
" FullModuleInt3D_basic " ,
" ThresholdBackward1dIntModule_basic " ,
" ThresholdBackward2dIntModule_basic " ,
" ThresholdBackward3dIntModule_basic " ,
# See https://github.com/llvm/torch-mlir/issues/2050
2023-04-29 07:05:17 +08:00
" ElementwiseCloneChannelsLastMemoryFormatModule_basic " ,
2023-05-02 10:35:26 +08:00
" ElementwiseCloneContiguousModule_basic " ,
" ElementwiseCloneModule_basic " ,
2023-04-29 07:05:17 +08:00
" ExpandAsFloatModule_basic " ,
" ExpandAsIntModule_basic " ,
" ExpandModule_basic " ,
" MoveDimIntModule_basic " ,
" MoveDimIntNegativeIndexModule_basic " ,
2023-05-02 10:35:26 +08:00
" NarrowVerticalTest2_basic " ,
" NarrowVerticalTest_basic " ,
2023-04-29 07:05:17 +08:00
" NumpyTRank2Module_basic " ,
" NumpyTRankNDynamicModule_basic " ,
" NumpyTRankNStaticModule_basic " ,
" PermuteModule_basic " ,
" PermuteNegativeIndexModule_basic " ,
" SelectIntNegativeDimAndIndexStaticModule_basic " ,
2023-05-12 13:46:33 +08:00
" TestMultipleTensorAndPrimitiveTypesReturn_basic " ,
2023-05-02 10:35:26 +08:00
" TModuleRank2_basic " ,
" ToCopyModule_basic " ,
" TransposeIntModule_basic " ,
" TransposeIntNegDimsModule_basic " ,
2023-07-18 00:51:24 +08:00
" IndexPutImpl2DNoneIndexStaticModule_basic " ,
Add aten.pool_max3d support to torch-to-linalg (#2735)
Added verification logic to the abstract_interpreter_lib_gen.py
Also made some unit tests
Initially, I thought we can use `linalg::pooling_ndhwc_max` to help
implement this problem. However, on a 5-dimensional matrix it does the
pooling on dimensions (2, 3, 4) which is not what we want. We want
pooling on dimensions (3, 4, 5).
To achieve this, we would need to lower our code using the `linalg`
dialect.
Turns out the pooling code in `linalg` looks like this.
```
func @max_pooling_ncdhw(%I: memref<?x?x?x?x?xf32>, %K: memref<3xindex>, %O: memref<?x?x?x?x?xf32>,
%strides: memref<3xindex>, %dilations: memref<3xindex>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%N = memref.dim %I, %c0 : memref<?x?x?x?x?xf32>
%C = memref.dim %I, %c1 : memref<?x?x?x?x?xf32>
%D = memref.dim %I, 2 : memref<?x?x?x?x?xf32>
%H = memref.dim %I, 3 : memref<?x?x?x?x?xf32>
%W = memref.dim %I, 4 : memref<?x?x?x?x?xf32>
%kernel_d = memref.load %K[%c0] : memref<3xindex>
%kernel_h = memref.load %K[%c1] : memref<3xindex>
%kernel_w = memref.load %K[2] : memref<3xindex>
%stride_d = memref.load %strides[%c0] : memref<3xindex>
%stride_h = memref.load %strides[%c1] : memref<3xindex>
%stride_w = memref.load %strides[2] : memref<3xindex>
%dilation_d = memref.load %dilations[%c0] : memref<3xindex>
%dilation_h = memref.load %dilations[%c1] : memref<3xindex>
%dilation_w = memref.load %dilations[2] : memref<3xindex>
linalg.generic {
indexing_maps = [
affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d * %stride_d + kd * %dilation_d, h * %stride_h + kh * %dilation_h, w * %stride_w + kw * %dilation_w)>, // Map for input tensor
affine_map<(n, c, d, h, w, kd, kh, kw) -> (kd, kh, kw)>, // Map for kernel tensor
affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d, h, w)> // Map for output tensor
],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"],
doc = "3D Max Pooling NCDHW with Strides, Dilations, and Kernel Size"
} ins(%I, %K : memref<?x?x?x?x?xf32>, memref<3xindex>) outs(%O : memref<?x?x?x?x?xf32>) {
^bb0(%input_elem: f32, %kernel_elem: index, %output_elem: f32):
%max_val = arith.maxf %input_elem, %output_elem : f32
linalg.yield %max_val : f32
}
return
}
```
This was implemented based on it's source code with the adjustments
mentioned above:
https://github.com/llvm/llvm-project/blob/4ca1b5e094280ef1af40412e3cfcb62dc3cf15bc/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L5647
Issues related to this can be found here
https://github.com/nod-ai/SHARK-Turbine/issues/324
2024-01-19 23:39:46 +08:00
" MaxPool3dCeilModeTrueModule_basic " ,
" MaxPool3dEmptyStrideStaticModule_basic " ,
" MaxPool3dLargeDatadModule_basic " ,
" MaxPool3dModuleRandomSimple_basic " ,
" MaxPool3dModule_basic " ,
" MaxPool3dStaticCeilModeTrueModule_basic " ,
" MaxPool3dStaticModule_basic " ,
2024-03-12 03:22:05 +08:00
# Looks like incorrect fx graph conversion
" ElementwiseAddScalar_TensorLiteralInt32_Module_basic " ,
2022-11-18 20:21:19 +08:00
}
2024-04-19 12:29:17 +08:00
FX_IMPORTER_XFAIL_SET = {
2024-04-17 13:36:07 +08:00
' AllBoolFalseModule_basic ' ,
' AllBoolTrueModule_basic ' ,
' AnyBoolFalseModule_basic ' ,
' AnyBoolTrueModule_basic ' ,
' ArangeStartOutViewModule_basic ' ,
' AtenEmbeddingBagStaticModule_basic ' ,
' AtenEmbeddingBagSumExample_basic ' ,
' AtenFloatScalarModule_basic ' ,
' AtenIntBoolOpConstFalseModule_basic ' ,
' AtenIntBoolOpConstTrueModule_basic ' ,
' AtenIntBoolOpModule_basic ' ,
' AtenItemFpOpModule_basic ' ,
' AtenMatmulQMixedSigni8Transpose_basic ' ,
' AtenMatmulQMixedSigni8_basic ' ,
' AtenMatmulQint8MV_basic ' ,
' AtenMatmulQint8_basic ' ,
' AtenMatmulQint8VM_basic ' ,
' AtenMatmulQint8VV_basic ' ,
' AtenMmQMixedSigni8_basic ' ,
' AtenMmQint8_basic ' ,
' AtenMmQuint8_basic ' ,
' AtenSubFloatModule_basic ' ,
' BincountMinlengthModule_basic ' ,
' BincountModule_basic ' ,
' BincountStaticSizeModule_basic ' ,
' BoolFloatConstantModule_basic ' ,
' BoolFloatFalseModule_basic ' ,
' BoolFloatTrueModule_basic ' ,
' BoolIntConstantModule_basic ' ,
' BoolIntFalseModule_basic ' ,
' BoolIntTrueModule_basic ' ,
' BroadcastDynamicDimModule_basic ' ,
' CeilFloatModule_basic ' ,
' ConstantBoolParameterModule_basic ' ,
' ContainsIntList_False ' ,
' ContainsIntList_True ' ,
' Conv2dQInt8Module_basic ' ,
' Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier ' ,
' ConvTbcModule_basic ' ,
' ConvolutionBackwardModule2DPadded_basic ' ,
' ConvolutionBackwardModule2DStrided_basic ' ,
' ConvolutionBackwardModule2D_basic ' ,
2024-04-13 07:07:56 +08:00
' CumsumModule_basic ' ,
2024-04-17 13:36:07 +08:00
' DivFloatModule_basic ' ,
' DivIntModule_basic ' ,
' ElementwiseAddScalar_NumToTensorFloat_Module_basic ' ,
' ElementwiseDequantizePerChannelModule_basic ' ,
' ElementwiseDequantizePerTensorModule_basic ' ,
' ElementwiseQuantizePerTensorModule_basic ' ,
' ElementwiseQuantizePerTensorUIntModule_basic ' ,
' ElementwiseToDtypeI64ToUI8Module_basic ' ,
' EqIntModule_basic ' ,
' FakeQuantizePerTensorAffineDynamicShapeModule_basic ' ,
' FakeQuantizePerTensorAffineModule_basic ' ,
' FakeQuantizePerTensorAffineRoundToEvenModule_basic ' ,
' FloatImplicitModule_basic ' ,
' GeFloatIntModule_basic ' ,
' GeFloatModule_basic ' ,
' GeIntModule_basic ' ,
' GtFloatIntModule_basic ' ,
' GtIntModule_basic ' ,
' IntFloatModule_basic ' ,
' IntImplicitModule_basic ' ,
' IsFloatingPointFloat_True ' ,
' IsFloatingPointInt_False ' ,
' LenStrModule_basic ' ,
' MaxPool3dCeilModeTrueModule_basic ' ,
' MaxPool3dEmptyStrideStaticModule_basic ' ,
' MaxPool3dLargeDatadModule_basic ' ,
' MaxPool3dModuleRandomSimple_basic ' ,
' MaxPool3dModule_basic ' ,
' MaxPool3dStaticCeilModeTrueModule_basic ' ,
' MaxPool3dStaticModule_basic ' ,
' MulFloatModule_basic ' ,
2024-04-13 07:07:56 +08:00
' NativeGroupNormBackwardModule_basic ' ,
2024-04-17 13:36:07 +08:00
' NeFloatIntModule_basic ' ,
' NeIntModule_basic ' ,
' NllLossModuleBackward1DMeanWeight_basic ' ,
' NllLossModuleBackward1DMean_basic ' ,
' NllLossModuleBackward1DSumWeight_basic ' ,
' NllLossModuleBackward1DSum_basic ' ,
' NllLossModuleBackward1DWeight_basic ' ,
' NllLossModuleBackward1D_basic ' ,
' NumToTensorFloatModule_basic ' ,
' NumToTensorIntModule_basic ' ,
' NumelModule_basic ' ,
' NumelZeroRankModule_basic ' ,
' PowIntFloatModule_basic ' ,
' PrimMaxIntModule_basic ' ,
' PrimMinIntDynamicModule_basic ' ,
' PrimMinIntModule_basic ' ,
' PrimsSqueezeEmptyDimensionsModule_basic ' ,
' PrimsSqueezeModule_basic ' ,
' PrimsViewOfModule_basic ' ,
' PrimsViewOfZeroRankModule_basic ' ,
' QuantizedBatchedInputSingleLayer_basic ' ,
' QuantizedMLP_basic ' ,
' QuantizedNoLayer_basic ' ,
' QuantizedSingleLayer_basic ' ,
' ReduceMaxAlongDimUnsignedInt_basic ' ,
' ReduceMinAlongDimUnsignedInt_basic ' ,
' RsubInt0d_NumToTensor_Module_basic ' ,
' ScalarConstantTupleModule_basic ' ,
' ScalarImplicitFloatModule_basic ' ,
' SortIntListReverse_basic ' ,
' SortIntList_basic ' ,
' SplitDimDynamicModule_basic ' ,
' SplitDimStaticModule_basic ' ,
' SqrtIntConstantModule_basic ' ,
' SqrtIntModule_basic ' ,
' SubFloatModule_basic ' ,
' TModuleRank0_basic ' ,
' TensorToBoolZeroRank_basic ' ,
' TensorToBool_basic ' ,
' TensorToFloatZeroRank_basic ' ,
' TensorToFloat_basic ' ,
2024-04-13 07:07:56 +08:00
' TestMultipleTensorAndPrimitiveTypesReturn_basic ' ,
2024-04-17 13:36:07 +08:00
' ThresholdBackward2dMixedModule_basic ' ,
' TorchPrimLoopForLikeModule_basic ' ,
' TorchPrimLoopWhileLikeModule_basic ' ,
' UnbindIntGetItem_Module_basic ' ,
' UnbindIntListUnpack_Module_basic ' ,
' UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic ' ,
' UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic ' ,
' UpSampleNearest2dDynamicFactor_basic ' ,
' ViewCollapseDynamicWithAtenSizeIntModule_basic ' ,
' ViewSizeFromOtherTensor_basic ' ,
}
FX_IMPORTER_CRASHING_SET = {
" HBC_basic " ,
2024-04-13 07:07:56 +08:00
}
2024-04-19 12:29:17 +08:00
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
" AdaptiveAvgPool3dDynamicNoBatch_basic " ,
" AdaptiveAvgPool3dDynamic_basic " ,
" AdaptiveMaxPool1dDynamicNoBatch_basic " ,
" AdaptiveMaxPool1dDynamic_basic " ,
" AdaptiveMaxPool1dStatic_basic " ,
" AdaptiveMaxPool2dDynamicNoBatch_basic " ,
" AdaptiveMaxPool2dDynamicWithIndices_basic " ,
" AdaptiveMaxPool2dDynamic_basic " ,
" AdaptiveMaxPool2dStaticWithIndices_basic " ,
" AdaptiveMaxPool2dStatic_basic " ,
" AdaptiveMaxPool3dDynamicNoBatch_basic " ,
" AdaptiveMaxPool3dDynamicWithIndices_basic " ,
" AdaptiveMaxPool3dDynamic_basic " ,
" AdaptiveMaxPool3dStaticWithIndices_basic " ,
" AdaptiveMaxPool3dStatic_basic " ,
" AllBoolFalseModule_basic " ,
" AllBoolTrueModule_basic " ,
" AnyBoolFalseModule_basic " ,
" AnyBoolTrueModule_basic " ,
" ArangeStartOutViewModule_basic " ,
" ArgminIntModule_basic " ,
" ArgminIntModule_multiple_mins " ,
" ArgminModule_basic " ,
" ArgminModule_keepDim " ,
" ArgminModule_with_dim " ,
" AtenComplexImagModule_basic " ,
" AtenComplexRealModule_basic " ,
" AtenComplexViewModule_basic " ,
" AtenDiagEmbedDefaultDiag_basic " ,
" AtenDiagEmbedDimDiag_basic " ,
" AtenDiagEmbedNegOffsetDiag_basic " ,
" AtenDiagEmbedNonDefault4DDiag_basic " ,
" AtenDiagEmbedOffsetDiag_basic " ,
" AtenDiagEmbedRevDimDiag_basic " ,
" AtenEmbeddingBagStaticModule_basic " ,
" AtenEmbeddingBagSumExample_basic " ,
" AtenFloatScalarModule_basic " ,
" AtenIntBoolOpConstFalseModule_basic " ,
" AtenIntBoolOpConstTrueModule_basic " ,
" AtenIntBoolOpModule_basic " ,
" AtenItemFpOpModule_basic " ,
" AtenMatmulQMixedSigni8Transpose_basic " ,
" AtenMatmulQMixedSigni8_basic " ,
" AtenMatmulQint8MV_basic " ,
" AtenMatmulQint8VM_basic " ,
" AtenMatmulQint8VV_basic " ,
" AtenMatmulQint8_basic " ,
" AtenMmQMixedSigni8_basic " ,
" AtenMmQint8_basic " ,
" AtenMmQuint8_basic " ,
" AtenRealView128Module_basic " ,
" AtenRealView64Module_basic " ,
" AtenSubFloatModule_basic " ,
" AtenTopKModule_basic " ,
" AtenTopKSmallestModule_basic " ,
" AtenTrilModule_basic " ,
" AtenTrilWithNegDiagonalModule_basic " ,
" AtenTrilWithPosDiagonalModule_basic " ,
" Aten_EmbeddingBagExample_basic " ,
" AvgPool2dDivisorOverrideModule_basic " ,
" BernoulliTensorModule_basic " ,
" BincountMinlengthModule_basic " ,
" BincountModule_basic " ,
" BincountStaticSizeModule_basic " ,
" BoolFloatConstantModule_basic " ,
" BoolFloatFalseModule_basic " ,
" BoolFloatTrueModule_basic " ,
" BoolIntConstantModule_basic " ,
" BoolIntFalseModule_basic " ,
" BoolIntTrueModule_basic " ,
" BroadcastDynamicDimModule_basic " ,
" CeilFloatModule_basic " ,
" ConstantBoolParameterModule_basic " ,
" ConstantPad2dStaticModule_basic " ,
" ConstantPadNdModule_basic " ,
" ConstantPadNdPartialStaticModule_basic " ,
" ConstantPadNdStaticModule_basic " ,
" ContainsIntList_False " ,
" ContainsIntList_True " ,
" Conv2dQInt8Module_basic " ,
" ConvTbcModule_basic " ,
" ConvolutionBackwardModule2DPadded_basic " ,
" ConvolutionBackwardModule2DStrided_basic " ,
" ConvolutionBackwardModule2D_basic " ,
" CumsumModule_basic " ,
" DiagonalModule_basic " ,
" DiagonalModule_nonsquare " ,
" DiagonalModule_transposed " ,
" DiagonalModule_with_dims " ,
" DiagonalModule_with_dims_and_offset " ,
" DiagonalModule_with_negative_dims " ,
" DiagonalModule_with_offset " ,
" DivFloatModule_basic " ,
" DivIntModule_basic " ,
" ElementwiseAcoshIntModule_basic " ,
" ElementwiseAcoshModule_basic " ,
" ElementwiseAddScalar_NumToTensorFloat_Module_basic " ,
" ElementwiseAsinhIntModule_basic " ,
" ElementwiseAsinhModule_basic " ,
" ElementwiseAtan2FloatIntModule_basic " ,
" ElementwiseAtan2TensorFloatModule_basic " ,
" ElementwiseAtan2TensorIntModule_basic " ,
" ElementwiseAtanhIntModule_basic " ,
" ElementwiseAtanhModule_basic " ,
" ElementwiseBitwiseLeftShiftInt32Module_basic " ,
" ElementwiseBitwiseLeftShiftInt64Module_basic " ,
" ElementwiseBitwiseLeftShiftInt8Module_basic " ,
" ElementwiseBitwiseRightShiftInt32Module_basic " ,
" ElementwiseBitwiseRightShiftInt64Module_basic " ,
" ElementwiseBitwiseRightShiftInt8Module_basic " ,
" ElementwiseCoshIntModule_basic " ,
" ElementwiseCoshModule_basic " ,
" ElementwiseDequantizePerChannelModule_basic " ,
" ElementwiseDequantizePerTensorModule_basic " ,
" ElementwiseErfIntModule_basic " ,
" ElementwiseLogitModule_basic " ,
" ElementwiseMulTensorComplexModule_basic " ,
" ElementwisePowScalarModule_basic " ,
" ElementwiseQuantizePerTensorModule_basic " ,
" ElementwiseQuantizePerTensorUIntModule_basic " ,
" ElementwiseReciprocalIntModule_basic " ,
" ElementwiseTanIntModule_basic " ,
" ElementwiseTanModule_basic " ,
" ElementwiseTernaryModule_basic " ,
" ElementwiseToDtypeI64ToUI8Module_basic " ,
" EmptyModule_uint8 " ,
" EqIntModule_basic " ,
" FakeQuantizePerTensorAffineDynamicShapeModule_basic " ,
" FakeQuantizePerTensorAffineModule_basic " ,
" FakeQuantizePerTensorAffineRoundToEvenModule_basic " ,
" Fill_TensorFloat32WithFloat32_basic " ,
" Fill_TensorFloat32WithFloat64_basic " ,
" Fill_TensorFloat32WithInt64_basic " ,
" FloatImplicitModule_basic " ,
" GeFloatIntModule_basic " ,
" GeFloatModule_basic " ,
" GeIntModule_basic " ,
" GtFloatIntModule_basic " ,
" GtIntModule_basic " ,
" HBC_basic " ,
" HardtanhBackward_basic " ,
" IndexPut1DFloatAccumulateModule_basic " ,
" IndexPut1DFloatNonAccumulateModule_basic " ,
" IndexPut1DIntAccumulateModule_basic " ,
" IndexPut1DIntNonAccumulateModule_basic " ,
" IndexPut2DFloatAccumulateModule_basic " ,
" IndexPut2DFloatNonAccumulateModule_basic " ,
" IndexPut2DIntAccumulateModule_basic " ,
" IndexPut2DIntNonAccumulateModule_basic " ,
" IndexPut3DFloatAccumulateModule_basic " ,
" IndexPut3DFloatNonAccumulateModule_basic " ,
" IndexPut3DIntAccumulateModule_basic " ,
" IndexPut3DIntNonAccumulateModule_basic " ,
" IndexPutHackedTwin1DFloatAccumulateModule_basic " ,
" IndexPutHackedTwin1DFloatNonAccumulateModule_basic " ,
" IndexPutHackedTwin1DIntAccumulateModule_basic " ,
" IndexPutHackedTwin1DIntNonAccumulateModule_basic " ,
" IndexPutHackedTwin2DFloatAccumulateModule_basic " ,
" IndexPutHackedTwin2DFloatNonAccumulateModule_basic " ,
" IndexPutHackedTwin2DIntAccumulateModule_basic " ,
" IndexPutHackedTwin2DIntNonAccumulateModule_basic " ,
" IndexPutHackedTwin3DFloatAccumulateModule_basic " ,
" IndexPutHackedTwin3DFloatNonAccumulateModule_basic " ,
" IndexPutHackedTwin3DIntAccumulateModule_basic " ,
" IndexPutHackedTwin3DIntNonAccumulateModule_basic " ,
" IndexPutImpl1DFloatAccumulateModule_basic " ,
" IndexPutImpl1DFloatNonAccumulateModule_basic " ,
" IndexPutImpl1DIntAccumulateModule_basic " ,
" IndexPutImpl1DIntNonAccumulateModule_basic " ,
" IndexPutImpl2DFloatAccumulateModule_basic " ,
" IndexPutImpl2DFloatNonAccumulateModule_basic " ,
" IndexPutImpl2DImplicitModule_basic " ,
" IndexPutImpl2DIndexModule_basic " ,
" IndexPutImpl2DNoneIndexStaticModule_basic " ,
" IndexPutImpl3DFloatAccumulateModule_basic " ,
" IndexPutImpl3DFloatNonAccumulateModule_basic " ,
" IndexPutImplIndexWithNoneModule_basic " ,
2024-04-19 17:08:29 +08:00
" IndexSelectRank0IdxModule_basic " ,
2024-04-19 12:29:17 +08:00
" IndexTensorNegativeIndexModule_basic " ,
" IntFloatModule_basic " ,
" IntImplicitModule_basic " ,
" IsFloatingPointFloat_True " ,
" IsFloatingPointInt_False " ,
" LenStrModule_basic " ,
" MaxPool2dCeilModeTrueModule_basic " ,
" MaxPool2dEmptyStrideStaticModule_basic " ,
" MaxPool2dStaticCeilModeTrueModule_basic " ,
" MaxPool2dWithIndicesBackwardDynamic3DModule_basic " ,
" MaxPool2dWithIndicesBackwardDynamic4DModule_basic " ,
" MaxPool2dWithIndicesBackwardStatic3DModule_basic " ,
" MaxPool2dWithIndicesBackwardStatic4DModule_basic " ,
" MaxPool3dCeilModeTrueModule_basic " ,
" MaxPool3dEmptyStrideStaticModule_basic " ,
" MaxPool3dLargeDatadModule_basic " ,
" MaxPool3dModuleRandomSimple_basic " ,
" MaxPool3dModule_basic " ,
" MaxPool3dStaticCeilModeTrueModule_basic " ,
" MaxPool3dStaticModule_basic " ,
" MeanDimNoneDimModule_basic " ,
" MseLossMeanReductionModule_basic " ,
" MseLossSumReductionWithDifferentElemTypeModule_basic " ,
" MulFloatModule_basic " ,
" NativeGroupNormBackwardModule_basic " ,
" NeFloatIntModule_basic " ,
" NeIntModule_basic " ,
" NllLossModuleBackward1DMeanWeight_basic " ,
" NllLossModuleBackward1DMean_basic " ,
" NllLossModuleBackward1DSumWeight_basic " ,
" NllLossModuleBackward1DSum_basic " ,
" NllLossModuleBackward1DWeight_basic " ,
" NllLossModuleBackward1D_basic " ,
" NllLossModuleBackwardMeanWeight_basic " ,
" NllLossModuleBackwardMean_basic " ,
" NllLossModuleBackwardSumWeight_basic " ,
" NllLossModuleBackwardSum_basic " ,
" NllLossModuleBackwardWeight_basic " ,
" NllLossModuleBackward_basic " ,
" NllLossModuleBackward_ignore_index " ,
" NormScalarComplexModule_basic " ,
" NormScalarModule_basic " ,
" NormalFunctionalModule_basic " ,
" NumToTensorFloatModule_basic " ,
" NumToTensorIntModule_basic " ,
" NumelModule_basic " ,
" NumelZeroRankModule_basic " ,
" PadModule_basic " ,
" PadWithNoneValModule_basic " ,
" PixelShuffleModuleFullDynamic_basic " ,
" PixelShuffleModuleSpatiallyDynamic_basic " ,
" PixelShuffleModuleSpatiallyStatic_basic " ,
" PixelShuffleModuleStaticRank3Int64_basic " ,
" PixelShuffleModuleStaticRank4Float32_basic " ,
" PowIntFloatModule_basic " ,
" PrimMaxIntModule_basic " ,
" PrimMinIntDynamicModule_basic " ,
" PrimMinIntModule_basic " ,
" PrimsSqueezeEmptyDimensionsModule_basic " ,
" PrimsSqueezeModule_basic " ,
" PrimsViewOfModule_basic " ,
" PrimsViewOfZeroRankModule_basic " ,
" QuantizedBatchedInputSingleLayer_basic " ,
" QuantizedMLP_basic " ,
" QuantizedNoLayer_basic " ,
" QuantizedSingleLayer_basic " ,
" RandnDtypeDeviceModule_basic " ,
" RandnGeneratorF64Module_basic " ,
" RandnGeneratorModule_basic " ,
" RandnLikeDtypeModule_basic " ,
" RandnLikeModule_basic " ,
" RandnModule_basic " ,
" ReduceAllDimBool_basic " ,
" ReduceAllDimEmpty_basic " ,
" ReduceAllDimFloat_basic " ,
" ReduceAllDimInt_basic " ,
" ReduceMaxAlongDimUnsignedInt_basic " ,
" ReduceMinAlongDimNegative_basic " ,
" ReduceMinAlongDimSignedInt_basic " ,
" ReduceMinAlongDimUnsignedInt_basic " ,
" ReduceMinAlongDim_basic " ,
" ReduceMinKeepDimReturnBoth_basic " ,
" ReduceMinKeepDim_basic " ,
" ReduceProdDimIntFloatModule_basic " ,
" ReflectionPad1dModule2dInput_Right " ,
" ReflectionPad1dModule2dInput_basic " ,
" ReflectionPad1dModule3dInput_Left " ,
" ReflectionPad1dModule3dInput_basic " ,
" ReflectionPad2dModule_Bottom " ,
" ReflectionPad2dModule_Left " ,
" ReflectionPad2dModule_Right " ,
" ReflectionPad2dModule_Top " ,
" ReflectionPad2dModule_basic " ,
" ReplicationPad2dModule_basic " ,
" ReplicationPad2dModule_bottom0 " ,
" ReplicationPad2dModule_left0 " ,
" ReplicationPad2dModule_right0 " ,
" ReplicationPad2dModule_top0 " ,
" RsubInt0d_NumToTensor_Module_basic " ,
" ScalarConstantTupleModule_basic " ,
" ScalarImplicitFloatModule_basic " ,
" ScaledDotProductAttentionDifferentModule_basic " ,
" ScatterReduceFloatMaxModule " ,
" ScatterReduceFloatMaxModuleIncludeSelf " ,
" ScatterReduceFloatMeanModule " ,
" ScatterReduceFloatMeanModuleIncludeSelf " ,
" ScatterReduceFloatMinModule " ,
" ScatterReduceFloatMinModuleIncludeSelf " ,
" ScatterReduceFloatProdModule " ,
" ScatterReduceFloatProdModuleIncludeSelf " ,
" ScatterReduceFloatSumModule " ,
" ScatterReduceFloatSumModuleIncludeSelf " ,
" ScatterReduceIntMaxModule " ,
" ScatterReduceIntMaxModuleIncludeSelf " ,
" ScatterReduceIntMeanModule " ,
" ScatterReduceIntMeanModuleIncludeSelf " ,
" ScatterReduceIntMinModule " ,
" ScatterReduceIntMinModuleIncludeSelf " ,
" ScatterReduceIntProdModule " ,
" ScatterReduceIntProdModuleIncludeSelf " ,
" ScatterReduceIntSumModule " ,
" ScatterReduceIntSumModuleIncludeSelf " ,
" ScatterSrcModule_basic " ,
" ScatterSrcStaticModule_basic " ,
" ScatterValueFloatModule_basic " ,
" ScatterValueIntModule_basic " ,
" SliceOutOfLowerBoundEndIndexModule_basic " ,
" SortIntListReverse_basic " ,
" SortIntList_basic " ,
" SortTensorDescending_basic " ,
" SortTensorInteger_basic " ,
" SortTensorNegativeDimension_basic " ,
" SortTensorSpecificDimension_basic " ,
" SortTensor_basic " ,
" SplitDimDynamicModule_basic " ,
" SplitDimStaticModule_basic " ,
" SqrtIntConstantModule_basic " ,
" SqrtIntModule_basic " ,
" SubFloatModule_basic " ,
" TModuleRank0_basic " ,
" TensorToBoolZeroRank_basic " ,
" TensorToBool_basic " ,
" TensorToFloatZeroRank_basic " ,
" TensorToFloat_basic " ,
" TensorToInt_basic " ,
" TestMultipleTensorAndPrimitiveTypesReturn_basic " ,
" Threshold1dFloatModule_basic " ,
" Threshold1dIntI32Module_basic " ,
" Threshold1dIntModule_basic " ,
" Threshold2dFloatModule_basic " ,
" Threshold2dIntModule_basic " ,
" Threshold3dFloatModule_basic " ,
" Threshold3dIntModule_basic " ,
" ThresholdBackward1dFloatModule_basic " ,
" ThresholdBackward1dIntModule_basic " ,
" ThresholdBackward1dMixedModule_basic " ,
" ThresholdBackward2dFloatModule_basic " ,
" ThresholdBackward2dIntModule_basic " ,
" ThresholdBackward2dMixedModule_basic " ,
" ThresholdBackward3dFloatModule_basic " ,
" ThresholdBackward3dIntModule_basic " ,
" ThresholdBackward3dMixedModule_basic " ,
" TorchPrimLoopForLikeModule_basic " ,
" TorchPrimLoopWhileLikeModule_basic " ,
" TraceModule_basic " ,
" TraceModule_empty " ,
" TraceModule_nonsquare " ,
" TraceSignedIntModule_basic " ,
" TraceUnsignedIntModule_basic " ,
" TraceUnsignedIntModule_empty " ,
" UnbindIntGetItem_Module_basic " ,
" UnbindIntListUnpack_Module_basic " ,
" UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic " ,
" UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic " ,
" UpSampleNearest2dBackwardScalesNone_basic " ,
" UpSampleNearest2dBackward_basic " ,
" VarMeanBiasedModule_basic " ,
" VarMeanCorrectionNoneModule_basic " ,
" VarMeanUnbiasedModule_basic " ,
" ViewCollapseDynamicWithAtenSizeIntModule_basic " ,
" ViewSizeFromOtherTensor_basic " ,
}
FX_IMPORTER_STABLEHLO_CRASHING_SET = {
" BatchNorm1DModule_basic " ,
" BatchNorm2DModule_basic " ,
" BatchNorm3DModule_basic " ,
" ResNet18Module_basic " ,
" ResNet18StaticModule_basic " ,
" MobilenetV3Module_basic " ,
" Conv2dBiasNoPaddingModule_basic " ,
}
2023-02-02 21:29:47 +08:00
STABLEHLO_PASS_SET = {
2024-02-16 01:08:48 +08:00
" AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic " ,
" AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic " ,
" AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic " ,
" AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic " ,
2024-04-11 17:02:59 +08:00
" AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic " ,
2023-07-29 21:55:49 +08:00
" AddIntModule_basic " ,
2023-06-21 01:14:09 +08:00
" AliasModule_basic " ,
2023-05-26 22:57:57 +08:00
" AllBoolFalseModule_basic " ,
" AllBoolTrueModule_basic " ,
" AnyBoolFalseModule_basic " ,
" AnyBoolTrueModule_basic " ,
2022-09-20 22:31:24 +08:00
" ArangeDtypeFloatModule_basic " ,
" ArangeDtypeIntModule_basic " ,
" ArangeFalsePinMemoryModule_basic " ,
" ArangeFloatModule_basic " ,
" ArangeIntModule_basic " ,
" ArangeNegativeStartFloatModule_basic " ,
" ArangeNegativeStartIntModule_basic " ,
" ArangeStartFloatModule_basic " ,
" ArangeStartIntModule_basic " ,
" ArangeStartNegativeStepFloatModule_basic " ,
" ArangeStartNegativeStepIntModule_basic " ,
2024-02-16 01:08:48 +08:00
" ArangeStartOutDtypeModule_basic " ,
" ArangeStartOutModule_basic " ,
2022-09-20 22:31:24 +08:00
" ArangeStartStepFloatModule_basic " ,
" ArangeStartStepIntModule_basic " ,
" ArangeZeroElementOutputModule_basic " ,
2024-02-16 01:08:48 +08:00
" ArgmaxModule_with_dim " ,
" AtenComplex64Module_basic " ,
" AtenFloatScalarModule_basic " ,
" AtenIntBoolOpConstFalseModule_basic " ,
" AtenIntBoolOpConstTrueModule_basic " ,
" AtenIntBoolOpModule_basic " ,
" AtenIntTensorByteDtypeModule_basic " ,
" AtenIntTensorCharDtypeModule_basic " ,
" AtenItemFpOpModule_basic " ,
" AtenItemIntOpModule_basic " ,
" AtenMmFloatTypes_basic " ,
" AtenMmIntTypes_basic " ,
2024-03-12 08:58:20 +08:00
" AtenRoundFloatHalfToEvenModule_basic " ,
" AtenRoundFloatModule_basic " ,
2024-02-16 01:08:48 +08:00
" AtenRoundIntModule_basic " ,
" AtenSubFloatModule_basic " ,
" AtenToDeviceModule_basic " ,
2024-04-09 11:06:53 +08:00
" Aten_CastFloatModule_basic " ,
2024-04-17 21:58:32 +08:00
" Aten_CastLongModule_basic " ,
2024-02-16 01:08:48 +08:00
" AvgPool1dStaticModule_basic " ,
" AvgPool2dStaticModule_basic " ,
" BaddbmmBroadcast1DInputModule_basic " ,
" BaddbmmBroadcast2DInputModule_basic " ,
" BaddbmmStaticModule_basic " ,
" BoolFloatConstantModule_basic " ,
" BoolFloatFalseModule_basic " ,
" BoolFloatTrueModule_basic " ,
" BoolIntConstantModule_basic " ,
" BoolIntFalseModule_basic " ,
" BoolIntTrueModule_basic " ,
" BoolTensorReturnFalseModule_basic " ,
" BoolTensorReturnMixedModule_basic " ,
" BoolTensorReturnTrueModule_basic " ,
" BroadcastListConstructWithMinusOneModule_basic " ,
2022-11-21 21:50:35 +08:00
" BroadcastToSameRankStaticModule_basic " ,
" BroadcastZeroRankInputStaticModule_basic " ,
2024-02-16 01:08:48 +08:00
" CeilFloatModule_basic " ,
" ChunkListUnpackUneven_Module_basic " ,
" ChunkListUnpack_Module_basic " ,
" CloneModule_basic " ,
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python
def my_repeat_interleave(input, repeats, dim=None):
if dim is None:
# Flatten the input and then repeat
return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
else:
# Calculate the shape after repeat
expanded_shape = list(input.shape)
expanded_shape[dim] *= repeats
# Repeat the tensor along the specified dimension
repeat_shape = [1] * (input.dim() + 1)
repeat_shape[dim + 1] = repeats
input = input.unsqueeze(-1)
# Tile and then reshape
tiled = torch.tile(input, repeat_shape)
# Rearrange and reshape
repeated = tiled.reshape(*expanded_shape)
return repeated
```
I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
Unexpected outcome summary: (onnx)
****** Failed tests - 1 tests
FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
@ trace item #0 - call to "forward"
@ output of call to "forward"
ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```
@rsuderman
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
" CollapseAllDimensionsModule_basic " ,
" CollapseStaticModule_basic " ,
2024-02-16 01:08:48 +08:00
" ConstantBoolParameterModule_basic " ,
" ContainsIntList_False " ,
" ContainsIntList_True " ,
" ContiguousModule_basic " ,
" Conv2dWithPaddingDilationStrideStaticModule_basic " ,
" Conv2dWithPaddingDilationStrideStaticModule_depthwise " ,
" Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier " ,
" Conv2dWithPaddingDilationStrideStaticModule_grouped " ,
" Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier " ,
" Convolution2DStaticModule_basic " ,
" ConvolutionBackwardModule2DStatic_basic " ,
" ConvolutionModule2DTransposeStridedStatic_basic " ,
" CosineSimilarityStaticBroadcastModule_basic " ,
" CosineSimilarityStaticModule_basic " ,
" CumsumInputDtypeInt32Module_basic " ,
2023-01-30 13:38:27 +08:00
" CumsumStaticModule_basic " ,
" CumsumStaticNegativeDimModule_basic " ,
2023-04-18 23:59:14 +08:00
" DetachModule_basic " ,
2024-02-16 01:08:48 +08:00
" DivFloatModule_basic " ,
" DivIntModule_basic " ,
" DropoutEvalFloatModule_basic " ,
" DropoutEvalIntModule_basic " ,
" ElementwiseAbsFloatModule_basic " ,
" ElementwiseAbsIntModule_basic " ,
" ElementwiseAddScalar_NumToTensorFloat_Module_basic " ,
" ElementwiseAddScalar_TensorLiteralInt32_Module_basic " ,
" ElementwiseAtenIsinfOpModule_basic " ,
" ElementwiseAtenIsneginfOpModule_basic " ,
" ElementwiseAtenIsposinfOpModule_basic " ,
2023-01-04 10:11:25 +08:00
" ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic " ,
" ElementwiseAtenLogicalNotOpModule_basic " ,
" ElementwiseAtenLogicalNotOpPromoteModule_basic " ,
" ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic " ,
" ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic " ,
2022-11-24 14:28:34 +08:00
" ElementwiseAtenWhereSelfModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseBinaryStaticShapeModule_basic " ,
2023-01-12 06:40:03 +08:00
" ElementwiseBitwiseAndStaticShapeModule_basic " ,
" ElementwiseBitwiseNotInt32Module_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseBitwiseNotInt64Module_basic " ,
2023-01-12 06:40:03 +08:00
" ElementwiseBitwiseOrStaticShapeModule_basic " ,
" ElementwiseBitwiseXorStaticShapeModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseCeilModule_basic " ,
2022-09-16 15:09:21 +08:00
" ElementwiseClampMaxModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseClampMinModule_basic " ,
2024-04-19 10:55:27 +08:00
" ElementwiseClampMinTensorFloatModule_basic " ,
" ElementwiseClampMinTensorIntModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseClampModule_basic " ,
2024-04-19 10:55:27 +08:00
" ElementwiseClampTensorFloatModule_basic " ,
" ElementwiseClampTensorIntModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseClampTensorInt8Module_basic " ,
" ElementwiseCloneChannelsLastMemoryFormatModule_basic " ,
" ElementwiseCloneContiguousModule_basic " ,
" ElementwiseCloneModule_basic " ,
" ElementwiseCosModule_basic " ,
2024-04-16 04:45:10 +08:00
" ElementwiseDivTensorRoundingModeFloorStaticModule_basic " ,
" ElementwiseDivTensorRoundingModeTruncStaticModule_basic " ,
" ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic " ,
" ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic " ,
" ElementwiseDivScalarRoundingModeFloorStaticModule_basic " ,
" ElementwiseDivScalarRoundingModeTruncStaticModule_basic " ,
" ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic " ,
" ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseErfModule_basic " ,
2022-09-08 10:15:36 +08:00
" ElementwiseExpModule_basic " ,
2024-04-22 10:20:49 +08:00
" ElementwiseExpm1IntModule_basic " ,
" ElementwiseExpm1Module_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseFloorIntModule_basic " ,
" ElementwiseFloorModule_basic " ,
2024-04-21 08:39:36 +08:00
" ElementwiseFmodTensor_Float_basic " ,
" ElementwiseFmodTensor_Int_Float_basic " ,
" ElementwiseFmodTensor_Int_basic " ,
2024-02-27 21:56:01 +08:00
" ElementwiseGeluApproximateTanhModule_basic " ,
2024-04-02 07:34:59 +08:00
" ElementwiseGeluModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseLeakyReluStaticModule_basic " ,
2022-09-08 10:15:36 +08:00
" ElementwiseLogModule_basic " ,
2024-04-23 19:06:55 +08:00
" ElementwiseLog10Module_basic " ,
" ElementwiseLog2Module_basic " ,
" ElementwiseLog10IntModule_basic " ,
" ElementwiseLog2IntModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseNanToNumModule_Basic " ,
" ElementwiseNeFloatTensorStaticModule_basic " ,
" ElementwiseNeIntTensorStaticModule_basic " ,
2022-09-08 10:15:36 +08:00
" ElementwiseNegModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseOrTensorStaticShapeModule_basic " ,
2024-04-08 20:24:17 +08:00
" ElementwiseAndScalarStaticShapeModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwisePowTensorBroadcastStaticModule_basic " ,
" ElementwisePowTensorStaticModule_basic " ,
2024-04-02 07:34:59 +08:00
" ElementwisePreluStaticModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseReciprocalModule_basic " ,
" ElementwiseReluModule_basic " ,
2024-04-21 00:03:37 +08:00
" ElementwiseRemainderTensorModule_Float_basic " ,
" ElementwiseRemainderTensorModule_Int_Float_basic " ,
" ElementwiseRemainderTensorModule_Int_basic " ,
2022-12-22 10:13:59 +08:00
" ElementwiseRsqrtModule_basic " ,
" ElementwiseSigmoidModule_basic " ,
2023-02-07 03:14:26 +08:00
" ElementwiseSinModule_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseSqrtModule_basic " ,
2022-09-23 10:24:36 +08:00
" ElementwiseToDtypeF32ToI64Module_basic " ,
2024-02-16 01:08:48 +08:00
" ElementwiseToDtypeI64ToI8Module_basic " ,
" ElementwiseToDtypeIdentityModule_basic " ,
" ElementwiseUnaryModule_basic " ,
2023-07-27 18:35:25 +08:00
" EmptyLikeMemoryFormatModule_basic " ,
" EmptyLikeModule_defaultDtype " ,
" EmptyLikeModule_falsePinMemory " ,
" EmptyLikeModule_float " ,
" EmptyLikeModule_int " ,
2024-02-16 01:08:48 +08:00
" EmptyModule_contiguous " ,
" EmptyModule_defaultDtype " ,
" EmptyModule_falsePinMemory " ,
" EmptyModule_float " ,
" EmptyModule_int " ,
" EmptyStridedModule_basic " ,
" EqIntModule_basic " ,
2022-09-23 10:24:36 +08:00
" ExpandAsIntModule_basic " ,
2024-04-02 07:34:59 +08:00
" FakeQuantizePerTensorAffineModule_basic " ,
" FakeQuantizePerTensorAffineRoundToEvenModule_basic " ,
2024-02-16 01:08:48 +08:00
" Fill_TensorFloat64WithFloat32Static_basic " ,
2023-07-27 18:35:25 +08:00
" Fill_TensorFloat64WithFloat32_basic " ,
" Fill_TensorFloat64WithFloat64_basic " ,
2023-05-12 07:41:46 +08:00
" Fill_TensorFloat64WithInt64Static_basic " ,
2024-02-16 01:08:48 +08:00
" Fill_TensorFloat64WithInt64_basic " ,
" FlattenRank0Module_basic " ,
2023-06-15 10:27:34 +08:00
" FlipModuleStaticShape_basic " ,
" FlipNegativeIndexModule_basic " ,
2022-09-23 10:24:36 +08:00
" FullLikeModuleDefaultDtype_basic " ,
" FullLikeModuleFalsePinMemory_basic " ,
" FullLikeModuleFloat2D_basic " ,
" FullLikeModuleFloat3DStatic_basic " ,
" FullLikeModuleFloat3D_basic " ,
" FullLikeModuleInt2DStatic_basic " ,
" FullLikeModuleInt2D_basic " ,
" FullLikeModuleInt3D_basic " ,
" FullModuleDefaultDtype_basic " ,
" FullModuleFalsePinMemory_basic " ,
" FullModuleFloat2D_basic " ,
" FullModuleFloat3D_basic " ,
" FullModuleInt2D_basic " ,
" FullModuleInt3D_basic " ,
2024-02-16 01:08:48 +08:00
" GeFloatIntModule_basic " ,
" GeFloatModule_basic " ,
" GeIntModule_basic " ,
2022-12-21 20:09:43 +08:00
" GeluBackwardModule_basic " ,
2024-02-16 01:08:48 +08:00
" GluStaticModule_basic " ,
" GtFloatIntModule_basic " ,
" GtIntModule_basic " ,
2023-05-25 02:13:57 +08:00
" IndexTensorMultiIndexStaticModule_basic " ,
2024-02-16 01:08:48 +08:00
" IndexTensorStaticModule_basic " ,
" IntFloatModule_basic " ,
" IsFloatingPointFloat_True " ,
" IsFloatingPointInt_False " ,
2023-01-04 00:30:16 +08:00
" LeakyReluBackwardStaticModule_basic " ,
2024-02-16 01:08:48 +08:00
" LenStrModule_basic " ,
" LiftFreshCopyModule_basic " ,
2024-04-02 07:34:59 +08:00
" LinspaceDtypeModule_basic " ,
" LinspaceEmptyModule_basic " ,
" LinspaceModule_basic " ,
" LinspaceOneSizeModule_basic " ,
" LinspaceTwoSizeModule_basic " ,
2024-02-16 01:08:48 +08:00
" MaskedFillScalarFloatValueStaticModule_basic " ,
" MaskedFillScalarIntValueStaticModule_basic " ,
" Matmul4dStatic_basic " ,
" Matmul_2d " ,
" Matmul_dot " ,
" Matmul_matvec " ,
" Matmul_vecmat " ,
[Stablehlo] Enhance broadcast pattern in matmul Ops (#3161)
To pass test "MatmulStaticBroadcast_basic" in stablehlo:
```python
class MatmulStaticBroadcast(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([4, 1, 6, 7], torch.float32, True),
([8, 1, 5, 7, 6], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.matmul(lhs, rhs)
@register_test_case(module_factory=lambda: MatmulStaticBroadcast())
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
```
2024-04-16 10:10:36 +08:00
" MatmulStaticBroadcast_basic " ,
2024-02-16 01:08:48 +08:00
" MaxPool2dStaticModule_basic " ,
" MeanDimAllReduceModule_basic " ,
2022-11-23 15:02:41 +08:00
" MeanDimEmptyDimModule_basic " ,
2022-09-08 10:15:36 +08:00
" MeanDtypeModule_basic " ,
2022-11-23 15:02:41 +08:00
" MeanDynamicSizesModule_basic " ,
" MeanModule_basic " ,
2024-02-16 01:08:48 +08:00
" Mlp2LayerModuleNoBias_basic " ,
" MmDagModule_basic " ,
" MmModule_basic " ,
" MmModule_chained " ,
2022-09-08 10:15:36 +08:00
" MmTanhModule_basic " ,
2024-02-16 01:08:48 +08:00
" MoveDimIntModule_basic " ,
" MoveDimIntNegativeIndexModule_basic " ,
" MulFloatModule_basic " ,
" MulIntModule_basic " ,
2022-11-21 21:50:35 +08:00
" Mv_basic " ,
2024-02-16 01:08:48 +08:00
" NarrowHorizontalTest2_basic " ,
" NarrowHorizontalTest_basic " ,
" NarrowTensorHorizontalModule_basic " ,
" NarrowTensorVerticalModule_basic " ,
" NarrowVerticalTest2_basic " ,
" NarrowVerticalTest_basic " ,
" NativeDropoutEvalFloatModule_basic " ,
" NeFloatIntModule_basic " ,
" NeIntModule_basic " ,
2023-05-19 10:07:35 +08:00
" NewEmptyModuleDefaultDtype_basic " ,
" NewEmptyModuleFalsePinMemory_basic " ,
" NewEmptyModuleFloat2D_basic " ,
" NewEmptyModuleFloat3D_basic " ,
" NewEmptyModuleInt2D_basic " ,
" NewEmptyModuleInt3D_basic " ,
" NewEmptyModuleLayoutIntDtype_basic " ,
" NewEmptyModuleNonDefaultFloatDtype_basic " ,
" NewEmptyModuleNonDefaultIntDtype_basic " ,
" NewEmptyStridedModuleDefaultDtype_basic " ,
2024-02-16 01:08:48 +08:00
" NewFullModuleDefaultDtype_basic " ,
" NewFullModuleFalsePinMemory_basic " ,
" NewFullModuleFloat3DStatic_basic " ,
" NewFullModuleFloat3D_basic " ,
" NewFullModuleInt2D_basic " ,
" NewFullModuleInt3D_basic " ,
2022-08-23 16:47:21 +08:00
" NewOnesModuleDefaultDtype_basic " ,
2024-02-16 01:08:48 +08:00
" NewOnesModuleFalsePinMemory_basic " ,
2022-08-23 16:47:21 +08:00
" NewOnesModuleFloat2D_basic " ,
" NewOnesModuleFloat3D_basic " ,
2024-02-16 01:08:48 +08:00
" NewOnesModuleInt2D_basic " ,
" NewOnesModuleInt3D_basic " ,
" NewZerosModuleDefaultDtype_basic " ,
" NewZerosModuleFalsePinMemory_basic " ,
" NewZerosModuleFloat2D_basic " ,
" NewZerosModuleFloat3D_basic " ,
" NewZerosModuleInt2D_basic " ,
" NewZerosModuleInt3D_basic " ,
2023-03-30 22:08:20 +08:00
" NewZerosStaticModuleLayoutStrided_basic " ,
2024-02-16 01:08:48 +08:00
" NumToTensorFloatModule_basic " ,
" NumToTensorIntModule_basic " ,
" NumelModule_basic " ,
" NumelZeroRankModule_basic " ,
" NumpyTRank0Module_basic " ,
" NumpyTRank1Module_basic " ,
" NumpyTRank2Module_basic " ,
" NumpyTRankNDynamicModule_basic " ,
" NumpyTRankNStaticModule_basic " ,
" OnesLikeModule_defaultDtype " ,
" OnesLikeModule_falsePinMemory " ,
" OnesLikeModule_float " ,
" OnesLikeModule_int " ,
" OnesModuleCPUDevice_basic " ,
" OnesModuleDefaultDtype_basic " ,
" OnesModuleFalsePinMemory_basic " ,
" OnesModuleFloat_basic " ,
" OnesModuleInt_basic " ,
" Permute0RankModule_basic " ,
" PermuteModule_basic " ,
" PermuteNegativeIndexModule_basic " ,
" PowIntFloatModule_basic " ,
[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (#3140)
Fix the case PrimListUnpackOp's result num is not equal to PrimList
length.
See the following example:
```python
def forward(self, x):
if len(x.shape) == 5:
b0, t, c0, h0, w0 = x.shape
b, c, h, w = torch.mul(b0, t), c0, h0, w0
else:
b1, c1, h1, w1 = x.shape
b, c, h, w = b1, c1, h1, w1
res = torch.reshape(x, [b, c, h, w])
return res
```
Without this fix, the following error message will occur:
```
/root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
2024-04-11 19:48:49 +08:00
" PrimListUnpackNumMismatchModule_basic " ,
2024-02-16 01:08:48 +08:00
" PrimMaxIntModule_basic " ,
" PrimMinIntDynamicModule_basic " ,
" PrimMinIntModule_basic " ,
" PrimsConvertElementTypeModule_basic " ,
2024-04-22 10:45:01 +08:00
" PrimsIotaModule_basic " ,
2024-02-16 01:08:48 +08:00
" PrimsSqueezeEmptyDimensionsModule_basic " ,
" PrimsViewOfModule_basic " ,
" PrimsViewOfZeroRankModule_basic " ,
" RandIntDtypeModule_basic " ,
" RandIntLowDtypeModule_basic " ,
" RandIntLowModule_basic " ,
" RandIntModule_basic " ,
" RandIntPinMemoryModule_basic " ,
2024-04-02 12:40:00 +08:00
" ReduceAmaxMultiDim_basic " ,
" ReduceAmaxOutOfOrderDim_basic " ,
" ReduceAmaxSingleDim_basic " ,
2024-02-16 01:08:48 +08:00
" ReduceFrobeniusNormModule_basic " ,
2024-04-02 12:40:00 +08:00
" ReduceMaxAllDims_basic " ,
" ReduceMaxAlongDimNegative_basic " ,
" ReduceMaxAlongDimSignedInt_basic " ,
" ReduceMaxAlongDim_basic " ,
" ReduceMaxFloatModule_basic " ,
" ReduceMaxSignedIntModule_basic " ,
" ReduceMaxUnsignedIntModule_basic " ,
" ReduceMinFloatModule_basic " ,
" ReduceMinSignedIntModule_basic " ,
" ReduceMinUnsignedIntModule_basic " ,
" ReduceSumDimIntListDtypeFloatModule_basic " ,
" ReduceSumDimIntListDtypeIntModule_basic " ,
" ReduceSumDimIntListElementTypeBoolModule_basic " ,
" ReduceSumDimIntListEmptyDimModule_basic " ,
" ReduceSumDimIntListFloatModule_basic " ,
" ReduceSumDimIntListIntModule_basic " ,
" ReduceSumDtypeFloatModule_basic " ,
" ReduceSumDtypeIntModule_basic " ,
" ReduceSumElementTypeBoolModule_basic " ,
" ReduceSumFloatModule_basic " ,
" ReduceSumSignedIntModule_basic " ,
2022-08-23 16:47:21 +08:00
" ReduceSumUnsignedIntModule_basic " ,
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python
def my_repeat_interleave(input, repeats, dim=None):
if dim is None:
# Flatten the input and then repeat
return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
else:
# Calculate the shape after repeat
expanded_shape = list(input.shape)
expanded_shape[dim] *= repeats
# Repeat the tensor along the specified dimension
repeat_shape = [1] * (input.dim() + 1)
repeat_shape[dim + 1] = repeats
input = input.unsqueeze(-1)
# Tile and then reshape
tiled = torch.tile(input, repeat_shape)
# Rearrange and reshape
repeated = tiled.reshape(*expanded_shape)
return repeated
```
I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
Unexpected outcome summary: (onnx)
****** Failed tests - 1 tests
FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
@ trace item #0 - call to "forward"
@ output of call to "forward"
ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```
@rsuderman
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
" RepeatInterleaveSelfIntModule_basic " ,
" RepeatInterleaveSelfIntNoDimModule_basic " ,
2024-02-16 01:08:48 +08:00
" ReturnThreeTensorFloat32_basic " ,
" ReturnTwoTensorF32I64_basic " ,
" RollModule_basic " ,
" RsubInt0d_NumToTensor_Module_basic " ,
" ScalarConstantTupleModule_basic " ,
" ScalarImplicitFloatModule_basic " ,
" ScalarImplicitIntModule_basic " ,
" ScalarTensorDefaultDtypeModule_basic " ,
" ScalarTensorFloat32Module_basic " ,
" ScalarTensorInt32Module_basic " ,
" ScalarTensorInt64Module_basic " ,
" SliceModule_basic " ,
" SliceNegIdxModule_basic " ,
" SliceOutOfLowerBoundStartIndexModule_basic " ,
" SliceOutOfUpperBoundIndexModule_basic " ,
" SliceOutOfUpperBoundIndexStaticModule_basic " ,
" SliceScatterModule_basic " ,
" SliceScatterNegativeDimModule_basic " ,
" SliceScatterNegativeEndModule_basic " ,
" SliceScatterStaticModule_basic " ,
" SliceScatterStepVariationModule_basic " ,
" SliceScatterZeroDimModule_basic " ,
" SliceSizeTwoStepModule_basic " ,
" SliceStartEqEndModule_basic " ,
" SliceStaticModule_basic " ,
" SliceWholeTensorModule_basic " ,
" SortIntListReverse_basic " ,
" SortIntList_basic " ,
" SplitTensorGetItem_Module_basic " ,
" SplitTensorLastSmallerModule_basic " ,
" SplitTensorListUnpackModule_basic " ,
" SplitTensorNegativeDimModule_basic " ,
" SplitWithSizesListUnpackModule_basic " ,
" SqrtIntConstantModule_basic " ,
" SqrtIntModule_basic " ,
" SqueezeDimModule_identity " ,
" SqueezeDimModule_unitDim " ,
" SqueezeModule_allUnitDim " ,
" SubFloatModule_basic " ,
" SubIntModule_basic " ,
" TModuleRank0_basic " ,
" TModuleRank1_basic " ,
2022-08-23 16:47:21 +08:00
" TModuleRank2_basic " ,
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%float1.000000e01 = torch.constant.float 1.000000e+01
%67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
return %67 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int45 = torch.constant.int 45
%67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
return %67 : !torch.vtensor<[],si32>
}
```
2024-04-19 22:17:06 +08:00
" TensorFloatModule_basic " ,
2024-02-16 01:08:48 +08:00
" TensorIntModule_basic " ,
2022-08-23 16:47:21 +08:00
" TensorLiteralModule_basic " ,
" TensorOpaqueLiteralModule_basic " ,
2024-02-16 01:08:48 +08:00
" TensorToBoolZeroRank_basic " ,
" TensorToFloatZeroRank_basic " ,
" TensorToIntZeroRank_basic " ,
" TensorsConcatModule_basic " ,
" TensorsConcatNegativeDimModule_basic " ,
" TensorsConcatNegativeDimStaticModule_basic " ,
" TensorsConcatPromoteDTypeModule_basic " ,
" TensorsConcatStaticModule_basic " ,
" TestF16Return_basic " ,
" TestMultipleTensorAndPrimitiveTypesReturn_basic " ,
" TestMultipleTensorReturn_basic " ,
" ToCopyBoolDTypeStaticModule_basic " ,
2022-11-21 21:50:35 +08:00
" ToDtypeBoolLayoutNoneStaticModule_basic " ,
2024-02-16 01:08:48 +08:00
" ToDtypeLayoutCPUModule_basic " ,
2022-09-23 10:24:36 +08:00
" ToDtypeLayoutNoneModule_basic " ,
" ToDtypeLayoutStridedModule_basic " ,
2024-04-02 07:34:59 +08:00
" TorchPrimLoopForLikeTensorArgModule_basic " ,
2024-02-16 01:08:48 +08:00
" TransposeIntModule_basic " ,
" TransposeIntNegDimsModule_basic " ,
" TupleModule_basic " ,
2023-07-20 09:51:58 +08:00
" TypeAsDifferentModule_basic " ,
2024-02-16 01:08:48 +08:00
" TypeAsSameModule_basic " ,
2022-09-23 10:24:36 +08:00
" TypeConversionF32ToF64Module_basic " ,
" TypeConversionF64ToF32Module_basic " ,
" TypeConversionI1ToF32Module_basic " ,
" TypeConversionI1ToF64Module_basic " ,
" TypeConversionI1ToI32Module_basic " ,
" TypeConversionI1ToI64Module_basic " ,
" TypeConversionI32ToI64Module_basic " ,
" TypeConversionI64ToI32Module_basic " ,
2024-02-16 01:08:48 +08:00
" UnsafeView1DFoldModule_basic " ,
" View1DFoldModule_basic " ,
" ZeroFloat32Module_basic " ,
" ZeroInt32Module_basic " ,
" ZeroInt64Module_basic " ,
" ZerosLikeModule_defaultDtype " ,
" ZerosLikeModule_falsePinMemory " ,
" ZerosLikeModule_float " ,
" ZerosLikeModule_int " ,
" ZerosModuleDefaultDtype_basic " ,
" ZerosModuleFalsePinMemory_basic " ,
" ZerosModuleFloat2D_basic " ,
" ZerosModuleFloat3D_basic " ,
" ZerosModuleInt2D_basic " ,
" ZerosModuleInt3D_basic " ,
2024-04-02 22:47:24 +08:00
" AtenEmbeddingBagStaticModule_basic " ,
" AtenEyeMModuleCPUDevice_basic " ,
" AtenEyeMModuleDefaultDtype_basic " ,
" AtenEyeMModuleFalsePinMemory_basic " ,
" AtenEyeMModuleFloat2D_basic " ,
" AtenEyeMModuleInt2D_basic " ,
" AtenEyeModuleCPUDevice_basic " ,
" AtenEyeModuleDefaultDtype_basic " ,
" AtenEyeModuleFalsePinMemory_basic " ,
" AtenEyeModuleFloat2D_basic " ,
" AtenEyeModuleInt2D_basic " ,
" AtenInstanceNormModule_basic " ,
" AtenLinalgCrossBroadcast_basic " ,
" AtenLinalgCrossCustomDim_basic " ,
" AtenLinalgCrossFloat_basic " ,
" AtenLinalgCrossInt_basic " ,
" AtenLinalgCrossNegativeDim_basic " ,
" BucketizeTensorStaticFloatModule_basic " ,
" BucketizeTensorStaticModule_basic " ,
" DropoutTrainStaticShapeModule_basic " ,
" ElementwiseWhereScalarOtherStaticModule_basic " ,
" ElementwiseWhereScalarSelfStaticModule_basic " ,
" EmbeddingModule1DIndices_basic " ,
" EmbeddingModuleF16_basic " ,
" EmbeddingModuleI32Static_basic " ,
" EmbeddingModuleI32_basic " ,
" EmbeddingModuleI64_basic " ,
" GatherStaticModule_basic " ,
" IndexSelectDynamicIndexSizeModule_basic " ,
" IndexSelectNegativeDimModule_basic " ,
" IndexSelectSingleIdxModule_basic " ,
" IndexSelectTwoIdxModule_basic " ,
" IndexSelectWholeDimensionModule_basic " ,
" IndexSelectWholeTensorModule_basic " ,
" IndexTensorStaticContiguousWithNoneModule_basic " ,
" IndexTensorStaticNonContiguousWithNoneModule_basic " ,
" LayerNormLastDimModule_basic " ,
" LayerNormModule_basic " ,
" LayerNormNormalizeOverAllDimsModule_basic " ,
" MaxPool2dWithIndicesStaticModule_basic " ,
" MeanDimAllReduceKeepdimModule_basic " ,
" NativeDropoutTrainStaticShapeModule_basic " ,
" NativeLayerNormModule4D_basic " ,
" NativeLayerNormModule_basic " ,
" NormalizeModule_basic " ,
" PrimsSqueezeModule_basic " ,
" RandModule_basic " ,
" ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic " ,
" SelectIntNegativeDimAndIndexStaticModule_basic " ,
" SelectScattertStaticModule_basic " ,
" SqueezeDimModule_static " ,
" SqueezeModule_static " ,
" TriuBroadcastModule_basic " ,
" TriuModule_basic " ,
" UnbindIntGetItem_Module_basic " ,
" UnbindIntListUnpack_Module_basic " ,
" UniformStaticShapeModule_basic " ,
2024-04-07 10:48:11 +08:00
" ArangeStartOutViewModule_basic " ,
" ConvolutionBackwardModule2DStrided_basic " ,
" EinsumStaticContractRhsModule_basic " ,
" EinsumStaticFourDimensionModule_basic " ,
" EinsumStaticModule_basic " ,
" EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic " ,
" EinsumStaticWithEllipsisSlicingModule_basic " ,
" FlattenStaticModule_basic " ,
" GroupNormModule_basic " ,
" GroupNormNoWeightAndBiasModule_basic " ,
" NativeGroupNormModule_basic " ,
" RepeatModule_basic " ,
" ReshapeAliasCollapseModule_basic " ,
" ReshapeAliasExpandModule_basic " ,
" ReshapeAsModule_basic " ,
" ReshapeExpandModule_basic " ,
" TileBigDimsSizeModule_basic " ,
" TileSmallDimsSizeModule_basic " ,
" UnflattenIntNegativeOneDimStaticModule_basic " ,
" UnflattenIntNegativeOneSizeStaticModule_basic " ,
" UnflattenIntStaticModule_basic " ,
" UnflattenStaticModule_basic " ,
" UniformNoCorrelationModule_basic " ,
" UnsafeViewCollapseModule_basic " ,
" UnsafeViewDynamicExpandModule_basic " ,
" UnsafeViewExpandModule_basic " ,
" ViewCollapseInferredDimModule_basic " ,
" ViewCollapseModule_basic " ,
" ViewCollapseOnesMiddleModule_basic " ,
" ViewDynamicExpandCollapseModule_basic " ,
" ViewDynamicExpandModule_basic " ,
" ViewExpandCollapseModule_basic " ,
" ViewExpandCollapseWithOnesModule_basic " ,
" ViewExpandDynamicDimModule_basic " ,
" ViewExpandInferredDimModule_basic " ,
" ViewExpandModule_basic " ,
" ViewExpandOnesBeforeAndAfterModule_basic " ,
" ViewExpandOnesMiddleModule_basic " ,
" ViewExpandOnesModule_basic " ,
" ViewNegativeStaticModule_basic " ,
" ViewNoChange1dModule_basic " ,
" ViewNoChange2dModule_basic " ,
" ViewNoChange3dModule_basic " ,
" ViewNoChangeStaticModule_basic " ,
" ViewOffsetBackwardTestStaticModule_basic " ,
" ViewOffsetTestStaticModule_basic " ,
" ViewTwoFiveThreeStaticModule_basic " ,
" ViewTwoToThreeStaticModule_basic " ,
2024-04-07 17:01:58 +08:00
" ElementwiseLog1pModule_basic " ,
2024-04-08 20:05:42 +08:00
" ElementwiseSgnModule_basic " ,
" ElementwiseSignIntModule_basic " ,
2024-04-23 16:24:53 +08:00
" ElementwiseAcosModule_basic " ,
" ElementwiseAsinModule_basic " ,
" ElementwiseAtanTensorFloatModule_basic " ,
2024-04-23 17:57:12 +08:00
" ElementwiseAcosIntModule_basic " ,
" ElementwiseAsinIntModule_basic " ,
" ElementwiseAtanTensorIntModule_basic " ,
" ElementwiseCosIntModule_basic " ,
" ElementwiseExpIntModule_basic " ,
" ElementwiseLogIntModule_basic " ,
" ElementwiseRsqrtIntModule_basic " ,
" ElementwiseSigmoidIntModule_basic " ,
" ElementwiseSinIntModule_basic " ,
" ElementwiseSqrtIntModule_basic " ,
" ElementwiseUnaryIntModule_basic " ,
2022-08-23 16:47:21 +08:00
}
2024-02-16 01:08:48 +08:00
STABLEHLO_CRASHING_SET = {
" AtenEmbeddingBagSumExample_basic " ,
2023-07-29 21:55:49 +08:00
}
2021-10-08 10:07:03 +08:00
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
2021-10-26 02:43:21 +08:00
TOSA_PASS_SET = {
2024-04-08 20:05:42 +08:00
" ElementwiseSgnModule_basic " ,
" ElementwiseSignIntModule_basic " ,
2023-12-13 06:22:25 +08:00
" AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic " ,
" AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic " ,
2024-04-11 17:02:59 +08:00
" AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic " ,
2021-11-25 06:01:48 +08:00
" AddCDivModule_basic " ,
2023-12-13 06:22:25 +08:00
" AddCDiv_Module_basic " ,
" AddCMulModule_basic " ,
" AddCMul_Module_basic " ,
" Add_Module_basic " ,
" AliasModule_basic " ,
" ArangeDtypeFloatModule_basic " ,
" ArangeIntModule_basic " ,
" ArangeNegativeStartIntModule_basic " ,
" ArangeStartIntModule_basic " ,
" ArangeStartNegativeStepIntModule_basic " ,
" ArangeStartOutModule_basic " ,
" ArangeStartOutViewModule_basic " ,
" ArangeStartStepIntModule_basic " ,
" ArangeZeroElementOutputModule_basic " ,
2024-02-28 05:40:55 +08:00
" ArangeDtypeIntModule_basic " ,
" ArangeFalsePinMemoryModule_basic " ,
" ArangeFloatModule_basic " ,
" ArangeNegativeStartFloatModule_basic " ,
" ArangeStartFloatModule_basic " ,
" ArangeStartNegativeStepFloatModule_basic " ,
" ArangeStartOutDtypeModule_basic " ,
" ArangeStartStepFloatModule_basic " ,
2024-03-20 06:19:29 +08:00
" ArgmaxIntModule_basic " ,
" ArgmaxIntModule_multiple_maxs " ,
" ArgmaxModule_basic " ,
2023-12-13 06:22:25 +08:00
" ArgmaxModule_keepDim " ,
" AtenComplex64Module_basic " ,
" AtenEyeMModuleCPUDevice_basic " ,
" AtenEyeMModuleDefaultDtype_basic " ,
" AtenEyeMModuleFalsePinMemory_basic " ,
" AtenEyeMModuleFloat2D_basic " ,
" AtenEyeModuleCPUDevice_basic " ,
" AtenEyeModuleDefaultDtype_basic " ,
" AtenEyeModuleFalsePinMemory_basic " ,
" AtenEyeModuleFloat2D_basic " ,
" AtenRoundIntModule_basic " ,
2024-02-19 22:23:48 +08:00
" AtenInstanceNormModule_basic " ,
2022-08-11 07:24:02 +08:00
" AtenToDeviceModule_basic " ,
2024-04-09 11:06:53 +08:00
" Aten_CastFloatModule_basic " ,
2023-12-13 06:22:25 +08:00
" BaddbmmBroadcast1DInputModule_basic " ,
" BaddbmmBroadcast2DInputModule_basic " ,
" BaddbmmDynamicModule_basic " ,
" BaddbmmStaticModule_basic " ,
" BaddbmmWithAlphaBetaModule_basic " ,
" BaddbmmWithAlphaModule_basic " ,
" BaddbmmWithBetaModule_basic " ,
" BatchNorm1DModule_basic " ,
" BatchNorm1DStaticShapeModule_basic " ,
" BatchNorm1DWith2DInputModule_basic " ,
" BatchNorm2DModule_basic " ,
" BatchNorm3DModule_basic " ,
2023-09-11 20:58:59 +08:00
" BmmFloatModule_basic " ,
2023-12-13 06:22:25 +08:00
" BoolTensorHandleSignless_basic " ,
" BoolTensorReturnFalseModule_basic " ,
" BoolTensorReturnMixedModule_basic " ,
" BoolTensorReturnTrueModule_basic " ,
" BroadcastListConstructWithMinusOneModule_basic " ,
" BroadcastToSameRankStaticModule_basic " ,
" BroadcastZeroRankInputStaticModule_basic " ,
" BucketizeTensorStaticFloatModule_basic " ,
" BucketizeTensorStaticModule_basic " ,
2024-01-31 09:43:21 +08:00
" CloneModule_basic " ,
2023-12-13 06:22:25 +08:00
" ChunkListUnpackUneven_Module_basic " ,
" ChunkListUnpack_Module_basic " ,
" ConstantBoolParameterModule_basic " ,
" ConstantPad2dStaticModule_basic " ,
" ConstantPadNdModule_basic " ,
" ConstantPadNdPartialStaticModule_basic " ,
" ConstantPadNdStaticModule_basic " ,
" ContiguousModule_basic " ,
" Conv2dBiasNoPaddingModule_basic " ,
" Conv2dNoPaddingModule_basic " ,
" Conv2dWithPaddingDilationStrideModule_basic " ,
" Conv2dWithPaddingDilationStrideStaticModule_basic " ,
" Conv2dWithPaddingDilationStrideStaticModule_depthwise " ,
" Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier " ,
" Conv2dWithPaddingModule_basic " ,
" Convolution2DStaticModule_basic " ,
2024-01-19 04:32:23 +08:00
" CosineSimilarityStaticModule_basic " ,
2023-12-13 06:22:25 +08:00
" DetachModule_basic " ,
" DropoutEvalFloatModule_basic " ,
" DropoutEvalIntModule_basic " ,
" DropoutModule_basic " ,
2023-12-10 12:30:37 +08:00
" EinsumStaticContractRhsModule_basic " ,
2023-12-13 06:22:25 +08:00
" EinsumStaticFourDimensionModule_basic " ,
" EinsumStaticModule_basic " ,
2024-03-28 03:42:10 +08:00
" EinsumStaticWithEllipsisSlicingModule_basic " ,
" EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic " ,
2024-02-09 06:53:40 +08:00
" ElementwiseAbsFloatModule_basic " ,
" ElementwiseAbsIntModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseAddModule_basic " ,
" ElementwiseAddScalarFloatModule_basic " ,
" ElementwiseAddScalarInt64Module_basic " ,
" ElementwiseAddScalarInt8Module_basic " ,
" ElementwiseAddScalarIntModule_basic " ,
" ElementwiseAddScalar_TensorLiteralInt32_Module_basic " ,
" ElementwiseAtenDivIntScalarModule_basic " ,
2023-12-29 09:20:32 +08:00
" ElementwiseAtenIsinfOpModule_basic " ,
2024-01-16 14:29:34 +08:00
" ElementwiseAtenIsneginfOpModule_basic " ,
" ElementwiseAtenIsposinfOpModule_basic " ,
2024-01-12 02:36:48 +08:00
" ElementwiseAtenLogicalOrOpBrodcastModule_basic " ,
" ElementwiseAtenLogicalOrOpDiffArgs1Module_basic " ,
" ElementwiseAtenLogicalOrOpDiffArgs2Module_basic " ,
" ElementwiseAtenLogicalOrOpDiffArgs3Module_basic " ,
" ElementwiseAtenLogicalOrOpModule_basic " ,
" ElementwiseAtenLogicalOrOpNegativeModule_basic " ,
" ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic " ,
" ElementwiseAtenLogicalOrOpRandomFloatModule_basic " ,
" ElementwiseAtenLogicalOrOpRandomModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseAtenWhereSelfModule_basic " ,
" ElementwiseBinaryModule_basic " ,
" ElementwiseBinaryStaticShapeModule_basic " ,
2023-01-12 06:40:03 +08:00
" ElementwiseBitwiseAndModule_basic " ,
" ElementwiseBitwiseAndStaticShapeModule_basic " ,
" ElementwiseBitwiseNotInt32Module_basic " ,
" ElementwiseBitwiseNotInt64Module_basic " ,
" ElementwiseBitwiseOrModule_basic " ,
" ElementwiseBitwiseOrStaticShapeModule_basic " ,
" ElementwiseBitwiseXorModule_basic " ,
" ElementwiseBitwiseXorStaticShapeModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseCeilModule_basic " ,
2024-01-12 02:36:48 +08:00
" ElementwiseClampMaxModule_basic " ,
" ElementwiseClampMinModule_basic " ,
" ElementwiseClampModule_basic " ,
2024-02-29 06:13:26 +08:00
" ElementwiseClampTensorInt8Module_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseCloneChannelsLastMemoryFormatModule_basic " ,
" ElementwiseCloneContiguousModule_basic " ,
" ElementwiseCloneModule_basic " ,
" ElementwiseDivScalarModule_basic " ,
2024-02-27 13:32:05 +08:00
" ElementwiseDivTensorIntegerModule_basic " ,
" ElementwiseDivTensorUnsignedIntegerModule_basic " ,
2024-04-16 04:45:10 +08:00
" ElementwiseDivScalarIntegerModule_basic " ,
" ElementwiseDivScalarUnsignedIntegerModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseEluModule_basic " ,
" ElementwiseEluNonDefaultModule_basic " ,
" ElementwiseEqBoolScalarModule_basic " ,
" ElementwiseEqDiffWidthScalarModule_basic " ,
" ElementwiseEqFloatScalarModule_basic " ,
" ElementwiseEqFloatTensorModule_basic " ,
" ElementwiseEqIntScalarModule_basic " ,
" ElementwiseEqIntTensorModule_basic " ,
" ElementwiseExpModule_basic " ,
" ElementwiseFlattenBroadcastModule_basic " ,
" ElementwiseFloorIntModule_basic " ,
" ElementwiseFloorModule_basic " ,
2023-06-11 02:45:35 +08:00
" ElementwiseGeFloatIntScalarModule_basic " ,
" ElementwiseGeFloatScalarModule_basic " ,
" ElementwiseGeIntScalarModule_basic " ,
" ElementwiseGeMixedIntScalarModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseGeluModule_basic " ,
2022-01-21 02:58:30 +08:00
" ElementwiseGtFloatScalarModule_basic " ,
" ElementwiseGtFloatTensorModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseGtIntScalarModule_basic " ,
2022-01-21 02:58:30 +08:00
" ElementwiseGtIntTensorModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseGtMixed2ScalarModule_basic " ,
" ElementwiseIsinfModule_basic " ,
2024-01-16 14:29:34 +08:00
" ElementwiseAtenIsneginfOpModule_basic " ,
" ElementwiseAtenIsposinfOpModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseIsnanModule_basic " ,
" ElementwiseLeFloatTensorModule_basic " ,
" ElementwiseLeIntTensorModule_basic " ,
" ElementwiseLeakyReluModule_basic " ,
" ElementwiseLeakyReluModule_basic " ,
" ElementwiseLeakyReluStaticModule_basic " ,
2024-02-01 01:39:38 +08:00
" ElementwiseLerpScalarIntModule_basic " ,
" ElementwiseLerpScalarFloatModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseLog2Module_basic " ,
" ElementwiseLogModule_basic " ,
2022-01-21 02:58:30 +08:00
" ElementwiseLtDiffWidthScalarModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseLtFloatScalarModule_basic " ,
2022-01-21 02:58:30 +08:00
" ElementwiseLtFloatTensorModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseLtIntScalarModule_basic " ,
2022-01-21 02:58:30 +08:00
" ElementwiseLtIntTensorModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseMaxOtherIntModule_basic " ,
" ElementwiseMaxOtherModule_basic " ,
" ElementwiseMaximumIntModule_basic " ,
" ElementwiseMaximumModule_basic " ,
" ElementwiseMinOtherIntModule_basic " ,
" ElementwiseMinOtherModule_basic " ,
" ElementwiseMinimumIntModule_basic " ,
" ElementwiseMinimumModule_basic " ,
" ElementwiseMulScalarModule_basic " ,
" ElementwiseMulScalarModule_float " ,
" ElementwiseMulScalarModule_float " ,
" ElementwiseMulScalarModule_int " ,
" ElementwiseMulTensorIntModule_basic " ,
2023-06-07 10:06:27 +08:00
" ElementwiseNeFloatScalarModule_basic " ,
" ElementwiseNeFloatTensorModule_basic " ,
" ElementwiseNeFloatTensorStaticModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseNeIntScalarModule_basic " ,
2023-06-07 10:06:27 +08:00
" ElementwiseNeIntTensorModule_basic " ,
" ElementwiseNeIntTensorStaticModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseNegModule_basic " ,
" ElementwiseOrTensorModule_basic " ,
" ElementwiseOrTensorStaticShapeModule_basic " ,
" ElementwisePowModule_basic " ,
2024-03-29 08:05:00 +08:00
" ElementwisePreluModule_basic " ,
" ElementwisePreluStaticModule_basic " ,
2022-01-21 02:58:30 +08:00
" ElementwiseReciprocalModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseRelu6Module_basic " ,
" ElementwiseReluModule_basic " ,
" ElementwiseRemainderScalarModule_Float_basic " ,
" ElementwiseRemainderScalarModule_Int_Float_basic " ,
" ElementwiseRemainderScalarModule_Int_basic " ,
" ElementwiseRemainderScalarModule_Int_basic " ,
" ElementwiseRsqrtModule_basic " ,
2023-12-14 12:28:08 +08:00
" ElementwiseSeluModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseSigmoidModule_basic " ,
" ElementwiseSignModule_basic " ,
" ElementwiseSqrtIntModule_basic " ,
" ElementwiseSqrtModule_basic " ,
" ElementwiseSubScalarFloatModule_basic " ,
2022-02-12 04:30:02 +08:00
" ElementwiseSubScalarIntModule_basic " ,
2023-12-13 06:22:25 +08:00
" ElementwiseSubTensorInt8Module_basic " ,
" ElementwiseToDtypeIdentityModule_basic " ,
" ElementwiseUnaryModule_basic " ,
2022-11-15 01:09:15 +08:00
" ElementwiseUnsqueezeBroadcastModule_basic " ,
2023-04-19 04:36:57 +08:00
" ElementwiseWhereScalarModule_basic " ,
2024-01-16 14:29:34 +08:00
" ElementwiseNanToNumModule_Basic " ,
2023-12-13 06:22:25 +08:00
" EmbeddingModule1DIndices_basic " ,
" EmbeddingModuleI32Static_basic " ,
" FlattenRank0Module_basic " ,
" FlattenStaticModule_basic " ,
2024-03-20 06:19:29 +08:00
" FlattenDynamicModuleCollapseAll_basic " ,
2023-04-19 04:36:57 +08:00
" FullLikeModuleFloat3DStatic_basic " ,
2023-12-13 06:22:25 +08:00
" FullLikeModuleInt2DStatic_basic " ,
2023-04-19 04:36:57 +08:00
" FullModuleDefaultDtype_basic " ,
2023-12-13 06:22:25 +08:00
" FullModuleFloat2D_basic " ,
2023-04-19 04:36:57 +08:00
" FullModuleFloat3D_basic " ,
2023-12-13 06:22:25 +08:00
" FullModuleInt3D_basic " ,
" GatherStaticModule_basic " ,
" GeluBackwardModule_basic " ,
" GluStaticModule_basic " ,
" HardTanhIntModule_basic " ,
" HardTanhModule_basic " ,
" HardsigmoidModule_basic " ,
" HardsigmoidRandomModule_basic " ,
" HardswishModule_basic " ,
" HardswishRandomModule_basic " ,
" HardtanhBackward_basic " ,
" IndexPutImpl2DNoneIndexStaticModule_basic " ,
" IndexTensorMultiIndexStaticModule_basic " ,
" IndexTensorStaticModule_basic " ,
" IscloseStaticModuleTrue_basic " ,
" IscloseStaticModule_basic " ,
" LayerNormNormalizeOverAllDimsModule_basic " ,
" LeakyReluBackwardModule_basic " ,
" LeakyReluBackwardStaticModule_basic " ,
" LiftFreshCopyModule_basic " ,
2024-01-19 04:32:23 +08:00
" LinalgVectorNormKeepDimModule_basic " ,
" LinalgVectorNormModule_basic " ,
2024-03-06 08:31:01 +08:00
" LinalgNormKeepDimModule_basic " ,
2023-12-13 06:22:25 +08:00
" MaskedFillScalarDefaultModule_basic " ,
" MaskedFillScalarIntValueModule_basic " ,
" MaskedFillScalarIntValueStaticModule_basic " ,
" MaskedFillTensorIntValueStaticModule_basic " ,
" Matmul4dStatic_basic " ,
" Matmul_3d " ,
" Matmul_dot " ,
2024-02-29 01:46:58 +08:00
" MatmulStaticBroadcast_basic " ,
2023-12-13 06:22:25 +08:00
" MaxPool2dEmptyStrideStaticModule_basic " ,
" MaxPool2dStaticCeilModeTrueModule_basic " ,
" MaxPool2dStaticModule_basic " ,
" MeanModule_basic " ,
" MmDagModule_basic " ,
" MoveDimIntModule_basic " ,
" MoveDimIntModule_basic " ,
" MoveDimIntNegativeIndexModule_basic " ,
" MseLossNoReductionModule_basic " ,
" NativeLayerNormModule4D_basic " ,
2023-09-12 22:29:08 +08:00
" NewFullModuleDefaultDtype_basic " ,
" NewFullModuleFalsePinMemory_basic " ,
" NewFullModuleFloat2D_basic " ,
" NewFullModuleFloat3DStatic_basic " ,
" NewFullModuleFloat3D_basic " ,
" NewFullModuleInt2DStatic_basic " ,
2023-12-13 06:22:25 +08:00
" NewOnesModuleDefaultDtype_basic " ,
" NewOnesModuleFalsePinMemory_basic " ,
" NewOnesModuleFloat2D_basic " ,
" NewOnesModuleFloat3D_basic " ,
" NewOnesModuleInt2D_basic " ,
" NewOnesModuleInt3D_basic " ,
" NewZerosModuleDefaultDtype_basic " ,
" NewZerosModuleFalsePinMemory_basic " ,
" NewZerosModuleFloat2D_basic " ,
" NewZerosModuleFloat3D_basic " ,
" NewZerosModuleInt2D_basic " ,
" NewZerosModuleInt3D_basic " ,
" NewZerosStaticModuleLayoutStrided_basic " ,
2024-01-19 04:32:23 +08:00
" NormalizeModule_basic " ,
" NormScalarOptDimKeepDimModule_basic " ,
" NormScalarOptDimModule_basic " ,
2023-04-19 04:36:57 +08:00
" NumToTensorFloatModule_basic " ,
2023-12-13 06:22:25 +08:00
" NumToTensorIntModule_basic " ,
" NumpyTRank0Module_basic " ,
" NumpyTRank1Module_basic " ,
" NumpyTRank2Module_basic " ,
" NumpyTRankNDynamicModule_basic " ,
" NumpyTRankNStaticModule_basic " ,
" OnesModuleCPUDevice_basic " ,
" OnesModuleDefaultDtype_basic " ,
" OnesModuleFalsePinMemory_basic " ,
" OnesModuleFloat_basic " ,
" OnesModuleInt_basic " ,
" PadModule_basic " ,
" PadWithNoneValModule_basic " ,
" Permute0RankModule_basic " ,
" PermuteModule_basic " ,
" PermuteNegativeIndexModule_basic " ,
[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (#3140)
Fix the case PrimListUnpackOp's result num is not equal to PrimList
length.
See the following example:
```python
def forward(self, x):
if len(x.shape) == 5:
b0, t, c0, h0, w0 = x.shape
b, c, h, w = torch.mul(b0, t), c0, h0, w0
else:
b1, c1, h1, w1 = x.shape
b, c, h, w = b1, c1, h1, w1
res = torch.reshape(x, [b, c, h, w])
return res
```
Without this fix, the following error message will occur:
```
/root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
2024-04-11 19:48:49 +08:00
" PrimListUnpackNumMismatchModule_basic " ,
2024-04-22 10:45:01 +08:00
" PrimsIotaModule_basic " ,
2023-12-13 06:22:25 +08:00
" PrimsSqueezeEmptyDimensionsModule_basic " ,
" PrimsSqueezeModule_basic " ,
" PrimsViewOfModule_basic " ,
" PrimsViewOfZeroRankModule_basic " ,
2022-10-18 12:22:53 +08:00
" ReduceSumDimIntListFloatModule_basic " ,
" ReduceSumDimIntListIntModule_basic " ,
" ReduceSumDimIntListKeepDimFloatModule_basic " ,
" ReduceSumDimIntListKeepDimIntModule_basic " ,
2023-12-13 06:22:25 +08:00
" ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic " ,
2022-10-18 12:22:53 +08:00
" ReduceSumFloatModule_basic " ,
" ReduceSumSignedIntModule_basic " ,
" ReduceSumUnsignedIntModule_basic " ,
2023-03-15 23:42:15 +08:00
" RepeatModule_basic " ,
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python
def my_repeat_interleave(input, repeats, dim=None):
if dim is None:
# Flatten the input and then repeat
return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
else:
# Calculate the shape after repeat
expanded_shape = list(input.shape)
expanded_shape[dim] *= repeats
# Repeat the tensor along the specified dimension
repeat_shape = [1] * (input.dim() + 1)
repeat_shape[dim + 1] = repeats
input = input.unsqueeze(-1)
# Tile and then reshape
tiled = torch.tile(input, repeat_shape)
# Rearrange and reshape
repeated = tiled.reshape(*expanded_shape)
return repeated
```
I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
Unexpected outcome summary: (onnx)
****** Failed tests - 1 tests
FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
@ trace item #0 - call to "forward"
@ output of call to "forward"
ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```
@rsuderman
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
" RepeatInterleaveSelfIntNoDimModule_basic " ,
2023-12-13 06:22:25 +08:00
" ResNet18StaticModule_basic " ,
2024-04-02 07:34:59 +08:00
" ReshapeAliasCollapseModule_basic " ,
" ReshapeAliasExpandModule_basic " ,
2023-12-13 06:22:25 +08:00
" ReshapeAsModule_basic " ,
" ReshapeCollapseModule_basic " ,
2024-04-02 07:34:59 +08:00
" ReshapeExpandModule_basic " ,
2023-12-13 06:22:25 +08:00
" ReturnThreeTensorFloat32_basic " ,
" ReturnTwoTensorF32I64_basic " ,
" RsubFloatModule_basic " ,
" RsubFloatModule_noalpha_basic " ,
" RsubInt0d_NumToTensor_Module_basic " ,
2023-06-01 11:38:50 +08:00
" ScalarTensorDefaultDtypeModule_basic " ,
" ScalarTensorFloat32Module_basic " ,
" ScalarTensorInt32Module_basic " ,
" ScalarTensorInt64Module_basic " ,
2023-12-13 06:22:25 +08:00
" SelectIntNegativeDimAndIndexStaticModule_basic " ,
" SiluModule_basic " ,
" SliceOutOfUpperBoundIndexStaticModule_basic " ,
" SliceStaticModule_basic " ,
2023-05-24 03:43:33 +08:00
" SplitTensorGetItem_Module_basic " ,
2023-12-13 06:22:25 +08:00
" SplitTensorLastSmallerModule_basic " ,
2023-06-07 01:38:04 +08:00
" SplitTensorListUnpackModule_basic " ,
2023-07-20 15:53:54 +08:00
" SplitTensorNegativeDimModule_basic " ,
2023-09-04 09:59:26 +08:00
" SplitWithSizesListUnpackModule_basic " ,
2023-12-13 06:22:25 +08:00
" SquareModule_basic " ,
" SqueezeDimModule_identity " ,
" SqueezeDimModule_static " ,
" SqueezeDimModule_unitDim " ,
" SqueezeModule_allUnitDim " ,
" SqueezeModule_broadcast " ,
" SqueezeModule_noUnitDim " ,
" SqueezeModule_static " ,
" TModuleRank0_basic " ,
" TModuleRank1_basic " ,
" TModuleRank2_basic " ,
" TanhBackward_basic " ,
[Torch] Add folder for AtenIntOp, AtenFloatOp (#3189)
See unit test below:
```
// CHECK-LABEL: func.func @torch.aten.tensor.float(
// CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor<f32>) : !torch.vtensor<[],f32>
func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> {
%none = torch.constant.none
%false = torch.constant.bool false
%float1.000000e01 = torch.constant.float 1.000000e+01
%67 = torch.aten.tensor.float %float1.000000e01, %none, %none, %false : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],f32>
return %67 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func.func @torch.aten.tensor.int(
// CHECK-NEXT: torch.vtensor.literal(dense<45> : tensor<si32>) : !torch.vtensor<[],si32>
func.func @torch.aten.tensor.int() -> !torch.vtensor<[],si32> {
%none = torch.constant.none
%false = torch.constant.bool false
%int45 = torch.constant.int 45
%67 = torch.aten.tensor.int %int45, %none, %none, %false : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si32>
return %67 : !torch.vtensor<[],si32>
}
```
2024-04-19 22:17:06 +08:00
" TensorFloatModule_basic " ,
" TensorIntModule_basic " ,
2023-12-13 06:22:25 +08:00
" TensorLiteralModule_basic " ,
" TensorOpaqueLiteralModule_basic " ,
" TensorsConcatNegativeDimStaticModule_basic " ,
" TensorsConcatStaticModule_basic " ,
" TestF16Return_basic " ,
" TestMultipleTensorReturn_basic " ,
" Threshold1dFloatModule_basic " ,
" Threshold1dIntI32Module_basic " ,
" Threshold2dFloatModule_basic " ,
" Threshold3dFloatModule_basic " ,
" TileBigDimsSizeModule_basic " ,
" TileSmallDimsSizeModule_basic " ,
" ToCopyBoolDTypeStaticModule_basic " ,
" ToDtypeBoolLayoutNoneStaticModule_basic " ,
" TransposeIntModule_basic " ,
" TransposeIntNegDimsModule_basic " ,
2023-07-18 22:32:26 +08:00
" TupleModule_basic " ,
2023-12-13 06:22:25 +08:00
" TypeAsSameModule_basic " ,
" TypePromotionAlphaWiderModule_basic " ,
" TypePromotionDifferentCategoryModule_basic " ,
" TypePromotionSameCategoryDifferentWidthModule_basic " ,
" TypePromotionSameCategoryZeroRankWider_basic " ,
" TypePromotionZeroRankHigherCategoryModule_basic " ,
" UnbindIntGetItem_Module_basic " ,
" UnbindIntListUnpack_Module_basic " ,
" UnflattenIntNegativeOneDimStaticModule_basic " ,
" UnflattenIntNegativeOneSizeStaticModule_basic " ,
" UnflattenIntStaticModule_basic " ,
" UnflattenStaticModule_basic " ,
" UnsafeView1DFoldModule_basic " ,
2024-04-02 07:34:59 +08:00
" UnsafeViewCollapseModule_basic " ,
" UnsafeViewDynamicExpandModule_basic " ,
2023-12-13 06:22:25 +08:00
" UnsafeViewExpandModule_basic " ,
" View1DFoldModule_basic " ,
2024-04-02 07:34:59 +08:00
" ViewCollapseModule_basic " ,
2023-12-13 06:22:25 +08:00
" ViewCollapseInferredDimModule_basic " ,
" ViewCollapseOnesMiddleModule_basic " ,
" ViewDoubleMergeStaticModule_basic " ,
2024-04-02 07:34:59 +08:00
" ViewDynamicExpandCollapseModule_basic " ,
" ViewDynamicExpandModule_basic " ,
2023-12-13 06:22:25 +08:00
" ViewExpandCollapseModule_basic " ,
" ViewExpandCollapseWithOnesModule_basic " ,
2024-04-02 07:34:59 +08:00
" ViewExpandDynamicDimModule_basic " ,
2023-12-13 06:22:25 +08:00
" ViewExpandInferredDimModule_basic " ,
" ViewExpandModule_basic " ,
" ViewExpandOnesBeforeAndAfterModule_basic " ,
" ViewExpandOnesMiddleModule_basic " ,
" ViewExpandOnesMiddleOppModule_basic " ,
" ViewExpandOnesModule_basic " ,
" ViewFiveTestStaticModule_basic " ,
" ViewNegativeStaticModule_basic " ,
2024-04-02 07:34:59 +08:00
" ViewNoChange1dModule_basic " ,
" ViewNoChange2dModule_basic " ,
" ViewNoChange3dModule_basic " ,
2023-12-13 06:22:25 +08:00
" ViewNoChangeStaticModule_basic " ,
" ViewOffsetBackwardTestStaticModule_basic " ,
" ViewOffsetTestStaticModule_basic " ,
" ViewTwoFiveThreeStaticModule_basic " ,
" ViewTwoToThreeStaticModule_basic " ,
" ZerosModuleDefaultDtype_basic " ,
" ZerosModuleFalsePinMemory_basic " ,
" ZerosModuleFloat2D_basic " ,
" ZerosModuleFloat3D_basic " ,
" ZerosModuleInt2D_basic " ,
" ZerosModuleInt3D_basic " ,
" _LogSoftmaxModuleStable_basic " ,
2024-03-14 08:28:33 +08:00
" LinspaceModule_basic " ,
" LinspaceOneSizeModule_basic " ,
" LinspaceTwoSizeModule_basic " ,
2024-03-21 02:04:02 +08:00
" TorchPrimLoopForLikeTensorArgModule_basic "
2021-10-26 02:43:21 +08:00
}
2022-06-10 03:56:01 +08:00
2023-07-13 21:07:54 +08:00
MAKE_FX_TOSA_PASS_SET = ( TOSA_PASS_SET | {
### Tests additionally passing in make_fx_tosa
2023-12-13 06:22:25 +08:00
" AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic " ,
" AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic " ,
[RFC] general support for Adaptive Pooling Ops (#2661)
Adaptive pooling ops can only be decomposed into their non-adaptive
counterparts in trivial cases.
For example, the current decomposition for AtenAdaptiveAvgPool1dOp in
DecomposeComplexOps.cpp supports outSize = inSize (i.e., do literally
nothing), and outSize = 1 (i.e., do a batched average).
The reason adaptive pooling ops are difficult to lower to linalg is that
they are not constantly strided. They are computed by taking an input
tensor of shape (N, C, Hin), and an output size Hout, and computing the
output tensor at position (n,c, h) in the following way:
1. compute st(h) = (h*Hin)//Hout
2. compute en(h) = 1 + ((h+1)*Hin -1)//Hout
3. apply a computation (max or avg) to the slice: INPUT[n, c,
st(h):en(h)]
The provided sample implementation (for ConvertAtenAdaptiveAvgPool1dOp)
uses tensor.extract to access the input tensor inside the payload of a
linalg generic op. This is likely an unattractive use of linalg generic
ops, which is why I am asking for some more targeted feedback on the
validity of this approach before attempting to support the many other
adaptive pooling ops.
Specifically:
- Is the performance of this implementation bad enough to warrant
targeting different dialects entirely? e.g. TMtensor/linalg ext/ etc.
- If the provided implementation is of acceptable performance to the
community, then is it permissable to remove the Adaptive pooling
decompositions from DecomposeComplexOps.cpp? Based on the current
structure of the -torch-decompose-complex-ops pass, it does not seem
possible to only decompose the adaptive ops in special cases (it seems
to get stuck in an infinite loop on a match failure). I would be happy
to instead incorporate the case logic into the conversion directly, and
remove the decompositions once they are rendered completely obsolete.
As long as this approach is acceptable, I can clean up the
implementation with some helper functions, and quickly add support for
each of the remaining Adaptive pooling ops.
2024-01-10 03:14:10 +08:00
" AdaptiveAvgPool1dStaticEvenMultiple_basic " ,
2024-01-19 04:32:23 +08:00
" CosineSimilarityModule_basic " ,
2023-07-13 21:07:54 +08:00
" NativeGroupNormBackwardModule_basic " ,
2024-01-19 04:32:23 +08:00
" ReduceFrobeniusNormKeepDimModule_basic " ,
" ReduceFrobeniusNormModule_basic " ,
2023-07-20 15:53:54 +08:00
" SliceWholeTensorModule_basic " ,
2023-07-13 21:07:54 +08:00
" TensorFloatModule_basic " ,
" TensorIntModule_basic " ,
2023-12-29 09:20:32 +08:00
" AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic " ,
" AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic " ,
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python
def my_repeat_interleave(input, repeats, dim=None):
if dim is None:
# Flatten the input and then repeat
return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
else:
# Calculate the shape after repeat
expanded_shape = list(input.shape)
expanded_shape[dim] *= repeats
# Repeat the tensor along the specified dimension
repeat_shape = [1] * (input.dim() + 1)
repeat_shape[dim + 1] = repeats
input = input.unsqueeze(-1)
# Tile and then reshape
tiled = torch.tile(input, repeat_shape)
# Rearrange and reshape
repeated = tiled.reshape(*expanded_shape)
return repeated
```
I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
Unexpected outcome summary: (onnx)
****** Failed tests - 1 tests
FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
@ trace item #0 - call to "forward"
@ output of call to "forward"
ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```
@rsuderman
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
" RepeatInterleaveSelfIntModule_basic " ,
2024-03-21 02:04:02 +08:00
" TorchPrimLoopForLikeTensorArgModule_basic " ,
2024-04-02 07:34:59 +08:00
" ViewSizeDimFollowedByCollapsedOnesModule_basic " ,
" ViewSizeDimFollowedByExpandedOnesModule_basic " ,
" ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic " ,
" ViewSizeDimLedByCollapsedOnesModule_basic " ,
" ViewSizeFromOtherTensor_basic " ,
2023-07-13 21:07:54 +08:00
} ) - {
### Test failing in make_fx_tosa but not in tosa
# Dynamic shape, has extra unsupported broadcast ops
" Matmul_3d " ,
2024-02-29 01:46:58 +08:00
" MatmulStaticBroadcast_basic " ,
2023-07-13 21:07:54 +08:00
# failed to legalize operation 'torch.aten.max_pool2d_with_indices
" MaxPool2dEmptyStrideStaticModule_basic " ,
" MaxPool2dStaticCeilModeTrueModule_basic " ,
" MaxPool2dStaticModule_basic " ,
" ResNet18StaticModule_basic " ,
# Unimplemented operator 'aten._index_put_impl_.hacked_twin'
" IndexPutImpl1DFloatNonAccumulateModule_basic " ,
" IndexPutImpl1DIntNonAccumulateModule_basic " ,
2023-07-26 15:30:13 +08:00
# RuntimeError: The size of tensor a (7) must match the size of tensor b (3) at non-singleton dimension 1
" Add_Module_basic " ,
2023-11-02 02:23:28 +08:00
# failed to legalize operation 'torch.aten.to.dtype' that was explicitly marked illegal
" AtenEyeModuleInt2D_basic " ,
" AtenEyeMModuleInt2D_basic " ,
2023-07-13 21:07:54 +08:00
2023-12-08 15:13:42 +08:00
" Conv2dBiasNoPaddingModule_basic " ,
" Conv2dNoPaddingModule_basic " ,
" Conv2dWithPaddingDilationStrideModule_basic " ,
" Conv2dWithPaddingModule_basic " ,
2024-02-19 22:23:48 +08:00
" AtenInstanceNormModule_basic " ,
2024-03-29 08:05:00 +08:00
# failed to legalize operation 'torch.operator'
" ElementwisePreluModule_basic " ,
" ElementwisePreluStaticModule_basic " ,
2024-04-02 07:34:59 +08:00
# Shape Related failures
[Torch] Fix PrimListUnpackOp::getCanonicalizationPatterns (#3140)
Fix the case PrimListUnpackOp's result num is not equal to PrimList
length.
See the following example:
```python
def forward(self, x):
if len(x.shape) == 5:
b0, t, c0, h0, w0 = x.shape
b, c, h, w = torch.mul(b0, t), c0, h0, w0
else:
b1, c1, h1, w1 = x.shape
b, c, h, w = b1, c1, h1, w1
res = torch.reshape(x, [b, c, h, w])
return res
```
Without this fix, the following error message will occur:
```
/root/torch-mlir/externals/llvm-project/mlir/lib/IR/PatternMatch.cpp:118: virtual void mlir::RewriterBase::replaceOp(mlir::Operation *, mlir::ValueRange): Assertion `op->getNumResults() == newValues.size() && "incorrect # of replacement values"' failed.
```
2024-04-11 19:48:49 +08:00
" PrimListUnpackNumMismatchModule_basic " ,
2024-04-02 07:34:59 +08:00
" ReshapeExpandModule_basic " ,
" UnsafeViewCollapseModule_basic " ,
" UnsafeViewDynamicExpandModule_basic " ,
" ViewCollapseModule_basic " ,
" ViewDynamicExpandCollapseModule_basic " ,
" ViewDynamicExpandModule_basic " ,
" ViewExpandDynamicDimModule_basic " ,
" ViewNoChange1dModule_basic " ,
" ViewNoChange2dModule_basic " ,
" ViewNoChange3dModule_basic " ,
2023-12-08 15:13:42 +08:00
}
2023-07-13 21:07:54 +08:00
2023-08-30 18:29:39 +08:00
LTC_CRASHING_SET = {
# TODO: update test to move all inputs to the lazy device. Otherwise test fails with:
# Check failed: lazy_tensor Input tensor is not a lazy tensor: CPUBoolType.
" HBC_basic " ,
}
2022-06-10 03:56:01 +08:00
LTC_XFAIL_SET = {
2024-03-21 02:04:02 +08:00
" TorchPrimLoopForLikeTensorArgModule_basic "
2023-11-16 00:34:38 +08:00
" CollapseAllDimensionsModule_basic " ,
" CollapseRank1DynamicModule_basic " ,
" CollapseStaticModule_basic " ,
" CollapsePartialDynamicModule_basic " ,
" CollapseFullDynamicModule_basic " ,
2023-11-21 23:56:09 +08:00
" SplitDimStaticModule_basic " ,
" SplitDimDynamicModule_basic " ,
Decomposition of aten.pixel_shuffle with static input shape (#2550)
For static tests (that is when the shape is know) for example:
```
@annotate_args([None, ([3, 18, 2, 2], torch.float32, True)])
```
The e2e passes. But only if the replacement op's return type is set as
undefined (optional shape and type must be explicitly made unset),
otherwise there's a error about the function return type.
For dynamic cases, for example if the above is replaced with
```
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
```
There is a failure to lower to linalg from torch ("view op explicitly
labelled as illegal"). This seems to be because the support for lowering
from torch to linalg with dynamic shapes is limited.
2023-11-08 21:52:44 +08:00
" PixelShuffleModuleStaticRank3Int64_basic " ,
" PixelShuffleModuleStaticRank4Float32_basic " ,
2023-11-23 04:31:06 +08:00
" PixelShuffleModuleFullDynamic_basic " ,
" PixelShuffleModuleSpatiallyDynamic_basic " ,
" PixelShuffleModuleSpatiallyStatic_basic " ,
2024-01-24 13:30:03 +08:00
" ConvTbcModule_basic " ,
2022-07-14 01:28:05 +08:00
" _Convolution2DAllFalseModule_basic " ,
" _Convolution2DBenchmarkModule_basic " ,
" _Convolution2DCudnnModule_basic " ,
" _Convolution2DDeterministicModule_basic " ,
" _Convolution2DTF32Module_basic " ,
2022-08-22 13:44:44 +08:00
" _ConvolutionDeprecated2DAllFalseModule_basic " ,
" _ConvolutionDeprecated2DBenchmarkModule_basic " ,
2022-09-27 00:16:49 +08:00
" _ConvolutionDeprecated2DCudnnModule_basic " ,
2022-08-22 13:44:44 +08:00
" _ConvolutionDeprecated2DDeterministicModule_basic " ,
Add aten.pool_max3d support to torch-to-linalg (#2735)
Added verification logic to the abstract_interpreter_lib_gen.py
Also made some unit tests
Initially, I thought we can use `linalg::pooling_ndhwc_max` to help
implement this problem. However, on a 5-dimensional matrix it does the
pooling on dimensions (2, 3, 4) which is not what we want. We want
pooling on dimensions (3, 4, 5).
To achieve this, we would need to lower our code using the `linalg`
dialect.
Turns out the pooling code in `linalg` looks like this.
```
func @max_pooling_ncdhw(%I: memref<?x?x?x?x?xf32>, %K: memref<3xindex>, %O: memref<?x?x?x?x?xf32>,
%strides: memref<3xindex>, %dilations: memref<3xindex>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%N = memref.dim %I, %c0 : memref<?x?x?x?x?xf32>
%C = memref.dim %I, %c1 : memref<?x?x?x?x?xf32>
%D = memref.dim %I, 2 : memref<?x?x?x?x?xf32>
%H = memref.dim %I, 3 : memref<?x?x?x?x?xf32>
%W = memref.dim %I, 4 : memref<?x?x?x?x?xf32>
%kernel_d = memref.load %K[%c0] : memref<3xindex>
%kernel_h = memref.load %K[%c1] : memref<3xindex>
%kernel_w = memref.load %K[2] : memref<3xindex>
%stride_d = memref.load %strides[%c0] : memref<3xindex>
%stride_h = memref.load %strides[%c1] : memref<3xindex>
%stride_w = memref.load %strides[2] : memref<3xindex>
%dilation_d = memref.load %dilations[%c0] : memref<3xindex>
%dilation_h = memref.load %dilations[%c1] : memref<3xindex>
%dilation_w = memref.load %dilations[2] : memref<3xindex>
linalg.generic {
indexing_maps = [
affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d * %stride_d + kd * %dilation_d, h * %stride_h + kh * %dilation_h, w * %stride_w + kw * %dilation_w)>, // Map for input tensor
affine_map<(n, c, d, h, w, kd, kh, kw) -> (kd, kh, kw)>, // Map for kernel tensor
affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d, h, w)> // Map for output tensor
],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"],
doc = "3D Max Pooling NCDHW with Strides, Dilations, and Kernel Size"
} ins(%I, %K : memref<?x?x?x?x?xf32>, memref<3xindex>) outs(%O : memref<?x?x?x?x?xf32>) {
^bb0(%input_elem: f32, %kernel_elem: index, %output_elem: f32):
%max_val = arith.maxf %input_elem, %output_elem : f32
linalg.yield %max_val : f32
}
return
}
```
This was implemented based on it's source code with the adjustments
mentioned above:
https://github.com/llvm/llvm-project/blob/4ca1b5e094280ef1af40412e3cfcb62dc3cf15bc/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L5647
Issues related to this can be found here
https://github.com/nod-ai/SHARK-Turbine/issues/324
2024-01-19 23:39:46 +08:00
" MaxPool3dEmptyStrideStaticModule_basic " ,
2022-06-10 03:56:01 +08:00
" AddIntModule_basic " ,
2023-11-17 00:51:55 +08:00
" ArangeStartOutViewModule_basic " ,
2023-01-18 02:14:14 +08:00
" AtenIntBoolOpModule_basic " ,
2022-06-10 03:56:01 +08:00
" BernoulliTensorModule_basic " ,
" BincountMinlengthModule_basic " ,
" BincountModule_basic " ,
" BincountStaticSizeModule_basic " ,
" BoolFloatFalseModule_basic " ,
" BoolFloatTrueModule_basic " ,
" BoolIntFalseModule_basic " ,
" BoolIntTrueModule_basic " ,
" CeilFloatModule_basic " ,
" DivFloatModule_basic " ,
" EqIntModule_basic " ,
2023-12-28 12:33:18 +08:00
" ExponentialModule_basic " ,
2022-06-10 03:56:01 +08:00
" GeFloatIntModule_basic " ,
" GeFloatModule_basic " ,
2022-07-30 17:54:40 +08:00
" GeIntModule_basic " ,
2022-06-10 03:56:01 +08:00
" GtFloatIntModule_basic " ,
" GtIntModule_basic " ,
" IndexPutImpl1DFloatAccumulateModule_basic " ,
" IndexPutImpl1DFloatNonAccumulateModule_basic " ,
" IndexPutImpl1DIntAccumulateModule_basic " ,
" IndexPutImpl1DIntNonAccumulateModule_basic " ,
" IndexPutImpl2DFloatAccumulateModule_basic " ,
" IndexPutImpl2DFloatNonAccumulateModule_basic " ,
2022-12-08 13:46:54 +08:00
" IndexPutImpl2DIndexModule_basic " ,
2023-07-18 00:51:24 +08:00
" IndexPutImpl2DNoneIndexStaticModule_basic " ,
2022-06-10 03:56:01 +08:00
" IndexPutImpl3DFloatAccumulateModule_basic " ,
" IndexPutImpl3DFloatNonAccumulateModule_basic " ,
2022-12-08 13:46:54 +08:00
" IndexPutImplIndexWithNoneModule_basic " ,
2022-06-10 03:56:01 +08:00
" Matmul_dot " ,
" MulIntModule_basic " ,
2022-10-06 21:11:52 +08:00
" DivIntModule_basic " ,
2022-06-10 03:56:01 +08:00
" NeFloatIntModule_basic " ,
" NeIntModule_basic " ,
" QuantizedMLP_basic " ,
2024-04-02 07:21:05 +08:00
" QuantizedSingleLayer_basic " ,
2024-04-16 07:06:47 +08:00
" QuantizedBatchedInputSingleLayer_basic " ,
2022-06-10 03:56:01 +08:00
" ScalarImplicitFloatModule_basic " ,
" ScalarImplicitIntModule_basic " ,
" SliceEndSleStartModule_basic " ,
" SliceOutOfUpperBoundIndexModule_basic " ,
2023-07-20 15:53:54 +08:00
" SliceOutOfUpperBoundIndexStaticModule_basic " ,
2022-06-10 03:56:01 +08:00
" SliceStartEqEndModule_basic " ,
" SqrtIntModule_basic " ,
" SubFloatModule_basic " ,
2024-02-06 08:23:04 +08:00
" MulFloatModule_basic " ,
2022-06-10 03:56:01 +08:00
" SubIntModule_basic " ,
2023-03-11 09:25:25 +08:00
" TensorsStackPromoteDTypeModule_basic " ,
2022-06-10 03:56:01 +08:00
" TensorToBoolZeroRank_basic " ,
" TensorToBool_basic " ,
" TensorToFloatZeroRank_basic " ,
" TensorToFloat_basic " ,
" TensorToIntZeroRank_basic " ,
" TensorToInt_basic " ,
" UniformModule_basic " ,
" UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic " ,
" ViewCollapseDynamicWithAtenSizeIntModule_basic " ,
2022-08-02 06:23:45 +08:00
" AtenEmbeddingBagSumExample_basic " ,
2022-08-09 06:56:49 +08:00
" Aten_EmbeddingBagExample_basic " ,
2024-01-11 22:55:42 +08:00
" ElementwiseLogitModule_basic " ,
2022-08-11 08:02:06 +08:00
" ElementwiseRemainderScalarModule_Int_Float_basic " ,
" ElementwiseRemainderScalarModule_Bool_basic " ,
2024-02-01 01:39:38 +08:00
" ElementwiseLerpScalarIntModule_basic " ,
" ElementwiseLerpScalarFloatModule_basic " ,
2022-09-20 02:50:51 +08:00
" AtenIntTensorByteDtypeModule_basic " ,
" AtenIntTensorCharDtypeModule_basic " ,
2022-11-01 21:08:04 +08:00
" UpSampleNearest2dBackwardVec_basic " ,
" UpSampleNearest2dBackwardOutputSizeNone_basic " ,
2022-11-04 15:57:29 +08:00
" ConvolutionBackwardModule2D_basic " ,
" ConvolutionBackwardModule2DPadded_basic " ,
2022-11-15 22:39:40 +08:00
" VarMeanCorrectionModule_basic " ,
2022-11-21 16:38:47 +08:00
" VarMeanCorrectionNoneModule_basic " ,
2022-12-28 11:21:33 +08:00
" ElementwisePreluModule_basic " ,
2022-12-09 23:22:26 +08:00
" VarMeanBiasedModule_basic " ,
" VarMeanUnbiasedModule_basic " ,
2023-01-16 19:40:21 +08:00
" RandnLikeModule_basic " ,
" RandnLikeDtypeModule_basic " ,
2024-01-16 14:49:29 +08:00
" NormalFunctionalModule_basic " ,
2023-02-28 10:32:21 +08:00
" BernoulliFloatModule_basic " ,
" BernoulliModule_basic " ,
" BernoulliPModule_basic " ,
" DropoutTrainModule_basic " ,
2023-06-27 14:19:33 +08:00
" DropoutTrainStaticShapeModule_basic " ,
" NativeDropoutTrainModule_basic " ,
" NativeDropoutTrainStaticShapeModule_basic " ,
2023-02-28 10:32:21 +08:00
" StdCorrectionKeepDimModule_basic " ,
" StdCorrectionNoneModule_basic " ,
" VarCorrectionKeepDimModule_basic " ,
" VarCorrectionNoneModule_basic " ,
2022-11-16 13:57:58 +08:00
" AtenFloatScalarModule_basic " ,
" PrimsSqueezeModule_basic " ,
" PrimsSqueezeEmptyDimensionsModule_basic " ,
2023-04-10 11:50:26 +08:00
" PrimsViewOfModule_basic " ,
" PrimsViewOfZeroRankModule_basic " ,
2023-04-11 16:02:28 +08:00
" OneHotModule_basic " ,
2023-04-26 15:14:06 +08:00
" VarMeanDimModule_basic " ,
" VarMeanDimBiasedModule_basic " ,
2023-04-20 00:22:57 +08:00
" AtenComplexImagModule_basic " ,
" AtenComplexRealModule_basic " ,
2023-05-19 10:07:58 +08:00
" AtenComplexViewModule_basic " ,
2023-09-02 12:12:01 +08:00
" AtenRealView128Module_basic " ,
" AtenRealView64Module_basic " ,
2022-10-16 05:46:06 +08:00
" ScatterValueFloatModule_basic " ,
" ScatterValueIntModule_basic " ,
2023-06-27 10:55:28 +08:00
" UniformStaticShapeModule_basic " ,
2023-09-05 21:28:37 +08:00
" AtenEmbeddingBagStaticModule_basic " ,
2023-09-14 01:04:31 +08:00
" EmptyStridedModule_basic " ,
2023-11-01 11:56:54 +08:00
" EmptyStridedSizeIntStrideModule_basic " ,
2023-09-28 20:53:02 +08:00
" ElementwiseBitwiseAndScalarInt64Module_basic " ,
" ElementwiseBitwiseAndScalarInt32Module_basic " ,
2023-10-03 19:59:56 +08:00
" ElementwiseBitwiseAndScalarInt8Module_basic " ,
2024-01-31 05:46:47 +08:00
" Conv2dQInt8Module_basic " ,
2022-06-10 03:56:01 +08:00
}
2024-02-16 02:17:13 +08:00
ONNX_XFAIL_SET = {
2024-02-28 14:48:07 +08:00
# Failure - cast error
" PermuteNegativeIndexModule_basic " ,
2024-04-19 05:58:13 +08:00
# Failure - expand multiple dynamic dims
" EmbeddingModuleF16_basic " ,
" EmbeddingModuleI32_basic " ,
" EmbeddingModuleI64_basic " ,
" IndexTensorHackedTwinModule3dInput_basic " ,
" IndexTensorHackedTwinModule_basic " ,
" IndexTensorModule3dInput_basic " ,
" IndexTensorModule_basic " ,
" IndexTensorMultiInputContiguousOneDimDynamic_basic " ,
" IndexTensorMultiInputNonContiguousOneDimDynamic_basic " ,
" IndexTensorSelectDimModule_basic " ,
2024-02-28 14:48:07 +08:00
# Failure - incorrect numerics
2024-04-19 05:58:13 +08:00
" AvgPool2dDivisorOverrideModule_basic " ,
" BroadcastDynamicDimModule_basic " ,
2024-02-28 14:48:07 +08:00
" ElementwiseAtan2TensorIntModule_basic " ,
2024-04-19 05:58:13 +08:00
" ElementwiseAtenFloorDivideScalarNegativeModule_basic " ,
" ElementwiseAtenFloorDivideTensorNegativeModule_basic " ,
2024-02-28 14:48:07 +08:00
" ElementwiseLog10IntModule_basic " ,
" ElementwiseLog2IntModule_basic " ,
" ElementwiseSeluModule_basic " ,
" FlipModuleStaticShape_basic " ,
" FlipNegativeIndexModule_basic " ,
" HardsigmoidModule_basic " ,
" HardsigmoidRandomModule_basic " ,
" PixelShuffleModuleStaticRank4Float32_basic " ,
2024-04-19 05:58:13 +08:00
" ReflectionPad1dModule2dInput_Right " ,
" ReflectionPad1dModule2dInput_basic " ,
" ReflectionPad1dModule3dInput_Left " ,
" ReflectionPad1dModule3dInput_basic " ,
" ReflectionPad2dModule_Bottom " ,
" ReflectionPad2dModule_Left " ,
" ReflectionPad2dModule_Right " ,
" ReflectionPad2dModule_Top " ,
" ReflectionPad2dModule_basic " ,
" ReplicationPad2dModule_basic " ,
" ReplicationPad2dModule_bottom0 " ,
" ReplicationPad2dModule_left0 " ,
" ReplicationPad2dModule_right0 " ,
" ReplicationPad2dModule_top0 " ,
2024-02-28 14:48:07 +08:00
" SliceCopyEndGreaterThanDimSize_Module_basic " ,
" SliceCopyNegative_Module_basic " ,
" SliceCopyNonZeroDim_Module_basic " ,
" SliceCopy_Module_basic " ,
2024-04-19 05:58:13 +08:00
" StdCorrectionLargeInputModule_basic " ,
2024-02-28 14:48:07 +08:00
" TupleModule_basic " ,
2024-04-19 05:58:13 +08:00
" VarCorrectionLargeInputModule_basic " ,
2024-02-28 14:48:07 +08:00
# Failure - incorrect shape
" ArangeStartOutDtypeModule_basic " ,
" ArangeStartOutViewModule_basic " ,
" MoveDimIntNegativeIndexModule_basic " ,
2024-04-19 05:58:13 +08:00
" ReduceL3NormKeepDimModule_basic " ,
2024-02-28 14:48:07 +08:00
" ViewSizeFromOtherTensor_basic " ,
2024-04-19 05:58:13 +08:00
2024-02-16 02:17:13 +08:00
# Failure - onnx_export
" AdaptiveAvgPool1dGeneralDynamic_basic " ,
" AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic " ,
" AdaptiveAvgPool1dStaticLargerOutput_basic " ,
2024-04-19 05:58:13 +08:00
" AdaptiveAvgPool2dDynamicNoBatch_basic " ,
" AdaptiveAvgPool2dDynamic_basic " ,
2024-02-16 02:17:13 +08:00
" AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic " ,
2024-04-11 17:02:59 +08:00
" AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic " ,
2024-04-19 05:58:13 +08:00
" AdaptiveAvgPool3dDynamicNoBatch_basic " ,
" AdaptiveAvgPool3dDynamic_basic " ,
" AdaptiveMaxPool1dDynamicNoBatch_basic " ,
" AdaptiveMaxPool1dDynamic_basic " ,
" AdaptiveMaxPool1dStatic_basic " ,
" AdaptiveMaxPool2dDynamicNoBatch_basic " ,
2024-02-16 02:17:13 +08:00
" AdaptiveMaxPool2dDynamicWithIndices_basic " ,
" AdaptiveMaxPool2dDynamic_basic " ,
" AdaptiveMaxPool2dStaticWithIndices_basic " ,
" AdaptiveMaxPool2dStatic_basic " ,
2024-03-23 02:05:20 +08:00
" AdaptiveMaxPool3dDynamicNoBatch_basic " ,
2024-04-19 05:58:13 +08:00
" AdaptiveMaxPool3dDynamicWithIndices_basic " ,
" AdaptiveMaxPool3dDynamic_basic " ,
" AdaptiveMaxPool3dStaticWithIndices_basic " ,
" AdaptiveMaxPool3dStatic_basic " ,
2024-02-16 02:17:13 +08:00
" AddCDivModule_basic " ,
" AddIntModule_basic " ,
" Add_Module_basic " ,
" AllBoolFalseModule_basic " ,
" AllBoolTrueModule_basic " ,
" AnyBoolFalseModule_basic " ,
" AnyBoolTrueModule_basic " ,
" AtenComplex64Module_basic " ,
" AtenComplexImagModule_basic " ,
" AtenComplexRealModule_basic " ,
" AtenComplexViewModule_basic " ,
2024-04-19 05:58:13 +08:00
" AtenDiagEmbedDefaultDiag_basic " ,
" AtenDiagEmbedDimDiag_basic " ,
" AtenDiagEmbedNegOffsetDiag_basic " ,
" AtenDiagEmbedNonDefault4DDiag_basic " ,
" AtenDiagEmbedOffsetDiag_basic " ,
" AtenDiagEmbedRevDimDiag_basic " ,
2024-02-16 02:17:13 +08:00
" AtenEmbeddingBagStaticModule_basic " ,
" AtenEmbeddingBagSumExample_basic " ,
" AtenFloatScalarModule_basic " ,
" AtenIntBoolOpConstFalseModule_basic " ,
" AtenIntBoolOpConstTrueModule_basic " ,
" AtenIntBoolOpModule_basic " ,
" AtenIntTensorByteDtypeModule_basic " ,
" AtenIntTensorCharDtypeModule_basic " ,
" AtenItemFpOpModule_basic " ,
" AtenItemIntOpModule_basic " ,
2024-04-19 05:58:13 +08:00
" AtenLinalgCrossDynamic_basic " ,
2024-04-16 07:06:47 +08:00
" AtenMatmulQMixedSigni8Transpose_basic " ,
" AtenMatmulQMixedSigni8_basic " ,
" AtenMatmulQint8MV_basic " ,
2024-04-17 00:28:28 +08:00
" AtenMatmulQint8VM_basic " ,
2024-04-19 05:58:13 +08:00
" AtenMatmulQint8VV_basic " ,
2024-04-16 07:06:47 +08:00
" AtenMatmulQint8_basic " ,
2024-04-19 05:58:13 +08:00
" AtenMmQMixedSigni8_basic " ,
" AtenMmQint8_basic " ,
" AtenMmQuint8_basic " ,
2024-02-16 02:17:13 +08:00
" AtenRealView128Module_basic " ,
" AtenRealView64Module_basic " ,
" AtenSubFloatModule_basic " ,
" AtenTopKModule_basic " ,
" AtenTopKSmallestModule_basic " ,
" Aten_EmbeddingBagExample_basic " ,
" AvgPool2dWithoutPadModule_basic " ,
" BatchMlpLayerModule_basic " ,
" BincountMinlengthModule_basic " ,
" BincountModule_basic " ,
" BincountStaticSizeModule_basic " ,
" BoolFloatConstantModule_basic " ,
" BoolFloatFalseModule_basic " ,
" BoolFloatTrueModule_basic " ,
" BoolIntConstantModule_basic " ,
" BoolIntFalseModule_basic " ,
" BoolIntTrueModule_basic " ,
" CeilFloatModule_basic " ,
" ChunkListUnpackDynamic_Module_basic " ,
" ChunkListUnpackUnevenDynamic_Module_basic " ,
" CollapseAllDimensionsModule_basic " ,
" CollapseFullDynamicModule_basic " ,
" CollapsePartialDynamicModule_basic " ,
" CollapseRank1DynamicModule_basic " ,
" CollapseStaticModule_basic " ,
" ConstantBoolParameterModule_basic " ,
" ContainsIntList_False " ,
" ContainsIntList_True " ,
" Conv1dModule_basic " ,
" Conv2dBiasNoPaddingModule_basic " ,
" Conv2dModule_basic " ,
" Conv2dNoPaddingModule_basic " ,
" Conv2dQInt8Module_basic " ,
" Conv2dWithPaddingDilationStrideModule_basic " ,
" Conv2dWithPaddingModule_basic " ,
" Conv3dModule_basic " ,
" ConvTbcModule_basic " ,
" Conv_Transpose2dModule_basic " ,
" Convolution2DModule_basic " ,
" Convolution2DStridedModule_basic " ,
" ConvolutionBackwardModule2DPadded_basic " ,
" ConvolutionBackwardModule2DStatic_basic " ,
" ConvolutionBackwardModule2DStrided_basic " ,
" ConvolutionBackwardModule2D_basic " ,
" ConvolutionModule2DGroups_basic " ,
" ConvolutionModule2DTransposeNonUnitOutputPadding_basic " ,
" ConvolutionModule2DTransposeStrided_basic " ,
" ConvolutionModule2DTranspose_basic " ,
" DivFloatModule_basic " ,
" DivIntModule_basic " ,
" ElementwiseAcoshIntModule_basic " ,
" ElementwiseAcoshModule_basic " ,
2024-04-08 20:24:17 +08:00
" ElementwiseAndScalarModule_basic " ,
" ElementwiseAndScalarStaticShapeModule_basic " ,
2024-02-16 02:17:13 +08:00
" ElementwiseAsinhIntModule_basic " ,
" ElementwiseAsinhModule_basic " ,
" ElementwiseAtanhIntModule_basic " ,
" ElementwiseAtanhModule_basic " ,
" ElementwiseAtenIsneginfOpModule_basic " ,
" ElementwiseAtenIsposinfOpModule_basic " ,
" ElementwiseBitwiseAndModule_basic " ,
" ElementwiseBitwiseAndScalarInt32Module_basic " ,
" ElementwiseBitwiseAndScalarInt64Module_basic " ,
" ElementwiseBitwiseAndScalarInt8Module_basic " ,
" ElementwiseBitwiseAndStaticShapeModule_basic " ,
" ElementwiseBitwiseLeftShiftInt32Module_basic " ,
" ElementwiseBitwiseLeftShiftInt64Module_basic " ,
" ElementwiseBitwiseLeftShiftInt8Module_basic " ,
" ElementwiseBitwiseNotInt32Module_basic " ,
" ElementwiseBitwiseNotInt64Module_basic " ,
" ElementwiseBitwiseOrModule_basic " ,
" ElementwiseBitwiseOrStaticShapeModule_basic " ,
" ElementwiseBitwiseRightShiftInt32Module_basic " ,
" ElementwiseBitwiseRightShiftInt64Module_basic " ,
" ElementwiseBitwiseRightShiftInt8Module_basic " ,
" ElementwiseBitwiseXorModule_basic " ,
" ElementwiseBitwiseXorStaticShapeModule_basic " ,
" ElementwiseCoshIntModule_basic " ,
" ElementwiseCoshModule_basic " ,
" ElementwiseDequantizePerChannelModule_basic " ,
" ElementwiseDequantizePerTensorModule_basic " ,
2024-04-19 05:58:13 +08:00
" ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic " ,
" ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic " ,
2024-02-16 02:17:13 +08:00
" ElementwiseEluNonDefaultModule_basic " ,
" ElementwiseExpm1IntModule_basic " ,
" ElementwiseExpm1Module_basic " ,
2024-04-19 05:58:13 +08:00
" ElementwiseFmodTensor_Int_basic " ,
2024-02-16 02:17:13 +08:00
" ElementwiseMulTensorComplexModule_basic " ,
" ElementwiseOrTensorModule_basic " ,
" ElementwiseOrTensorStaticShapeModule_basic " ,
" ElementwiseQuantizePerTensorModule_basic " ,
2024-03-21 04:37:47 +08:00
" ElementwiseQuantizePerTensorUIntModule_basic " ,
2024-02-16 02:17:13 +08:00
" ElementwiseRemainderTensorModule_Int_basic " ,
2024-04-19 05:58:13 +08:00
" ElementwiseSgnModule_basic " ,
2024-02-16 02:17:13 +08:00
" EmptyStridedModule_basic " ,
" EmptyStridedSizeIntStrideModule_basic " ,
" EqIntModule_basic " ,
" ExponentialModule_basic " ,
2024-02-28 14:48:07 +08:00
" FloatImplicitModule_basic " ,
2024-02-16 02:17:13 +08:00
" GeFloatIntModule_basic " ,
" GeFloatModule_basic " ,
" GeIntModule_basic " ,
" GeluBackwardModule_basic " ,
" GtFloatIntModule_basic " ,
" GtIntModule_basic " ,
" HardtanhBackward_basic " ,
" IndexPutImpl1DFloatAccumulateModule_basic " ,
" IndexPutImpl1DFloatNonAccumulateModule_basic " ,
" IndexPutImpl1DIntAccumulateModule_basic " ,
" IndexPutImpl1DIntNonAccumulateModule_basic " ,
" IndexPutImpl2DFloatAccumulateModule_basic " ,
" IndexPutImpl2DFloatNonAccumulateModule_basic " ,
2024-04-19 05:58:13 +08:00
" IndexPutImpl2DImplicitModule_basic " ,
2024-02-16 02:17:13 +08:00
" IndexPutImpl2DIndexModule_basic " ,
" IndexPutImpl2DNoneIndexStaticModule_basic " ,
" IndexPutImpl3DFloatAccumulateModule_basic " ,
" IndexPutImpl3DFloatNonAccumulateModule_basic " ,
" IndexPutImplIndexWithNoneModule_basic " ,
" IntFloatModule_basic " ,
2024-02-28 14:48:07 +08:00
" IntImplicitModule_basic " ,
2024-02-16 02:17:13 +08:00
" IouOfModule_basic " ,
" IsFloatingPointFloat_True " ,
" IsFloatingPointInt_False " ,
" IscloseStaticModuleTrue_basic " ,
" IscloseStaticModule_basic " ,
" LeakyReluBackwardModule_basic " ,
" LeakyReluBackwardStaticModule_basic " ,
" LenStrModule_basic " ,
" LiftFreshCopyModule_basic " ,
2024-04-19 05:58:13 +08:00
" LinalgNormKeepDimComplexModule_basic " ,
" LinalgVectorNormComplexModule_basic " ,
2024-02-16 02:17:13 +08:00
" LogSoftmaxBackwardModule_basic " ,
" MaxPool2dCeilModeTrueModule_basic " ,
" MaxPool2dModule_basic " ,
" MaxPool2dWithIndicesAllOnesModule_basic " ,
" MaxPool2dWithIndicesBackwardDynamic3DModule_basic " ,
" MaxPool2dWithIndicesBackwardDynamic4DModule_basic " ,
" MaxPool2dWithIndicesBackwardStatic3DModule_basic " ,
" MaxPool2dWithIndicesBackwardStatic4DModule_basic " ,
" MaxPool2dWithIndicesCeilModeTrueModule_basic " ,
" MaxPool2dWithIndicesFullSizeKernelModule_basic " ,
" MaxPool2dWithIndicesModule_basic " ,
" MaxPool2dWithIndicesNonDefaultDilationModule_basic " ,
" MaxPool2dWithIndicesNonDefaultParamsModule_basic " ,
" MaxPool2dWithIndicesNonDefaultStrideModule_basic " ,
" MaxPool3dCeilModeTrueModule_basic " ,
" MaxPool3dLargeDatadModule_basic " ,
" MaxPool3dModuleRandomSimple_basic " ,
" MaxPool3dModule_basic " ,
" MeanDimEmptyDimModule_basic " ,
" Mlp1LayerModule_basic " ,
" Mlp2LayerModuleNoBias_basic " ,
" Mlp2LayerModule_basic " ,
" MulFloatModule_basic " ,
" MulIntModule_basic " ,
" NarrowHorizontalTest2_basic " ,
" NarrowHorizontalTest_basic " ,
" NarrowTensorHorizontalModule_basic " ,
" NarrowTensorVerticalModule_basic " ,
" NarrowVerticalTest2_basic " ,
" NarrowVerticalTest_basic " ,
" NativeBatchNorm1DModule_basic " ,
" NativeBatchNorm2DModule_basic " ,
" NativeBatchNorm3DModule_basic " ,
" NativeBatchNormNoneWeightModule_basic " ,
" NativeDropoutEvalFloatModule_basic " ,
" NativeGroupNormBackwardModule_basic " ,
" NativeGroupNormModule_basic " ,
" NativeLayerNormDynamicModule_basic " ,
" NeFloatIntModule_basic " ,
" NeIntModule_basic " ,
" NewEmptyStridedModuleDefaultDtype_basic " ,
" NllLossModuleBackward1DMeanWeight_basic " ,
" NllLossModuleBackward1DMean_basic " ,
" NllLossModuleBackward1DSumWeight_basic " ,
" NllLossModuleBackward1DSum_basic " ,
" NllLossModuleBackward1DWeight_basic " ,
" NllLossModuleBackward1D_basic " ,
" NllLossModuleBackwardMeanWeight_basic " ,
" NllLossModuleBackwardMean_basic " ,
" NllLossModuleBackwardSumWeight_basic " ,
" NllLossModuleBackwardSum_basic " ,
" NllLossModuleBackwardWeight_basic " ,
" NllLossModuleBackward_basic " ,
" NllLossModuleBackward_ignore_index " ,
" NllLossModule_1D_basic " ,
" NllLossModule_basic " ,
" NllLossModule_ignore_index_out_of_bounds_basic " ,
" NllLossModule_mean_basic " ,
" NllLossModule_sum_basic " ,
2024-04-02 16:33:30 +08:00
" NormScalarComplexModule_basic " ,
2024-04-19 05:58:13 +08:00
" NormScalarModule_basic " ,
" NormScalarOptDimKeepDimComplexModule_basic " ,
2024-02-16 02:17:13 +08:00
" NormScalarOptDimKeepDimModule_basic " ,
" NormScalarOptDimModule_basic " ,
" NormalFunctionalModule_basic " ,
" NumToTensorFloatModule_basic " ,
" NumToTensorIntModule_basic " ,
" NumelModule_basic " ,
" NumelZeroRankModule_basic " ,
" PixelShuffleModuleFullDynamic_basic " ,
" PixelShuffleModuleSpatiallyDynamic_basic " ,
" PixelShuffleModuleSpatiallyStatic_basic " ,
" PixelShuffleModuleStaticRank3Int64_basic " ,
" PowIntFloatModule_basic " ,
" PrimMaxIntModule_basic " ,
" PrimMinIntDynamicModule_basic " ,
" PrimMinIntModule_basic " ,
" PrimsConvertElementTypeModule_basic " ,
" PrimsSqueezeEmptyDimensionsModule_basic " ,
" PrimsSqueezeModule_basic " ,
" PrimsViewOfModule_basic " ,
" PrimsViewOfZeroRankModule_basic " ,
" RandIntDtypeModule_basic " ,
" RandIntModule_basic " ,
" RandIntPinMemoryModule_basic " ,
2024-04-19 05:58:13 +08:00
" ReduceFrobeniusNormComplexModule_basic " ,
" ReduceL1NormComplexModule_basic " ,
" ReduceL2NormComplexModule_basic " ,
" ReduceL3NormKeepDimComplexModule_basic " ,
2024-02-16 02:17:13 +08:00
" ReshapeAliasCollapseModule_basic " ,
" ReshapeAliasExpandModule_basic " ,
" ReshapeExpandModule_basic " ,
" ScalarConstantTupleModule_basic " ,
" ScalarImplicitFloatModule_basic " ,
" ScalarImplicitIntModule_basic " ,
" ScatterReduceFloatMaxModule " ,
" ScatterReduceFloatMeanModule " ,
" ScatterReduceFloatMeanModuleIncludeSelf " ,
" ScatterReduceFloatMinModule " ,
" ScatterReduceFloatProdModule " ,
" ScatterReduceFloatSumModule " ,
" ScatterReduceIntMaxModule " ,
" ScatterReduceIntMeanModule " ,
" ScatterReduceIntMeanModuleIncludeSelf " ,
" ScatterReduceIntMinModule " ,
" ScatterReduceIntProdModule " ,
" ScatterReduceIntSumModule " ,
" SelectScattertModule_basic " ,
" SelectScattertStaticModule_basic " ,
" SliceEndSleStartModule_basic " ,
" SliceOutOfUpperBoundIndexModule_basic " ,
" SliceScatterModule_basic " ,
" SliceScatterNegativeDimModule_basic " ,
" SliceScatterNegativeEndModule_basic " ,
" SliceScatterStaticModule_basic " ,
" SliceScatterStepVariationModule_basic " ,
" SliceScatterZeroDimModule_basic " ,
" SliceStartEqEndModule_basic " ,
" SoftmaxBackwardModule_basic " ,
" SortIntListReverse_basic " ,
" SortIntList_basic " ,
" SplitDimDynamicModule_basic " ,
" SplitDimStaticModule_basic " ,
" SqrtIntConstantModule_basic " ,
" SqrtIntModule_basic " ,
" StdCorrectionEmptyDimModule_basic " ,
" StdDimEmptyDimModule_basic " ,
" SubFloatModule_basic " ,
" SubIntModule_basic " ,
" TanhBackward_basic " ,
" TensorToBoolZeroRank_basic " ,
" TensorToBool_basic " ,
" TensorToFloatZeroRank_basic " ,
" TensorToFloat_basic " ,
" TensorToIntZeroRank_basic " ,
" TensorToInt_basic " ,
" TestMultipleTensorAndPrimitiveTypesReturn_basic " ,
" Threshold1dFloatModule_basic " ,
" Threshold1dIntI32Module_basic " ,
" Threshold1dIntModule_basic " ,
" Threshold2dFloatModule_basic " ,
" Threshold2dIntModule_basic " ,
" Threshold3dFloatModule_basic " ,
" Threshold3dIntModule_basic " ,
" ThresholdBackward1dFloatModule_basic " ,
" ThresholdBackward1dIntModule_basic " ,
" ThresholdBackward1dMixedModule_basic " ,
" ThresholdBackward2dFloatModule_basic " ,
" ThresholdBackward2dIntModule_basic " ,
" ThresholdBackward2dMixedModule_basic " ,
" ThresholdBackward3dFloatModule_basic " ,
" ThresholdBackward3dIntModule_basic " ,
" ThresholdBackward3dMixedModule_basic " ,
" ToCopyBoolDTypeStaticModule_basic " ,
" ToCopyModule_basic " ,
" ToCopyWithDTypeFalsePinMemoryModule_basic " ,
" ToCopyWithDTypeModule_basic " ,
" TorchPrimLoopForLikeModule_basic " ,
" TorchPrimLoopWhileLikeModule_basic " ,
" TraceModule_basic " ,
" TraceModule_empty " ,
" TraceModule_nonsquare " ,
" TraceSignedIntModule_basic " ,
" TraceUnsignedIntModule_basic " ,
" TraceUnsignedIntModule_empty " ,
" UniformModule_basic " ,
" UniformNoCorrelationModule_basic " ,
" UniformStaticShapeModule_basic " ,
" UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic " ,
" UnsafeView1DFoldModule_basic " ,
" UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic " ,
" UnsafeViewCollapseModule_basic " ,
" UnsafeViewDynamicExpandModule_basic " ,
" UnsafeViewDynamicExpandWithAtenSizeIntModule_basic " ,
" UnsafeViewExpandModule_basic " ,
" UpSampleNearest2dBackwardScalesNone_basic " ,
" UpSampleNearest2dBackward_basic " ,
" UpSampleNearest2dDynamicFactor_basic " ,
" UpSampleNearest2dStaticFactor_basic " ,
" UpSampleNearest2d_basic " ,
" VarCorrectionEmptyDimModule_basic " ,
" VarDimEmptyDimModule_basic " ,
" ViewCollapseDynamicWithAtenSizeIntModule_basic " ,
" ViewCollapseModule_basic " ,
" ViewDynamicExpandCollapseModule_basic " ,
" ViewDynamicExpandCollapseWithAtenIntModule_basic " ,
" ViewDynamicExpandModule_basic " ,
" ViewDynamicExpandWithAtenSizeIntModule_basic " ,
" ViewExpandDynamicDimModule_basic " ,
" ViewNoChange1dModule_basic " ,
" ViewNoChange2dModule_basic " ,
" ViewNoChange3dModule_basic " ,
" _Convolution2DAllFalseModule_basic " ,
" _Convolution2DBenchmarkModule_basic " ,
" _Convolution2DCudnnModule_basic " ,
" _Convolution2DDeterministicModule_basic " ,
" _Convolution2DTF32Module_basic " ,
" _ConvolutionDeprecated2DAllFalseModule_basic " ,
" _ConvolutionDeprecated2DBenchmarkModule_basic " ,
" _ConvolutionDeprecated2DCudnnModule_basic " ,
" _ConvolutionDeprecated2DDeterministicModule_basic " ,
" _SoftmaxModule_basic " ,
2024-04-19 05:58:13 +08:00
# Failure - onnx_lowering: onnx.AveragePool
" AdaptiveAvgPool1dGeneralDynamicNoBatches_basic " ,
# Failure - onnx_lowering: onnx.If
2024-02-16 02:17:13 +08:00
" DiagonalModule_basic " ,
" DiagonalModule_nonsquare " ,
" DiagonalModule_transposed " ,
" DiagonalModule_with_dims " ,
" DiagonalModule_with_dims_and_offset " ,
" DiagonalModule_with_negative_dims " ,
" DiagonalModule_with_offset " ,
" TileBigDimsSizeModule_basic " ,
" TileSmallDimsSizeModule_basic " ,
2024-04-19 05:58:13 +08:00
2024-02-28 14:48:07 +08:00
# Failure - onnx_lowering: onnx.MaxPool
2024-02-16 02:17:13 +08:00
" MaxPool2dWithIndicesAllNegativeValuesModule_basic " ,
" MaxPool2dWithIndicesNonDefaultPaddingModule_basic " ,
" MaxPool2dWithIndicesStaticModule_basic " ,
2024-04-19 05:58:13 +08:00
2024-02-28 14:48:07 +08:00
# Failure - onnx_lowering: onnx.OneHot
2024-02-16 02:17:13 +08:00
" OneHotModule_basic " ,
2024-04-19 05:58:13 +08:00
2024-04-23 00:58:07 +08:00
# ERROR: dtype (torch.float32) is not equal to golden dtype (torch.float64)
2024-02-28 14:48:07 +08:00
" RandnDtypeDeviceModule_basic " ,
" RandnGeneratorF64Module_basic " ,
" RandnGeneratorModule_basic " ,
" RandnModule_basic " ,
" RandnLikeModule_basic " ,
" BernoulliFloatModule_basic " ,
" BernoulliPModule_basic " ,
" BernoulliTensorModule_basic " ,
2024-04-19 05:58:13 +08:00
2024-02-28 14:48:07 +08:00
# Failure - onnx_lowering: onnx.ReduceProd
" ReduceProdDimIntFloatModule_basic " ,
2024-04-19 05:58:13 +08:00
2024-02-28 14:48:07 +08:00
# Failure - onnx_lowering: onnx.Resize
" UpSampleNearest2dDynamicSize_basic " ,
" UpSampleNearest2dStaticSize_basic " ,
2024-04-19 05:58:13 +08:00
2024-02-28 14:48:07 +08:00
# Failure - onnx_lowering: onnx.ScatterElements
2024-04-19 05:58:13 +08:00
" ScatterReduceFloatMaxModuleIncludeSelf " ,
" ScatterReduceFloatMinModuleIncludeSelf " ,
" ScatterReduceIntMaxModuleIncludeSelf " ,
" ScatterReduceIntMinModuleIncludeSelf " ,
2024-02-28 14:48:07 +08:00
" ScatterValueFloatModule_basic " ,
2024-04-19 05:58:13 +08:00
2024-02-28 14:48:07 +08:00
# Failure - onnx_lowering: onnx.ScatterND
" IndexPut1DFloatAccumulateModule_basic " ,
" IndexPut1DFloatNonAccumulateModule_basic " ,
" IndexPut1DIntAccumulateModule_basic " ,
" IndexPut1DIntNonAccumulateModule_basic " ,
" IndexPut2DFloatAccumulateModule_basic " ,
" IndexPut2DFloatNonAccumulateModule_basic " ,
" IndexPut2DIntAccumulateModule_basic " ,
" IndexPut2DIntNonAccumulateModule_basic " ,
" IndexPut3DFloatAccumulateModule_basic " ,
" IndexPut3DFloatNonAccumulateModule_basic " ,
" IndexPut3DIntAccumulateModule_basic " ,
" IndexPut3DIntNonAccumulateModule_basic " ,
" IndexPutHackedTwin1DFloatAccumulateModule_basic " ,
" IndexPutHackedTwin1DFloatNonAccumulateModule_basic " ,
" IndexPutHackedTwin1DIntAccumulateModule_basic " ,
" IndexPutHackedTwin1DIntNonAccumulateModule_basic " ,
" IndexPutHackedTwin2DFloatAccumulateModule_basic " ,
" IndexPutHackedTwin2DFloatNonAccumulateModule_basic " ,
" IndexPutHackedTwin2DIntAccumulateModule_basic " ,
" IndexPutHackedTwin2DIntNonAccumulateModule_basic " ,
" IndexPutHackedTwin3DFloatAccumulateModule_basic " ,
" IndexPutHackedTwin3DFloatNonAccumulateModule_basic " ,
" IndexPutHackedTwin3DIntAccumulateModule_basic " ,
" IndexPutHackedTwin3DIntNonAccumulateModule_basic " ,
2024-04-19 05:58:13 +08:00
2024-02-28 14:48:07 +08:00
# Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss
" CrossEntropyLossModule_basic " ,
" CrossEntropyLossNoReductionModule_basic " ,
2024-04-22 10:45:01 +08:00
# RuntimeError: unsupported input type: Device
" PrimsIotaModule_basic " ,
2024-04-22 16:52:42 +08:00
2024-02-16 02:17:13 +08:00
# Failure - unknown
2024-04-19 05:58:13 +08:00
" BernoulliModule_basic " ,
2024-02-16 02:17:13 +08:00
" Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier " ,
" CopyWithDifferentDTypesAndSizesModule_basic " ,
" CopyWithDifferentDTypesModule_basic " ,
" CosineSimilarityStaticBroadcastModule_basic " ,
" CumsumInputDtypeInt32Module_basic " ,
2024-04-19 05:58:13 +08:00
" DropoutTrainModule_basic " ,
" DropoutTrainStaticShapeModule_basic " ,
2024-02-28 14:48:07 +08:00
" ElementwiseAcosIntModule_basic " ,
" ElementwiseAsinIntModule_basic " ,
" ElementwiseAtanTensorIntModule_basic " ,
" ElementwiseCosIntModule_basic " ,
2024-04-16 04:45:10 +08:00
" ElementwiseDivTensorRoundingModeTruncModule_basic " ,
" ElementwiseDivTensorRoundingModeTruncStaticModule_basic " ,
" ElementwiseErfIntModule_basic " ,
2024-02-28 14:48:07 +08:00
" ElementwiseExpIntModule_basic " ,
" ElementwiseLogIntModule_basic " ,
" ElementwiseSigmoidIntModule_basic " ,
" ElementwiseSinIntModule_basic " ,
" ElementwiseTanIntModule_basic " ,
2024-04-19 05:58:13 +08:00
" ElementwiseToDtypeI64ToUI8Module_basic " ,
2024-02-28 14:48:07 +08:00
" ElementwiseUnaryIntModule_basic " ,
2024-02-16 02:17:13 +08:00
" MaskedFillTensorFloatValueModule_basic " ,
2024-04-19 05:58:13 +08:00
" NativeDropoutTrainModule_basic " ,
" NativeDropoutTrainStaticShapeModule_basic " ,
" ReduceMaxAlongDimUnsignedInt_basic " ,
2024-02-16 02:17:13 +08:00
" ReduceMinAlongDimUnsignedInt_basic " ,
}
2024-02-21 01:30:30 +08:00
2024-04-16 14:54:46 +08:00
if torch_version_for_comparison ( ) > = version . parse ( " 2.4.0.dev " ) :
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
# ERROR: Found dtype (torch.float64) but expected (torch.float32)
" ReduceL1NormWithDTypeModule_basic " ,
}
[Torch] Add decomposition of RepeatInterleaveSelfInt Op (#3075)
Decomposition RepeatInterleaveSelfInt with following ops:
```python
def my_repeat_interleave(input, repeats, dim=None):
if dim is None:
# Flatten the input and then repeat
return input.flatten().unsqueeze(-1).tile((1, repeats)).flatten()
else:
# Calculate the shape after repeat
expanded_shape = list(input.shape)
expanded_shape[dim] *= repeats
# Repeat the tensor along the specified dimension
repeat_shape = [1] * (input.dim() + 1)
repeat_shape[dim + 1] = repeats
input = input.unsqueeze(-1)
# Tile and then reshape
tiled = torch.tile(input, repeat_shape)
# Rearrange and reshape
repeated = tiled.reshape(*expanded_shape)
return repeated
```
I passed the tests of stablehlo and linalg. When testing onnx, strange
things happened.
In torch-mlir's CI **torch_nightly** and my own
environment(torch==2.4.0.dev20240318+cpu), it can **pass the pass**.
In torch-mlir's CI **torch_stable**, it **failed**.
The test case is `RepeatInterleaveSelfIntNoDimModule_basic`, the result
shape should be [120].
```python
class RepeatInterleaveSelfIntNoDimModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
])
def forward(self, x):
return x.repeat_interleave(2)
@register_test_case(module_factory=lambda: RepeatInterleaveSelfIntNoDimModule())
def RepeatInterleaveSelfIntNoDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))
```
The error log is as follows:
```
Unexpected outcome summary: (onnx)
****** Failed tests - 1 tests
FAIL - "RepeatInterleaveSelfIntNoDimModule_basic"
@ trace item #0 - call to "forward"
@ output of call to "forward"
ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
```
@rsuderman
Would you please help me check what's wrong with my PR? Thanks a lot.
2024-04-18 06:27:51 +08:00
if torch_version_for_comparison ( ) < version . parse ( ' 2.3.0.dev ' ) :
ONNX_XFAIL_SET = ONNX_XFAIL_SET | {
# ERROR: shape (torch.Size([6, 4, 5])) is not equal to golden shape (torch.Size([120]))
" RepeatInterleaveSelfIntNoDimModule_basic " ,
}
2024-04-16 14:54:46 +08:00
2024-03-15 08:53:29 +08:00
ONNX_CRASHING_SET = {
" FakeQuantizePerTensorAffineModule_basic " ,
" FakeQuantizePerTensorAffineDynamicShapeModule_basic " ,
2024-04-22 16:52:42 +08:00
" ElementwisePreluModule_basic " ,
[MLIR][TORCH] Support parallel dimemsions expand/collapse (#3051)
This PR support `aten.view` with unique unknown dimension both in input
shape and output shape while the pass convert-torch-to-linalg that
lowing `aten.view` to `tensor.collapse_shape` or `tensor.expand_shape`.
Below is an example
```
func.func @test_reshape(%arg0: !torch.vtensor<[1,?,50,16],f32>) -> !torch.vtensor<[1,?,16],f32> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%int1 = torch.constant.int 1
%int-1 = torch.constant.int -1
%int16 = torch.constant.int 16
%0 = torch.prim.ListConstruct %int1, %int-1, %int16 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,?,50,16],f32>, !torch.list<int> -> !torch.vtensor<[1,?,16],f32>
return %1 : !torch.vtensor<[1,?,16],f32>
}
```
2024-04-12 01:43:03 +08:00
" ViewDynamicExpandCollapseWithParallelUnknownDimModule_basic " ,
2024-04-22 16:52:42 +08:00
" ScatterReduceFloatProdModuleIncludeSelf " ,
" ScatterReduceFloatSumModuleIncludeSelf " ,
" ScatterReduceIntProdModuleIncludeSelf " ,
" ScatterReduceIntSumModuleIncludeSelf " ,
2024-03-15 08:53:29 +08:00
}