Co-authored-by: Peiming Liu <peiming@google.com>
pull/3093/head
Rob Suderman 2024-04-01 16:34:59 -07:00 committed by GitHub
parent 532d297c46
commit ec4cb8be44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 85 additions and 140 deletions

@ -1 +1 @@
Subproject commit e5ed7b6e2fd368b722b6359556cd0125881e7638 Subproject commit 0030fc4ac74a9ce645adb9d59e108da4d4d11818

2
externals/stablehlo vendored

@ -1 +1 @@
Subproject commit 4ac26f8786d491c5d8376e6e563d1b72af09de75 Subproject commit 271e8634de184fbfafd677d3876170feb6d08c97

View File

@ -388,24 +388,12 @@ STABLEHLO_PASS_SET = {
"ArangeStartNegativeStepIntModule_basic", "ArangeStartNegativeStepIntModule_basic",
"ArangeStartOutDtypeModule_basic", "ArangeStartOutDtypeModule_basic",
"ArangeStartOutModule_basic", "ArangeStartOutModule_basic",
"ArangeStartOutViewModule_basic",
"ArangeStartStepFloatModule_basic", "ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic", "ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic", "ArangeZeroElementOutputModule_basic",
"ArgmaxModule_with_dim", "ArgmaxModule_with_dim",
"AtenComplex64Module_basic", "AtenComplex64Module_basic",
"AtenEyeMModuleCPUDevice_basic",
"AtenEyeMModuleDefaultDtype_basic",
"AtenEyeMModuleFalsePinMemory_basic",
"AtenEyeMModuleFloat2D_basic",
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleCPUDevice_basic",
"AtenEyeModuleDefaultDtype_basic",
"AtenEyeModuleFalsePinMemory_basic",
"AtenEyeModuleFloat2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenFloatScalarModule_basic", "AtenFloatScalarModule_basic",
"AtenInstanceNormModule_basic",
"AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpConstTrueModule_basic",
"AtenIntBoolOpModule_basic", "AtenIntBoolOpModule_basic",
@ -437,8 +425,6 @@ STABLEHLO_PASS_SET = {
"BroadcastListConstructWithMinusOneModule_basic", "BroadcastListConstructWithMinusOneModule_basic",
"BroadcastToSameRankStaticModule_basic", "BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CeilFloatModule_basic", "CeilFloatModule_basic",
"ChunkListUnpackUneven_Module_basic", "ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic", "ChunkListUnpack_Module_basic",
@ -454,7 +440,6 @@ STABLEHLO_PASS_SET = {
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Convolution2DStaticModule_basic", "Convolution2DStaticModule_basic",
"ConvolutionBackwardModule2DStatic_basic", "ConvolutionBackwardModule2DStatic_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionModule2DTransposeStridedStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic",
"CosineSimilarityStaticBroadcastModule_basic", "CosineSimilarityStaticBroadcastModule_basic",
"CosineSimilarityStaticModule_basic", "CosineSimilarityStaticModule_basic",
@ -466,12 +451,6 @@ STABLEHLO_PASS_SET = {
"DivIntModule_basic", "DivIntModule_basic",
"DropoutEvalFloatModule_basic", "DropoutEvalFloatModule_basic",
"DropoutEvalIntModule_basic", "DropoutEvalIntModule_basic",
"DropoutTrainStaticShapeModule_basic",
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
"ElementwiseAbsFloatModule_basic", "ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic", "ElementwiseAbsIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic",
@ -504,8 +483,8 @@ STABLEHLO_PASS_SET = {
"ElementwiseExpModule_basic", "ElementwiseExpModule_basic",
"ElementwiseFloorIntModule_basic", "ElementwiseFloorIntModule_basic",
"ElementwiseFloorModule_basic", "ElementwiseFloorModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseGeluApproximateTanhModule_basic", "ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseLeakyReluStaticModule_basic", "ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLogModule_basic", "ElementwiseLogModule_basic",
"ElementwiseNanToNumModule_Basic", "ElementwiseNanToNumModule_Basic",
@ -513,9 +492,9 @@ STABLEHLO_PASS_SET = {
"ElementwiseNeIntTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic",
"ElementwiseNegModule_basic", "ElementwiseNegModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseOrTensorStaticShapeModule_basic",
"ElementwisePreluStaticModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic", "ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorStaticModule_basic", "ElementwisePowTensorStaticModule_basic",
"ElementwisePreluStaticModule_basic",
"ElementwiseReciprocalModule_basic", "ElementwiseReciprocalModule_basic",
"ElementwiseReluModule_basic", "ElementwiseReluModule_basic",
"ElementwiseRsqrtModule_basic", "ElementwiseRsqrtModule_basic",
@ -526,8 +505,6 @@ STABLEHLO_PASS_SET = {
"ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeI64ToI8Module_basic",
"ElementwiseToDtypeIdentityModule_basic", "ElementwiseToDtypeIdentityModule_basic",
"ElementwiseUnaryModule_basic", "ElementwiseUnaryModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"EmptyLikeMemoryFormatModule_basic", "EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype", "EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory", "EmptyLikeModule_falsePinMemory",
@ -541,13 +518,14 @@ STABLEHLO_PASS_SET = {
"EmptyStridedModule_basic", "EmptyStridedModule_basic",
"EqIntModule_basic", "EqIntModule_basic",
"ExpandAsIntModule_basic", "ExpandAsIntModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithFloat32Static_basic",
"Fill_TensorFloat64WithFloat32_basic", "Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic", "Fill_TensorFloat64WithFloat64_basic",
"Fill_TensorFloat64WithInt64Static_basic", "Fill_TensorFloat64WithInt64Static_basic",
"Fill_TensorFloat64WithInt64_basic", "Fill_TensorFloat64WithInt64_basic",
"FlattenRank0Module_basic", "FlattenRank0Module_basic",
"FlattenStaticModule_basic",
"FlipModuleStaticShape_basic", "FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic", "FlipNegativeIndexModule_basic",
"FullLikeModuleDefaultDtype_basic", "FullLikeModuleDefaultDtype_basic",
@ -564,29 +542,26 @@ STABLEHLO_PASS_SET = {
"FullModuleFloat3D_basic", "FullModuleFloat3D_basic",
"FullModuleInt2D_basic", "FullModuleInt2D_basic",
"FullModuleInt3D_basic", "FullModuleInt3D_basic",
"GatherStaticModule_basic",
"GeFloatIntModule_basic", "GeFloatIntModule_basic",
"GeFloatModule_basic", "GeFloatModule_basic",
"GeIntModule_basic", "GeIntModule_basic",
"GeluBackwardModule_basic", "GeluBackwardModule_basic",
"GluStaticModule_basic", "GluStaticModule_basic",
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
"GtFloatIntModule_basic", "GtFloatIntModule_basic",
"GtIntModule_basic", "GtIntModule_basic",
"IndexTensorMultiIndexStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticModule_basic", "IndexTensorStaticModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"IntFloatModule_basic", "IntFloatModule_basic",
"IsFloatingPointFloat_True", "IsFloatingPointFloat_True",
"IsFloatingPointInt_False", "IsFloatingPointInt_False",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"LeakyReluBackwardStaticModule_basic", "LeakyReluBackwardStaticModule_basic",
"LenStrModule_basic", "LenStrModule_basic",
"LiftFreshCopyModule_basic", "LiftFreshCopyModule_basic",
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic", "MaskedFillScalarFloatValueStaticModule_basic",
"MaskedFillScalarIntValueStaticModule_basic", "MaskedFillScalarIntValueStaticModule_basic",
"Matmul4dStatic_basic", "Matmul4dStatic_basic",
@ -595,8 +570,6 @@ STABLEHLO_PASS_SET = {
"Matmul_matvec", "Matmul_matvec",
"Matmul_vecmat", "Matmul_vecmat",
"MaxPool2dStaticModule_basic", "MaxPool2dStaticModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
"MeanDimAllReduceKeepdimModule_basic",
"MeanDimAllReduceModule_basic", "MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic", "MeanDimEmptyDimModule_basic",
"MeanDtypeModule_basic", "MeanDtypeModule_basic",
@ -619,10 +592,6 @@ STABLEHLO_PASS_SET = {
"NarrowVerticalTest2_basic", "NarrowVerticalTest2_basic",
"NarrowVerticalTest_basic", "NarrowVerticalTest_basic",
"NativeDropoutEvalFloatModule_basic", "NativeDropoutEvalFloatModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"NativeGroupNormModule_basic",
"NativeLayerNormModule4D_basic",
"NativeLayerNormModule_basic",
"NeFloatIntModule_basic", "NeFloatIntModule_basic",
"NeIntModule_basic", "NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic", "NewEmptyModuleDefaultDtype_basic",
@ -654,7 +623,6 @@ STABLEHLO_PASS_SET = {
"NewZerosModuleInt2D_basic", "NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic", "NewZerosModuleInt3D_basic",
"NewZerosStaticModuleLayoutStrided_basic", "NewZerosStaticModuleLayoutStrided_basic",
"NormalizeModule_basic",
"NumToTensorFloatModule_basic", "NumToTensorFloatModule_basic",
"NumToTensorIntModule_basic", "NumToTensorIntModule_basic",
"NumelModule_basic", "NumelModule_basic",
@ -682,7 +650,6 @@ STABLEHLO_PASS_SET = {
"PrimMinIntModule_basic", "PrimMinIntModule_basic",
"PrimsConvertElementTypeModule_basic", "PrimsConvertElementTypeModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic", "PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic", "PrimsViewOfZeroRankModule_basic",
"RandIntDtypeModule_basic", "RandIntDtypeModule_basic",
@ -690,39 +657,8 @@ STABLEHLO_PASS_SET = {
"RandIntLowModule_basic", "RandIntLowModule_basic",
"RandIntModule_basic", "RandIntModule_basic",
"RandIntPinMemoryModule_basic", "RandIntPinMemoryModule_basic",
"RandModule_basic",
"ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic",
"ReduceFrobeniusNormModule_basic", "ReduceFrobeniusNormModule_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxAlongDimNegative_basic",
"ReduceMaxAlongDimSignedInt_basic",
"ReduceMaxAlongDim_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceMinFloatModule_basic",
"ReduceMinSignedIntModule_basic",
"ReduceMinUnsignedIntModule_basic",
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic", "ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeAsModule_basic",
"ReshapeExpandModule_basic",
"ReturnThreeTensorFloat32_basic", "ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic", "ReturnTwoTensorF32I64_basic",
"RollModule_basic", "RollModule_basic",
@ -734,8 +670,6 @@ STABLEHLO_PASS_SET = {
"ScalarTensorFloat32Module_basic", "ScalarTensorFloat32Module_basic",
"ScalarTensorInt32Module_basic", "ScalarTensorInt32Module_basic",
"ScalarTensorInt64Module_basic", "ScalarTensorInt64Module_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SelectScattertStaticModule_basic",
"SliceModule_basic", "SliceModule_basic",
"SliceNegIdxModule_basic", "SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic", "SliceOutOfLowerBoundStartIndexModule_basic",
@ -761,10 +695,8 @@ STABLEHLO_PASS_SET = {
"SqrtIntConstantModule_basic", "SqrtIntConstantModule_basic",
"SqrtIntModule_basic", "SqrtIntModule_basic",
"SqueezeDimModule_identity", "SqueezeDimModule_identity",
"SqueezeDimModule_static",
"SqueezeDimModule_unitDim", "SqueezeDimModule_unitDim",
"SqueezeModule_allUnitDim", "SqueezeModule_allUnitDim",
"SqueezeModule_static",
"SubFloatModule_basic", "SubFloatModule_basic",
"SubIntModule_basic", "SubIntModule_basic",
"TModuleRank0_basic", "TModuleRank0_basic",
@ -784,17 +716,14 @@ STABLEHLO_PASS_SET = {
"TestF16Return_basic", "TestF16Return_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic", "TestMultipleTensorAndPrimitiveTypesReturn_basic",
"TestMultipleTensorReturn_basic", "TestMultipleTensorReturn_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"ToCopyBoolDTypeStaticModule_basic", "ToCopyBoolDTypeStaticModule_basic",
"ToDtypeBoolLayoutNoneStaticModule_basic", "ToDtypeBoolLayoutNoneStaticModule_basic",
"ToDtypeLayoutCPUModule_basic", "ToDtypeLayoutCPUModule_basic",
"ToDtypeLayoutNoneModule_basic", "ToDtypeLayoutNoneModule_basic",
"ToDtypeLayoutStridedModule_basic", "ToDtypeLayoutStridedModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"TransposeIntModule_basic", "TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic", "TransposeIntNegDimsModule_basic",
"TriuBroadcastModule_basic",
"TriuModule_basic",
"TupleModule_basic", "TupleModule_basic",
"TypeAsDifferentModule_basic", "TypeAsDifferentModule_basic",
"TypeAsSameModule_basic", "TypeAsSameModule_basic",
@ -806,41 +735,8 @@ STABLEHLO_PASS_SET = {
"TypeConversionI1ToI64Module_basic", "TypeConversionI1ToI64Module_basic",
"TypeConversionI32ToI64Module_basic", "TypeConversionI32ToI64Module_basic",
"TypeConversionI64ToI32Module_basic", "TypeConversionI64ToI32Module_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic",
"UniformNoCorrelationModule_basic",
"UniformStaticShapeModule_basic",
"UnsafeView1DFoldModule_basic", "UnsafeView1DFoldModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
"UnsafeViewExpandModule_basic",
"View1DFoldModule_basic", "View1DFoldModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandDynamicDimModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewExpandOnesModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
"ZeroFloat32Module_basic", "ZeroFloat32Module_basic",
"ZeroInt32Module_basic", "ZeroInt32Module_basic",
"ZeroInt64Module_basic", "ZeroInt64Module_basic",
@ -854,18 +750,38 @@ STABLEHLO_PASS_SET = {
"ZerosModuleFloat3D_basic", "ZerosModuleFloat3D_basic",
"ZerosModuleInt2D_basic", "ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic", "ZerosModuleInt3D_basic",
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
} }
STABLEHLO_CRASHING_SET = { STABLEHLO_CRASHING_SET = {
"AtenEmbeddingBagSumExample_basic", "AtenEmbeddingBagSumExample_basic",
# Something is broken with stablehlo.reduce right now as any
# of these tests can randomly fail. Removing until someone can debug:
"ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxAlongDimNegative_basic",
"ReduceMaxAlongDimSignedInt_basic",
"ReduceMaxAlongDim_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceMinFloatModule_basic",
"ReduceMinSignedIntModule_basic",
"ReduceMinUnsignedIntModule_basic",
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
} }
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
@ -900,7 +816,6 @@ TOSA_PASS_SET = {
"ArgmaxIntModule_multiple_maxs", "ArgmaxIntModule_multiple_maxs",
"ArgmaxModule_basic", "ArgmaxModule_basic",
"ArgmaxModule_keepDim", "ArgmaxModule_keepDim",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic", "AtenComplex64Module_basic",
"AtenEyeMModuleCPUDevice_basic", "AtenEyeMModuleCPUDevice_basic",
"AtenEyeMModuleDefaultDtype_basic", "AtenEyeMModuleDefaultDtype_basic",
@ -1185,7 +1100,6 @@ TOSA_PASS_SET = {
"PrimsSqueezeModule_basic", "PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic", "PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic", "PrimsViewOfZeroRankModule_basic",
"ReduceAmaxKeepDim_basic",
"ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic", "ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimFloatModule_basic", "ReduceSumDimIntListKeepDimFloatModule_basic",
@ -1196,8 +1110,11 @@ TOSA_PASS_SET = {
"ReduceSumUnsignedIntModule_basic", "ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic", "RepeatModule_basic",
"ResNet18StaticModule_basic", "ResNet18StaticModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeAsModule_basic", "ReshapeAsModule_basic",
"ReshapeCollapseModule_basic", "ReshapeCollapseModule_basic",
"ReshapeExpandModule_basic",
"ReturnThreeTensorFloat32_basic", "ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic", "ReturnTwoTensorF32I64_basic",
"RsubFloatModule_basic", "RsubFloatModule_basic",
@ -1211,8 +1128,6 @@ TOSA_PASS_SET = {
"SiluModule_basic", "SiluModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic", "SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStaticModule_basic", "SliceStaticModule_basic",
"SoftmaxIntModule_basic",
"SoftmaxIntNegDimModule_basic",
"SplitTensorGetItem_Module_basic", "SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic", "SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic", "SplitTensorListUnpackModule_basic",
@ -1260,13 +1175,19 @@ TOSA_PASS_SET = {
"UnflattenIntStaticModule_basic", "UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic", "UnflattenStaticModule_basic",
"UnsafeView1DFoldModule_basic", "UnsafeView1DFoldModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
"UnsafeViewExpandModule_basic", "UnsafeViewExpandModule_basic",
"View1DFoldModule_basic", "View1DFoldModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseInferredDimModule_basic", "ViewCollapseInferredDimModule_basic",
"ViewCollapseOnesMiddleModule_basic", "ViewCollapseOnesMiddleModule_basic",
"ViewDoubleMergeStaticModule_basic", "ViewDoubleMergeStaticModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandCollapseModule_basic", "ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic", "ViewExpandCollapseWithOnesModule_basic",
"ViewExpandDynamicDimModule_basic",
"ViewExpandInferredDimModule_basic", "ViewExpandInferredDimModule_basic",
"ViewExpandModule_basic", "ViewExpandModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic", "ViewExpandOnesBeforeAndAfterModule_basic",
@ -1275,6 +1196,9 @@ TOSA_PASS_SET = {
"ViewExpandOnesModule_basic", "ViewExpandOnesModule_basic",
"ViewFiveTestStaticModule_basic", "ViewFiveTestStaticModule_basic",
"ViewNegativeStaticModule_basic", "ViewNegativeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"ViewNoChangeStaticModule_basic", "ViewNoChangeStaticModule_basic",
"ViewOffsetBackwardTestStaticModule_basic", "ViewOffsetBackwardTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic", "ViewOffsetTestStaticModule_basic",
@ -1287,8 +1211,6 @@ TOSA_PASS_SET = {
"ZerosModuleInt2D_basic", "ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic", "ZerosModuleInt3D_basic",
"_LogSoftmaxModuleStable_basic", "_LogSoftmaxModuleStable_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
"LinspaceModule_basic", "LinspaceModule_basic",
"LinspaceOneSizeModule_basic", "LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic", "LinspaceTwoSizeModule_basic",
@ -1310,10 +1232,14 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic", "TorchPrimLoopForLikeTensorArgModule_basic",
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
"ViewSizeDimFollowedByExpandedOnesModule_basic",
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeFromOtherTensor_basic",
}) - { }) - {
### Test failing in make_fx_tosa but not in tosa ### Test failing in make_fx_tosa but not in tosa
"FlattenDynamicModuleCollapseAll_basic",
# Dynamic shape, has extra unsupported broadcast ops # Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d", "Matmul_3d",
"MatmulStaticBroadcast_basic", "MatmulStaticBroadcast_basic",
@ -1343,6 +1269,18 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
# failed to legalize operation 'torch.operator' # failed to legalize operation 'torch.operator'
"ElementwisePreluModule_basic", "ElementwisePreluModule_basic",
"ElementwisePreluStaticModule_basic", "ElementwisePreluStaticModule_basic",
# Shape Related failures
"ReshapeExpandModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
"ViewCollapseModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandDynamicDimModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
} }
LTC_CRASHING_SET = { LTC_CRASHING_SET = {

View File

@ -328,7 +328,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
) )
if batch_dim > 0: if batch_dim > 0:
batch = ",".join(f"d{d}:dense" for d in range(batch_dim)) batch = ",".join(f"d{d}:batch" for d in range(batch_dim))
lvls = f"{batch},{lvls}" lvls = f"{batch},{lvls}"
if dense_dim > 0: if dense_dim > 0:

View File

@ -160,7 +160,6 @@ def sparse_jit(f, *args, **kwargs):
xargs = [] xargs = []
for a in args: for a in args:
if a.layout is torch.sparse_coo: if a.layout is torch.sparse_coo:
xargs.append(a.values().numpy())
# Construct the additional position array required by MLIR with data # Construct the additional position array required by MLIR with data
# array([0, nnz]). # array([0, nnz]).
xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy()) xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy())
@ -168,14 +167,15 @@ def sparse_jit(f, *args, **kwargs):
# MLIR SoA COO representation. # MLIR SoA COO representation.
for idx in a.indices(): for idx in a.indices():
xargs.append(idx.numpy()) xargs.append(idx.numpy())
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
xargs.append(a.values().numpy()) xargs.append(a.values().numpy())
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
xargs.append(a.crow_indices().numpy()) xargs.append(a.crow_indices().numpy())
xargs.append(a.col_indices().numpy()) xargs.append(a.col_indices().numpy())
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
xargs.append(a.values().numpy()) xargs.append(a.values().numpy())
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
xargs.append(a.ccol_indices().numpy()) xargs.append(a.ccol_indices().numpy())
xargs.append(a.row_indices().numpy()) xargs.append(a.row_indices().numpy())
xargs.append(a.values().numpy())
else: else:
xargs.append(a.numpy()) xargs.append(a.numpy())
# Invoke. # Invoke.
@ -302,7 +302,7 @@ def test_sparse_SpMM():
@run @run
# CHECK-LABEL: test_sparse_eltwise # CHECK-LABEL: test_sparse_eltwise
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> # CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main( # CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32> { # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32> {
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32> # CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32>
@ -331,6 +331,12 @@ def test_sparse_SpMM():
# CHECK: [-61. -62.] # CHECK: [-61. -62.]
# CHECK: [-63. -64.]{{\]\]}} # CHECK: [-63. -64.]{{\]\]}}
# #
# CHECK: torch.mlir.batch
# CHECK: {{\[\[}}[ -1. -2.]
# CHECK: [ -3. -4.]
# ...
# CHECK: [-61. -62.]
# CHECK: [-63. -64.]{{\]\]}}
def test_sparse_eltwise(): def test_sparse_eltwise():
class EltNet(torch.nn.Module): class EltNet(torch.nn.Module):
def __init__(self): def __init__(self):
@ -345,8 +351,8 @@ def test_sparse_eltwise():
) )
# This yields a **batched** CSR. # This yields a **batched** CSR.
sparse_input = dense_input.to_sparse_csr(dense_dim=0) batch_input = dense_input.to_sparse_csr(dense_dim=0)
m = export_and_import(net, sparse_input) m = export_and_import(net, batch_input)
print(m) print(m)
# This yields a plain CSR with dense **sub**tensor # This yields a plain CSR with dense **sub**tensor
@ -358,11 +364,12 @@ def test_sparse_eltwise():
# #
# TODO: note several issues that need to be fixed # TODO: note several issues that need to be fixed
# (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result # (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result
# (2) for dense_dim=0, this will need a dense(batched) property
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
res1 = net(sparse_input) res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input) res2 = sparse_jit(net, sparse_input)
res3 = sparse_jit(net, batch_input)
print("torch.sparse") print("torch.sparse")
print(res1) print(res1)
print("torch.mlir") print("torch.mlir")
print(res2) print(res2)
print("torch.mlir.batch")
print(res3)