Co-authored-by: Peiming Liu <peiming@google.com>
pull/3093/head
Rob Suderman 2024-04-01 16:34:59 -07:00 committed by GitHub
parent 532d297c46
commit ec4cb8be44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 85 additions and 140 deletions

@ -1 +1 @@
Subproject commit e5ed7b6e2fd368b722b6359556cd0125881e7638
Subproject commit 0030fc4ac74a9ce645adb9d59e108da4d4d11818

2
externals/stablehlo vendored

@ -1 +1 @@
Subproject commit 4ac26f8786d491c5d8376e6e563d1b72af09de75
Subproject commit 271e8634de184fbfafd677d3876170feb6d08c97

View File

@ -388,24 +388,12 @@ STABLEHLO_PASS_SET = {
"ArangeStartNegativeStepIntModule_basic",
"ArangeStartOutDtypeModule_basic",
"ArangeStartOutModule_basic",
"ArangeStartOutViewModule_basic",
"ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic",
"AtenEyeMModuleCPUDevice_basic",
"AtenEyeMModuleDefaultDtype_basic",
"AtenEyeMModuleFalsePinMemory_basic",
"AtenEyeMModuleFloat2D_basic",
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleCPUDevice_basic",
"AtenEyeModuleDefaultDtype_basic",
"AtenEyeModuleFalsePinMemory_basic",
"AtenEyeModuleFloat2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenFloatScalarModule_basic",
"AtenInstanceNormModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
"AtenIntBoolOpModule_basic",
@ -437,8 +425,6 @@ STABLEHLO_PASS_SET = {
"BroadcastListConstructWithMinusOneModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CeilFloatModule_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
@ -454,7 +440,6 @@ STABLEHLO_PASS_SET = {
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Convolution2DStaticModule_basic",
"ConvolutionBackwardModule2DStatic_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"CosineSimilarityStaticModule_basic",
@ -466,12 +451,6 @@ STABLEHLO_PASS_SET = {
"DivIntModule_basic",
"DropoutEvalFloatModule_basic",
"DropoutEvalIntModule_basic",
"DropoutTrainStaticShapeModule_basic",
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
@ -504,8 +483,8 @@ STABLEHLO_PASS_SET = {
"ElementwiseExpModule_basic",
"ElementwiseFloorIntModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseGeluApproximateTanhModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNanToNumModule_Basic",
@ -513,9 +492,9 @@ STABLEHLO_PASS_SET = {
"ElementwiseNeIntTensorStaticModule_basic",
"ElementwiseNegModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwisePreluStaticModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwisePreluStaticModule_basic",
"ElementwiseReciprocalModule_basic",
"ElementwiseReluModule_basic",
"ElementwiseRsqrtModule_basic",
@ -526,8 +505,6 @@ STABLEHLO_PASS_SET = {
"ElementwiseToDtypeI64ToI8Module_basic",
"ElementwiseToDtypeIdentityModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory",
@ -541,13 +518,14 @@ STABLEHLO_PASS_SET = {
"EmptyStridedModule_basic",
"EqIntModule_basic",
"ExpandAsIntModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"Fill_TensorFloat64WithFloat32Static_basic",
"Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic",
"Fill_TensorFloat64WithInt64Static_basic",
"Fill_TensorFloat64WithInt64_basic",
"FlattenRank0Module_basic",
"FlattenStaticModule_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"FullLikeModuleDefaultDtype_basic",
@ -564,29 +542,26 @@ STABLEHLO_PASS_SET = {
"FullModuleFloat3D_basic",
"FullModuleInt2D_basic",
"FullModuleInt3D_basic",
"GatherStaticModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
"GeluBackwardModule_basic",
"GluStaticModule_basic",
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"IntFloatModule_basic",
"IsFloatingPointFloat_True",
"IsFloatingPointInt_False",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"LeakyReluBackwardStaticModule_basic",
"LenStrModule_basic",
"LiftFreshCopyModule_basic",
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
"Matmul4dStatic_basic",
@ -595,8 +570,6 @@ STABLEHLO_PASS_SET = {
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool2dStaticModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
"MeanDimAllReduceKeepdimModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic",
"MeanDtypeModule_basic",
@ -619,10 +592,6 @@ STABLEHLO_PASS_SET = {
"NarrowVerticalTest2_basic",
"NarrowVerticalTest_basic",
"NativeDropoutEvalFloatModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"NativeGroupNormModule_basic",
"NativeLayerNormModule4D_basic",
"NativeLayerNormModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic",
@ -654,7 +623,6 @@ STABLEHLO_PASS_SET = {
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NewZerosStaticModuleLayoutStrided_basic",
"NormalizeModule_basic",
"NumToTensorFloatModule_basic",
"NumToTensorIntModule_basic",
"NumelModule_basic",
@ -682,7 +650,6 @@ STABLEHLO_PASS_SET = {
"PrimMinIntModule_basic",
"PrimsConvertElementTypeModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"RandIntDtypeModule_basic",
@ -690,39 +657,8 @@ STABLEHLO_PASS_SET = {
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RandModule_basic",
"ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic",
"ReduceFrobeniusNormModule_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxAlongDimNegative_basic",
"ReduceMaxAlongDimSignedInt_basic",
"ReduceMaxAlongDim_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceMinFloatModule_basic",
"ReduceMinSignedIntModule_basic",
"ReduceMinUnsignedIntModule_basic",
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeAsModule_basic",
"ReshapeExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"RollModule_basic",
@ -734,8 +670,6 @@ STABLEHLO_PASS_SET = {
"ScalarTensorFloat32Module_basic",
"ScalarTensorInt32Module_basic",
"ScalarTensorInt64Module_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SelectScattertStaticModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
@ -761,10 +695,8 @@ STABLEHLO_PASS_SET = {
"SqrtIntConstantModule_basic",
"SqrtIntModule_basic",
"SqueezeDimModule_identity",
"SqueezeDimModule_static",
"SqueezeDimModule_unitDim",
"SqueezeModule_allUnitDim",
"SqueezeModule_static",
"SubFloatModule_basic",
"SubIntModule_basic",
"TModuleRank0_basic",
@ -784,17 +716,14 @@ STABLEHLO_PASS_SET = {
"TestF16Return_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"TestMultipleTensorReturn_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"ToCopyBoolDTypeStaticModule_basic",
"ToDtypeBoolLayoutNoneStaticModule_basic",
"ToDtypeLayoutCPUModule_basic",
"ToDtypeLayoutNoneModule_basic",
"ToDtypeLayoutStridedModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
"TriuBroadcastModule_basic",
"TriuModule_basic",
"TupleModule_basic",
"TypeAsDifferentModule_basic",
"TypeAsSameModule_basic",
@ -806,41 +735,8 @@ STABLEHLO_PASS_SET = {
"TypeConversionI1ToI64Module_basic",
"TypeConversionI32ToI64Module_basic",
"TypeConversionI64ToI32Module_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic",
"UniformNoCorrelationModule_basic",
"UniformStaticShapeModule_basic",
"UnsafeView1DFoldModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
"UnsafeViewExpandModule_basic",
"View1DFoldModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandDynamicDimModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewExpandOnesModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
"ZeroFloat32Module_basic",
"ZeroInt32Module_basic",
"ZeroInt64Module_basic",
@ -854,18 +750,38 @@ STABLEHLO_PASS_SET = {
"ZerosModuleFloat3D_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"LinspaceDtypeModule_basic",
"LinspaceEmptyModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
}
STABLEHLO_CRASHING_SET = {
"AtenEmbeddingBagSumExample_basic",
# Something is broken with stablehlo.reduce right now as any
# of these tests can randomly fail. Removing until someone can debug:
"ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxAlongDimNegative_basic",
"ReduceMaxAlongDimSignedInt_basic",
"ReduceMaxAlongDim_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceMinFloatModule_basic",
"ReduceMinSignedIntModule_basic",
"ReduceMinUnsignedIntModule_basic",
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
}
# Write the TOSA set as a "passing" set as it is very early in development
@ -900,7 +816,6 @@ TOSA_PASS_SET = {
"ArgmaxIntModule_multiple_maxs",
"ArgmaxModule_basic",
"ArgmaxModule_keepDim",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic",
"AtenEyeMModuleCPUDevice_basic",
"AtenEyeMModuleDefaultDtype_basic",
@ -1185,7 +1100,6 @@ TOSA_PASS_SET = {
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"ReduceAmaxKeepDim_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimFloatModule_basic",
@ -1196,8 +1110,11 @@ TOSA_PASS_SET = {
"ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"ResNet18StaticModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeAsModule_basic",
"ReshapeCollapseModule_basic",
"ReshapeExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"RsubFloatModule_basic",
@ -1211,8 +1128,6 @@ TOSA_PASS_SET = {
"SiluModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStaticModule_basic",
"SoftmaxIntModule_basic",
"SoftmaxIntNegDimModule_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
@ -1260,13 +1175,19 @@ TOSA_PASS_SET = {
"UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic",
"UnsafeView1DFoldModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
"UnsafeViewExpandModule_basic",
"View1DFoldModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewDoubleMergeStaticModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandDynamicDimModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
@ -1275,6 +1196,9 @@ TOSA_PASS_SET = {
"ViewExpandOnesModule_basic",
"ViewFiveTestStaticModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
@ -1287,8 +1211,6 @@ TOSA_PASS_SET = {
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"_LogSoftmaxModuleStable_basic",
"_LogSoftmaxModule_basic",
"_SoftmaxModule_basic",
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
@ -1310,10 +1232,14 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
"ViewSizeDimFollowedByExpandedOnesModule_basic",
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeFromOtherTensor_basic",
}) - {
### Test failing in make_fx_tosa but not in tosa
"FlattenDynamicModuleCollapseAll_basic",
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",
"MatmulStaticBroadcast_basic",
@ -1343,6 +1269,18 @@ MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
# failed to legalize operation 'torch.operator'
"ElementwisePreluModule_basic",
"ElementwisePreluStaticModule_basic",
# Shape Related failures
"ReshapeExpandModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
"ViewCollapseModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandDynamicDimModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
}
LTC_CRASHING_SET = {

View File

@ -328,7 +328,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
)
if batch_dim > 0:
batch = ",".join(f"d{d}:dense" for d in range(batch_dim))
batch = ",".join(f"d{d}:batch" for d in range(batch_dim))
lvls = f"{batch},{lvls}"
if dense_dim > 0:

View File

@ -160,7 +160,6 @@ def sparse_jit(f, *args, **kwargs):
xargs = []
for a in args:
if a.layout is torch.sparse_coo:
xargs.append(a.values().numpy())
# Construct the additional position array required by MLIR with data
# array([0, nnz]).
xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy())
@ -168,14 +167,15 @@ def sparse_jit(f, *args, **kwargs):
# MLIR SoA COO representation.
for idx in a.indices():
xargs.append(idx.numpy())
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
xargs.append(a.values().numpy())
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
xargs.append(a.crow_indices().numpy())
xargs.append(a.col_indices().numpy())
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
xargs.append(a.values().numpy())
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
xargs.append(a.ccol_indices().numpy())
xargs.append(a.row_indices().numpy())
xargs.append(a.values().numpy())
else:
xargs.append(a.numpy())
# Invoke.
@ -302,7 +302,7 @@ def test_sparse_SpMM():
@run
# CHECK-LABEL: test_sparse_eltwise
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }>
# CHECK: func.func @main(
# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32> {
# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32>
@ -331,6 +331,12 @@ def test_sparse_SpMM():
# CHECK: [-61. -62.]
# CHECK: [-63. -64.]{{\]\]}}
#
# CHECK: torch.mlir.batch
# CHECK: {{\[\[}}[ -1. -2.]
# CHECK: [ -3. -4.]
# ...
# CHECK: [-61. -62.]
# CHECK: [-63. -64.]{{\]\]}}
def test_sparse_eltwise():
class EltNet(torch.nn.Module):
def __init__(self):
@ -345,8 +351,8 @@ def test_sparse_eltwise():
)
# This yields a **batched** CSR.
sparse_input = dense_input.to_sparse_csr(dense_dim=0)
m = export_and_import(net, sparse_input)
batch_input = dense_input.to_sparse_csr(dense_dim=0)
m = export_and_import(net, batch_input)
print(m)
# This yields a plain CSR with dense **sub**tensor
@ -358,11 +364,12 @@ def test_sparse_eltwise():
#
# TODO: note several issues that need to be fixed
# (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result
# (2) for dense_dim=0, this will need a dense(batched) property
sparse_input = dense_input.to_sparse_csr(dense_dim=1)
res1 = net(sparse_input)
res2 = sparse_jit(net, sparse_input)
res3 = sparse_jit(net, batch_input)
print("torch.sparse")
print(res1)
print("torch.mlir")
print(res2)
print("torch.mlir.batch")
print(res3)