diff --git a/externals/llvm-project b/externals/llvm-project index e5ed7b6e2..0030fc4ac 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit e5ed7b6e2fd368b722b6359556cd0125881e7638 +Subproject commit 0030fc4ac74a9ce645adb9d59e108da4d4d11818 diff --git a/externals/stablehlo b/externals/stablehlo index 4ac26f878..271e8634d 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit 4ac26f8786d491c5d8376e6e563d1b72af09de75 +Subproject commit 271e8634de184fbfafd677d3876170feb6d08c97 diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2e4dc0d09..6dba81a64 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -388,24 +388,12 @@ STABLEHLO_PASS_SET = { "ArangeStartNegativeStepIntModule_basic", "ArangeStartOutDtypeModule_basic", "ArangeStartOutModule_basic", - "ArangeStartOutViewModule_basic", "ArangeStartStepFloatModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", "ArgmaxModule_with_dim", "AtenComplex64Module_basic", - "AtenEyeMModuleCPUDevice_basic", - "AtenEyeMModuleDefaultDtype_basic", - "AtenEyeMModuleFalsePinMemory_basic", - "AtenEyeMModuleFloat2D_basic", - "AtenEyeMModuleInt2D_basic", - "AtenEyeModuleCPUDevice_basic", - "AtenEyeModuleDefaultDtype_basic", - "AtenEyeModuleFalsePinMemory_basic", - "AtenEyeModuleFloat2D_basic", - "AtenEyeModuleInt2D_basic", "AtenFloatScalarModule_basic", - "AtenInstanceNormModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", @@ -437,8 +425,6 @@ STABLEHLO_PASS_SET = { "BroadcastListConstructWithMinusOneModule_basic", "BroadcastToSameRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", "CeilFloatModule_basic", "ChunkListUnpackUneven_Module_basic", "ChunkListUnpack_Module_basic", @@ -454,7 +440,6 @@ STABLEHLO_PASS_SET = { "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Convolution2DStaticModule_basic", "ConvolutionBackwardModule2DStatic_basic", - "ConvolutionBackwardModule2DStrided_basic", "ConvolutionModule2DTransposeStridedStatic_basic", "CosineSimilarityStaticBroadcastModule_basic", "CosineSimilarityStaticModule_basic", @@ -466,12 +451,6 @@ STABLEHLO_PASS_SET = { "DivIntModule_basic", "DropoutEvalFloatModule_basic", "DropoutEvalIntModule_basic", - "DropoutTrainStaticShapeModule_basic", - "EinsumStaticContractRhsModule_basic", - "EinsumStaticFourDimensionModule_basic", - "EinsumStaticModule_basic", - "EinsumStaticWithEllipsisSlicingModule_basic", - "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", "ElementwiseAbsFloatModule_basic", "ElementwiseAbsIntModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", @@ -504,8 +483,8 @@ STABLEHLO_PASS_SET = { "ElementwiseExpModule_basic", "ElementwiseFloorIntModule_basic", "ElementwiseFloorModule_basic", - "ElementwiseGeluModule_basic", "ElementwiseGeluApproximateTanhModule_basic", + "ElementwiseGeluModule_basic", "ElementwiseLeakyReluStaticModule_basic", "ElementwiseLogModule_basic", "ElementwiseNanToNumModule_Basic", @@ -513,9 +492,9 @@ STABLEHLO_PASS_SET = { "ElementwiseNeIntTensorStaticModule_basic", "ElementwiseNegModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", - "ElementwisePreluStaticModule_basic", "ElementwisePowTensorBroadcastStaticModule_basic", "ElementwisePowTensorStaticModule_basic", + "ElementwisePreluStaticModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseReluModule_basic", "ElementwiseRsqrtModule_basic", @@ -526,8 +505,6 @@ STABLEHLO_PASS_SET = { "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeIdentityModule_basic", "ElementwiseUnaryModule_basic", - "ElementwiseWhereScalarOtherStaticModule_basic", - "ElementwiseWhereScalarSelfStaticModule_basic", "EmptyLikeMemoryFormatModule_basic", "EmptyLikeModule_defaultDtype", "EmptyLikeModule_falsePinMemory", @@ -541,13 +518,14 @@ STABLEHLO_PASS_SET = { "EmptyStridedModule_basic", "EqIntModule_basic", "ExpandAsIntModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", "Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithFloat32_basic", "Fill_TensorFloat64WithFloat64_basic", "Fill_TensorFloat64WithInt64Static_basic", "Fill_TensorFloat64WithInt64_basic", "FlattenRank0Module_basic", - "FlattenStaticModule_basic", "FlipModuleStaticShape_basic", "FlipNegativeIndexModule_basic", "FullLikeModuleDefaultDtype_basic", @@ -564,29 +542,26 @@ STABLEHLO_PASS_SET = { "FullModuleFloat3D_basic", "FullModuleInt2D_basic", "FullModuleInt3D_basic", - "GatherStaticModule_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", "GeluBackwardModule_basic", "GluStaticModule_basic", - "GroupNormModule_basic", - "GroupNormNoWeightAndBiasModule_basic", "GtFloatIntModule_basic", "GtIntModule_basic", "IndexTensorMultiIndexStaticModule_basic", - "IndexTensorStaticContiguousWithNoneModule_basic", "IndexTensorStaticModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", "IntFloatModule_basic", "IsFloatingPointFloat_True", "IsFloatingPointInt_False", - "LayerNormLastDimModule_basic", - "LayerNormModule_basic", - "LayerNormNormalizeOverAllDimsModule_basic", "LeakyReluBackwardStaticModule_basic", "LenStrModule_basic", "LiftFreshCopyModule_basic", + "LinspaceDtypeModule_basic", + "LinspaceEmptyModule_basic", + "LinspaceModule_basic", + "LinspaceOneSizeModule_basic", + "LinspaceTwoSizeModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", "MaskedFillScalarIntValueStaticModule_basic", "Matmul4dStatic_basic", @@ -595,8 +570,6 @@ STABLEHLO_PASS_SET = { "Matmul_matvec", "Matmul_vecmat", "MaxPool2dStaticModule_basic", - "MaxPool2dWithIndicesStaticModule_basic", - "MeanDimAllReduceKeepdimModule_basic", "MeanDimAllReduceModule_basic", "MeanDimEmptyDimModule_basic", "MeanDtypeModule_basic", @@ -619,10 +592,6 @@ STABLEHLO_PASS_SET = { "NarrowVerticalTest2_basic", "NarrowVerticalTest_basic", "NativeDropoutEvalFloatModule_basic", - "NativeDropoutTrainStaticShapeModule_basic", - "NativeGroupNormModule_basic", - "NativeLayerNormModule4D_basic", - "NativeLayerNormModule_basic", "NeFloatIntModule_basic", "NeIntModule_basic", "NewEmptyModuleDefaultDtype_basic", @@ -654,7 +623,6 @@ STABLEHLO_PASS_SET = { "NewZerosModuleInt2D_basic", "NewZerosModuleInt3D_basic", "NewZerosStaticModuleLayoutStrided_basic", - "NormalizeModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", "NumelModule_basic", @@ -682,7 +650,6 @@ STABLEHLO_PASS_SET = { "PrimMinIntModule_basic", "PrimsConvertElementTypeModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", - "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", "RandIntDtypeModule_basic", @@ -690,39 +657,8 @@ STABLEHLO_PASS_SET = { "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", - "RandModule_basic", - "ReduceAmaxMultiDim_basic", - "ReduceAmaxOutOfOrderDim_basic", - "ReduceAmaxSingleDim_basic", "ReduceFrobeniusNormModule_basic", - "ReduceMaxAllDims_basic", - "ReduceMaxAlongDimNegative_basic", - "ReduceMaxAlongDimSignedInt_basic", - "ReduceMaxAlongDim_basic", - "ReduceMaxFloatModule_basic", - "ReduceMaxSignedIntModule_basic", - "ReduceMaxUnsignedIntModule_basic", - "ReduceMinFloatModule_basic", - "ReduceMinSignedIntModule_basic", - "ReduceMinUnsignedIntModule_basic", - "ReduceSumDimIntListDtypeFloatModule_basic", - "ReduceSumDimIntListDtypeIntModule_basic", - "ReduceSumDimIntListElementTypeBoolModule_basic", - "ReduceSumDimIntListEmptyDimModule_basic", - "ReduceSumDimIntListFloatModule_basic", - "ReduceSumDimIntListIntModule_basic", - "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", - "ReduceSumDtypeFloatModule_basic", - "ReduceSumDtypeIntModule_basic", - "ReduceSumElementTypeBoolModule_basic", - "ReduceSumFloatModule_basic", - "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", - "RepeatModule_basic", - "ReshapeAliasCollapseModule_basic", - "ReshapeAliasExpandModule_basic", - "ReshapeAsModule_basic", - "ReshapeExpandModule_basic", "ReturnThreeTensorFloat32_basic", "ReturnTwoTensorF32I64_basic", "RollModule_basic", @@ -734,8 +670,6 @@ STABLEHLO_PASS_SET = { "ScalarTensorFloat32Module_basic", "ScalarTensorInt32Module_basic", "ScalarTensorInt64Module_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", - "SelectScattertStaticModule_basic", "SliceModule_basic", "SliceNegIdxModule_basic", "SliceOutOfLowerBoundStartIndexModule_basic", @@ -761,10 +695,8 @@ STABLEHLO_PASS_SET = { "SqrtIntConstantModule_basic", "SqrtIntModule_basic", "SqueezeDimModule_identity", - "SqueezeDimModule_static", "SqueezeDimModule_unitDim", "SqueezeModule_allUnitDim", - "SqueezeModule_static", "SubFloatModule_basic", "SubIntModule_basic", "TModuleRank0_basic", @@ -784,17 +716,14 @@ STABLEHLO_PASS_SET = { "TestF16Return_basic", "TestMultipleTensorAndPrimitiveTypesReturn_basic", "TestMultipleTensorReturn_basic", - "TileBigDimsSizeModule_basic", - "TileSmallDimsSizeModule_basic", "ToCopyBoolDTypeStaticModule_basic", "ToDtypeBoolLayoutNoneStaticModule_basic", "ToDtypeLayoutCPUModule_basic", "ToDtypeLayoutNoneModule_basic", "ToDtypeLayoutStridedModule_basic", + "TorchPrimLoopForLikeTensorArgModule_basic", "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", - "TriuBroadcastModule_basic", - "TriuModule_basic", "TupleModule_basic", "TypeAsDifferentModule_basic", "TypeAsSameModule_basic", @@ -806,41 +735,8 @@ STABLEHLO_PASS_SET = { "TypeConversionI1ToI64Module_basic", "TypeConversionI32ToI64Module_basic", "TypeConversionI64ToI32Module_basic", - "UnbindIntGetItem_Module_basic", - "UnbindIntListUnpack_Module_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenStaticModule_basic", - "UniformNoCorrelationModule_basic", - "UniformStaticShapeModule_basic", "UnsafeView1DFoldModule_basic", - "UnsafeViewCollapseModule_basic", - "UnsafeViewDynamicExpandModule_basic", - "UnsafeViewExpandModule_basic", "View1DFoldModule_basic", - "ViewCollapseInferredDimModule_basic", - "ViewCollapseModule_basic", - "ViewCollapseOnesMiddleModule_basic", - "ViewDynamicExpandCollapseModule_basic", - "ViewDynamicExpandModule_basic", - "ViewExpandCollapseModule_basic", - "ViewExpandCollapseWithOnesModule_basic", - "ViewExpandDynamicDimModule_basic", - "ViewExpandInferredDimModule_basic", - "ViewExpandModule_basic", - "ViewExpandOnesBeforeAndAfterModule_basic", - "ViewExpandOnesMiddleModule_basic", - "ViewExpandOnesModule_basic", - "ViewNegativeStaticModule_basic", - "ViewNoChange1dModule_basic", - "ViewNoChange2dModule_basic", - "ViewNoChange3dModule_basic", - "ViewNoChangeStaticModule_basic", - "ViewOffsetBackwardTestStaticModule_basic", - "ViewOffsetTestStaticModule_basic", - "ViewTwoFiveThreeStaticModule_basic", - "ViewTwoToThreeStaticModule_basic", "ZeroFloat32Module_basic", "ZeroInt32Module_basic", "ZeroInt64Module_basic", @@ -854,18 +750,38 @@ STABLEHLO_PASS_SET = { "ZerosModuleFloat3D_basic", "ZerosModuleInt2D_basic", "ZerosModuleInt3D_basic", - "LinspaceDtypeModule_basic", - "LinspaceEmptyModule_basic", - "LinspaceModule_basic", - "LinspaceOneSizeModule_basic", - "LinspaceTwoSizeModule_basic", - "FakeQuantizePerTensorAffineModule_basic", - "FakeQuantizePerTensorAffineRoundToEvenModule_basic", - "TorchPrimLoopForLikeTensorArgModule_basic", } STABLEHLO_CRASHING_SET = { "AtenEmbeddingBagSumExample_basic", + + # Something is broken with stablehlo.reduce right now as any + # of these tests can randomly fail. Removing until someone can debug: + "ReduceAmaxMultiDim_basic", + "ReduceAmaxOutOfOrderDim_basic", + "ReduceAmaxSingleDim_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxAlongDimNegative_basic", + "ReduceMaxAlongDimSignedInt_basic", + "ReduceMaxAlongDim_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceMinFloatModule_basic", + "ReduceMinSignedIntModule_basic", + "ReduceMinUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", + "ReduceSumDimIntListFloatModule_basic", + "ReduceSumDimIntListIntModule_basic", + "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", } # Write the TOSA set as a "passing" set as it is very early in development @@ -900,7 +816,6 @@ TOSA_PASS_SET = { "ArgmaxIntModule_multiple_maxs", "ArgmaxModule_basic", "ArgmaxModule_keepDim", - "ArgmaxModule_with_dim", "AtenComplex64Module_basic", "AtenEyeMModuleCPUDevice_basic", "AtenEyeMModuleDefaultDtype_basic", @@ -1185,7 +1100,6 @@ TOSA_PASS_SET = { "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", - "ReduceAmaxKeepDim_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", "ReduceSumDimIntListKeepDimFloatModule_basic", @@ -1196,8 +1110,11 @@ TOSA_PASS_SET = { "ReduceSumUnsignedIntModule_basic", "RepeatModule_basic", "ResNet18StaticModule_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", "ReshapeAsModule_basic", "ReshapeCollapseModule_basic", + "ReshapeExpandModule_basic", "ReturnThreeTensorFloat32_basic", "ReturnTwoTensorF32I64_basic", "RsubFloatModule_basic", @@ -1211,8 +1128,6 @@ TOSA_PASS_SET = { "SiluModule_basic", "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStaticModule_basic", - "SoftmaxIntModule_basic", - "SoftmaxIntNegDimModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", "SplitTensorListUnpackModule_basic", @@ -1260,13 +1175,19 @@ TOSA_PASS_SET = { "UnflattenIntStaticModule_basic", "UnflattenStaticModule_basic", "UnsafeView1DFoldModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", "UnsafeViewExpandModule_basic", "View1DFoldModule_basic", + "ViewCollapseModule_basic", "ViewCollapseInferredDimModule_basic", "ViewCollapseOnesMiddleModule_basic", "ViewDoubleMergeStaticModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandModule_basic", "ViewExpandCollapseModule_basic", "ViewExpandCollapseWithOnesModule_basic", + "ViewExpandDynamicDimModule_basic", "ViewExpandInferredDimModule_basic", "ViewExpandModule_basic", "ViewExpandOnesBeforeAndAfterModule_basic", @@ -1275,6 +1196,9 @@ TOSA_PASS_SET = { "ViewExpandOnesModule_basic", "ViewFiveTestStaticModule_basic", "ViewNegativeStaticModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", "ViewNoChangeStaticModule_basic", "ViewOffsetBackwardTestStaticModule_basic", "ViewOffsetTestStaticModule_basic", @@ -1287,8 +1211,6 @@ TOSA_PASS_SET = { "ZerosModuleInt2D_basic", "ZerosModuleInt3D_basic", "_LogSoftmaxModuleStable_basic", - "_LogSoftmaxModule_basic", - "_SoftmaxModule_basic", "LinspaceModule_basic", "LinspaceOneSizeModule_basic", "LinspaceTwoSizeModule_basic", @@ -1310,10 +1232,14 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "TorchPrimLoopForLikeTensorArgModule_basic", + "ViewSizeDimFollowedByCollapsedOnesModule_basic", + "ViewSizeDimFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic", + "ViewSizeDimLedByCollapsedOnesModule_basic", + "ViewSizeFromOtherTensor_basic", }) - { ### Test failing in make_fx_tosa but not in tosa - "FlattenDynamicModuleCollapseAll_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", "MatmulStaticBroadcast_basic", @@ -1343,6 +1269,18 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { # failed to legalize operation 'torch.operator' "ElementwisePreluModule_basic", "ElementwisePreluStaticModule_basic", + + # Shape Related failures + "ReshapeExpandModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", + "ViewCollapseModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandModule_basic", + "ViewExpandDynamicDimModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", } LTC_CRASHING_SET = { diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index f622a0b93..23ed415d5 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -328,7 +328,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: ) if batch_dim > 0: - batch = ",".join(f"d{d}:dense" for d in range(batch_dim)) + batch = ",".join(f"d{d}:batch" for d in range(batch_dim)) lvls = f"{batch},{lvls}" if dense_dim > 0: diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 52f10de32..93144daf9 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -160,7 +160,6 @@ def sparse_jit(f, *args, **kwargs): xargs = [] for a in args: if a.layout is torch.sparse_coo: - xargs.append(a.values().numpy()) # Construct the additional position array required by MLIR with data # array([0, nnz]). xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy()) @@ -168,14 +167,15 @@ def sparse_jit(f, *args, **kwargs): # MLIR SoA COO representation. for idx in a.indices(): xargs.append(idx.numpy()) - elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: xargs.append(a.values().numpy()) + elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: xargs.append(a.crow_indices().numpy()) xargs.append(a.col_indices().numpy()) - elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: xargs.append(a.values().numpy()) + elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: xargs.append(a.ccol_indices().numpy()) xargs.append(a.row_indices().numpy()) + xargs.append(a.values().numpy()) else: xargs.append(a.numpy()) # Invoke. @@ -302,7 +302,7 @@ def test_sparse_SpMM(): @run # CHECK-LABEL: test_sparse_eltwise -# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> +# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32> { # CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32> @@ -331,6 +331,12 @@ def test_sparse_SpMM(): # CHECK: [-61. -62.] # CHECK: [-63. -64.]{{\]\]}} # +# CHECK: torch.mlir.batch +# CHECK: {{\[\[}}[ -1. -2.] +# CHECK: [ -3. -4.] +# ... +# CHECK: [-61. -62.] +# CHECK: [-63. -64.]{{\]\]}} def test_sparse_eltwise(): class EltNet(torch.nn.Module): def __init__(self): @@ -345,8 +351,8 @@ def test_sparse_eltwise(): ) # This yields a **batched** CSR. - sparse_input = dense_input.to_sparse_csr(dense_dim=0) - m = export_and_import(net, sparse_input) + batch_input = dense_input.to_sparse_csr(dense_dim=0) + m = export_and_import(net, batch_input) print(m) # This yields a plain CSR with dense **sub**tensor @@ -358,11 +364,12 @@ def test_sparse_eltwise(): # # TODO: note several issues that need to be fixed # (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result - # (2) for dense_dim=0, this will need a dense(batched) property - sparse_input = dense_input.to_sparse_csr(dense_dim=1) res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) + res3 = sparse_jit(net, batch_input) print("torch.sparse") print(res1) print("torch.mlir") print(res2) + print("torch.mlir.batch") + print(res3)