diff --git a/build_tools/ci/test_posix.sh b/build_tools/ci/test_posix.sh index 8cc68d77b..73818051d 100755 --- a/build_tools/ci/test_posix.sh +++ b/build_tools/ci/test_posix.sh @@ -20,6 +20,10 @@ echo "::group::Run TOSA e2e integration tests" python -m e2e_testing.main --config=tosa -v echo "::endgroup::" +echo "::group::Run Stablehlo e2e integration tests" +python -m e2e_testing.main --config=stablehlo -v +echo "::endgroup::" + case $torch_version in nightly) # Failing with: NotImplementedError: diff --git a/externals/stablehlo b/externals/stablehlo index e191eb4c3..4ac26f878 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit e191eb4c3c3f3144503a8a117d760de5ddcc7e89 +Subproject commit 4ac26f8786d491c5d8376e6e563d1b72af09de75 diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 0db753e47..e4ba46138 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -31,6 +31,10 @@ set(LinkedLibs TorchMLIRTorchOnnxToTorch ) +if(TORCH_MLIR_ENABLE_STABLEHLO) +list(APPEND LinkedLibs StablehloPasses StablehloLinalgTransforms) +endif() + if(TORCH_MLIR_ENABLE_REFBACKEND) add_subdirectory(RefBackend) list(APPEND LinkedLibs TorchMLIRRefBackend) diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index e413fe532..0b27d0748 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -24,6 +24,9 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include +#include + using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -116,6 +119,12 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } + std::vector outputShape(inputShape.begin(), inputShape.end()); + outputShape.erase(outputShape.begin() + dim); + auto outputTy = RankedTensorType::get(outputShape, inputElemTy); + auto outputIndexTy = + RankedTensorType::get(outputShape, rewriter.getIntegerType(64)); + auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); auto indexTensor = rewriter.create( @@ -125,7 +134,8 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, inputShapeTensor, static_cast(dim)); auto stablehloReduceOp = rewriter.create( - op->getLoc(), ValueRange{input, indexTensor}, + op->getLoc(), TypeRange{outputTy, outputIndexTy}, + ValueRange{input, indexTensor}, ValueRange{ initValue, initIndex, @@ -412,7 +422,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input, + initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -473,7 +484,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -535,7 +547,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -614,6 +627,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputTy.getDimSize(i)); + } + } + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); @@ -625,7 +646,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op.getLoc(), + RankedTensorType::get(reduceResultShape, outTy.getElementType()), input, + initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -714,6 +737,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( // stable with unordered dims. std::sort(dims.begin(), dims.end()); + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputRank; i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputType.getDimSize(i)); + } + } + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure( @@ -728,8 +759,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } auto reduceOp = rewriter.create( - op->getLoc(), squareOp.getResult(), initValue, - rewriter.getDenseI64ArrayAttr(dims)); + op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType), + squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -832,6 +863,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( std::sort(dims.begin(), dims.end()); } + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputType.getRank(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputType.getDimSize(i)); + } + } + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure( @@ -848,7 +887,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ord, nullptr); auto reduceOp = rewriter.create( - op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType), + powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index ace6c1a40..eebfc9408 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -29,6 +29,11 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/RefBackend/Passes.h" +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +#include "stablehlo/conversions/linalg/transforms/Passes.h" +#include "stablehlo/transforms/Passes.h" +#endif + void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); @@ -52,6 +57,11 @@ void mlir::torch::registerAllPasses() { mlir::torch::onnx_c::registerTorchOnnxToTorchPasses(); mlir::torch::TMTensor::registerPasses(); +#ifdef TORCH_MLIR_ENABLE_STABLEHLO + mlir::stablehlo::registerChloLegalizeToStablehloPass(); + mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); +#endif + #ifdef TORCH_MLIR_ENABLE_REFBACKEND mlir::torch::RefBackend::registerRefBackendPasses(); #endif diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index 885f34477..b9cd04c1e 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -25,6 +25,7 @@ from torch_mlir_e2e_test.configs import ( from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend +from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend from .xfail_sets import ( LINALG_XFAIL_SET, @@ -43,7 +44,7 @@ from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -52,6 +53,7 @@ def _get_argparse(): Meaning of options: "linalg": run through torch-mlir"s default Linalg-on-Tensors backend. "tosa": run through torch-mlir"s default TOSA backend. +"stablehlo": run through torch-mlir"s default Stablehlo backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. @@ -90,6 +92,10 @@ def main(): config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = LINALG_XFAIL_SET crashing_set = set() + elif args.config == "stablehlo": + config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) + xfail_set = all_test_unique_names - STABLEHLO_PASS_SET + crashing_set = STABLEHLO_CRASHING_SET elif args.config == "tosa": config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3048ac04a..36a1d5662 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -378,84 +378,16 @@ TORCHDYNAMO_CRASHING_SET = { } STABLEHLO_PASS_SET = { - "TileBigDimsSizeModule_basic", - "TileSmallDimsSizeModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddIntModule_basic", - "AtenIntBoolOpModule_basic", - "AtenIntTensorByteDtypeModule_basic", - "AtenIntTensorCharDtypeModule_basic", - "BoolFloatFalseModule_basic", - "BoolFloatTrueModule_basic", - "BoolIntFalseModule_basic", - "BoolIntTrueModule_basic", - "CeilFloatModule_basic", - "DivFloatModule_basic", - "DivIntModule_basic", - "EqIntModule_basic", - "GeFloatIntModule_basic", - "GeFloatModule_basic", - "GeIntModule_basic", - "GtFloatIntModule_basic", - "GtIntModule_basic", - "MulIntModule_basic", - "NeFloatIntModule_basic", - "NeIntModule_basic", - "SqrtIntModule_basic", - "SubFloatModule_basic", - "SubIntModule_basic", - "TensorToBoolZeroRank_basic", - "TensorToIntZeroRank_basic", - "TensorToFloatZeroRank_basic", - "IndexTensorStaticContiguousWithNoneModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", "AliasModule_basic", - "TensorIntModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", - "AtenIntBoolOpConstFalseModule_basic", - "AtenIntBoolOpConstTrueModule_basic", - "AtenFloatScalarModule_basic", - "ScalarImplicitFloatModule_basic", - "ScalarImplicitIntModule_basic", - "AtenSubFloatModule_basic", - "BoolFloatConstantModule_basic", - "BoolIntConstantModule_basic", - "ContainsIntList_False", - "ContainsIntList_True", - "IntFloatModule_basic", - "IsFloatingPointFloat_True", - "IsFloatingPointInt_False", - "LenStrModule_basic", - "MeanDimAllReduceKeepdimModule_basic", - "MeanDimAllReduceModule_basic", - "MeanDimDtypeModule_basic", - "MeanDimKeepdimModule_basic", - "MeanDimModule_basic", - "MeanDimNegativeModule_basic", - "NumelZeroRankModule_basic", - "PowIntFloatModule_basic", - "PrimMaxIntModule_basic", - "PrimMinIntModule_basic", - "PrimMinIntDynamicModule_basic", - "SortIntListReverse_basic", - "SortIntList_basic", - "SqrtIntConstantModule_basic", - "StdBiasedModule_basic", - "StdDimBiasedModule_basic", - "TestMultipleTensorAndPrimitiveTypesReturn_basic", - "VarBiasedModule_basic", - "VarDimBiasedModule_basic", - "VarMeanBiasedModule_basic", - "VarMeanDimBiasedModule_basic", - "ConstantBoolParameterModule_basic", - "MaskedFillScalarIntValueStaticModule_basic", - "MaskedFillScalarFloatValueStaticModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AddSizeIntModule_basic", - "AddSizeIntNegDimModule_basic", "ArangeDtypeFloatModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", @@ -467,139 +399,161 @@ STABLEHLO_PASS_SET = { "ArangeStartIntModule_basic", "ArangeStartNegativeStepFloatModule_basic", "ArangeStartNegativeStepIntModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutModule_basic", + "ArangeStartOutViewModule_basic", "ArangeStartStepFloatModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", - "BatchMlpLayerModule_basic", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", - "ResNet18StaticModule_basic", - "BmmFloatModule_basic", - "BmmIntModule_basic", - "BroadcastToModule_basic", + "ArgmaxModule_with_dim", + "AtenComplex64Module_basic", + "AtenEyeMModuleCPUDevice_basic", + "AtenEyeMModuleDefaultDtype_basic", + "AtenEyeMModuleFalsePinMemory_basic", + "AtenEyeMModuleFloat2D_basic", + "AtenEyeMModuleInt2D_basic", + "AtenEyeModuleCPUDevice_basic", + "AtenEyeModuleDefaultDtype_basic", + "AtenEyeModuleFalsePinMemory_basic", + "AtenEyeModuleFloat2D_basic", + "AtenEyeModuleInt2D_basic", + "AtenFloatScalarModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemFpOpModule_basic", + "AtenItemIntOpModule_basic", + "AtenMmFloatTypes_basic", + "AtenMmIntTypes_basic", + "AtenRoundIntModule_basic", + "AtenSubFloatModule_basic", + "AtenToDeviceModule_basic", + "AvgPool1dStaticModule_basic", + "AvgPool2dStaticModule_basic", + "BaddbmmBroadcast1DInputModule_basic", + "BaddbmmBroadcast2DInputModule_basic", + "BaddbmmStaticModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnMixedModule_basic", + "BoolTensorReturnTrueModule_basic", + "BroadcastListConstructWithMinusOneModule_basic", "BroadcastToSameRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", - "BroadcastListConstructWithMinusOneModule_basic", "BucketizeTensorStaticFloatModule_basic", "BucketizeTensorStaticModule_basic", + "CeilFloatModule_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "CloneModule_basic", + "ConstantBoolParameterModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "ContiguousModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", + "Convolution2DStaticModule_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionModule2DTransposeStridedStatic_basic", + "CosineSimilarityStaticBroadcastModule_basic", + "CosineSimilarityStaticModule_basic", + "CumsumInputDtypeInt32Module_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", - "CosineSimilarityStaticModule_basic", - "CosineSimilarityStaticBroadcastModule_basic", "DetachModule_basic", - "ElementwiseIsnanModule_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "DropoutEvalFloatModule_basic", + "DropoutEvalIntModule_basic", + "DropoutTrainStaticShapeModule_basic", + "EinsumStaticContractRhsModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalNotOpModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenWhereSelfModule_basic", - "ElementwiseWhereScalarOtherStaticModule_basic", - "ElementwiseWhereScalarSelfStaticModule_basic", - "ElementwiseNanToNumModule_Basic", + "ElementwiseBinaryStaticShapeModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic", - "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseNotInt32Module_basic", - "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseOrStaticShapeModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", - "ElementwiseClampModule_basic", - "ElementwiseClampMinModule_basic", - "ElementwiseClampMaxModule_basic", - "ElementwiseSignModule_basic", - "ElementwisePowModule_basic", - "ElementwisePowTensorStaticModule_basic", - "ElementwisePowTensorBroadcastStaticModule_basic", - "ElementwiseExpModule_basic", - "ElementwiseFlattenBroadcastModule_basic", - "ElementwiseLeakyReluModule_basic", - "ElementwiseEluModule_basic", - "ElementwiseEluNonDefaultModule_basic", - "ElementwiseSeluModule_basic", - "ElementwiseLogModule_basic", - "ElementwiseNegModule_basic", - "ElementwiseRsqrtModule_basic", - "ElementwiseSigmoidModule_basic", - "ElementwiseSqrtModule_basic", - "ElementwiseSinModule_basic", - "ElementwiseCosModule_basic", "ElementwiseCeilModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampTensorInt8Module_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseCosModule_basic", + "ElementwiseErfModule_basic", + "ElementwiseExpModule_basic", + "ElementwiseFloorIntModule_basic", "ElementwiseFloorModule_basic", - "ElementwiseUnaryModule_basic", - "ElementwiseUnsqueezeBroadcastModule_basic", - "ElementwiseUnsqueezeNegDimsModule_basic", - "ElementwiseToDtypeF32ToI64Module_basic", - "ElementwiseAddModule_basic", - "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalarInt64Module_basic", - "ElementwiseAddScalarIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", - "ElementwiseDivScalarModule_basic", - "ElementwiseAtenDivIntScalarModule_basic", - "ElementwiseEqDiffWidthScalarModule_basic", - "ElementwiseEqFloatScalarModule_basic", - "ElementwiseEqIntScalarModule_basic", - "ElementwiseNeFloatScalarModule_basic", + "ElementwiseGeluModule_basic", + "ElementwiseLeakyReluStaticModule_basic", + "ElementwiseLogModule_basic", + "ElementwiseNanToNumModule_Basic", "ElementwiseNeFloatTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic", - "ElementwiseEqBoolScalarModule_basic", - "ElementwiseErfModule_basic", - "ElementwiseGeluModule_basic", - "ElementwiseGtFloatScalarModule_basic", - "ElementwiseGtIntScalarModule_basic", - "ElementwiseGtMixed2ScalarModule_basic", - "ElementwiseGeFloatIntScalarModule_basic", - "ElementwiseGeFloatScalarModule_basic", - "ElementwiseGeIntScalarModule_basic", - "ElementwiseGeMixedIntScalarModule_basic", - "ElementwiseLeakyReluStaticModule_basic", - "ElementwiseLeFloatIntScalarModule_basic", - "ElementwiseLeFloatScalarModule_basic", - "ElementwiseLeIntScalarModule_basic", - "ElementwiseLeMixedIntScalarModule_basic", - "ElementwiseLtDiffWidthScalarModule_basic", - "ElementwiseLtFloatScalarModule_basic", - "ElementwiseLtIntScalarModule_basic", - "ElementwiseMulScalarModule_basic", - "ElementwiseMulScalarModule_float", - "ElementwiseMulScalarModule_int", - "ElementwiseNeIntScalarModule_basic", + "ElementwiseNegModule_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorStaticModule_basic", "ElementwiseReciprocalModule_basic", - "ElementwiseRelu6Module_basic", "ElementwiseReluModule_basic", - "ElementwiseRemainderScalarModule_Bool_basic", - "ElementwiseRemainderScalarModule_Float_basic", - "ElementwiseRemainderScalarModule_Int_Float_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "ElementwiseSubScalarFloatModule_basic", - "ElementwiseSubScalarIntModule_basic", - "ElementwiseWhereScalarModule_basic", - "ElementwiseAbsFloatModule_basic", - "ElementwiseAbsIntModule_basic", - "EmbeddingModule1DIndices_basic", - "EmbeddingModuleI32Static_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", - "EmbeddingModuleF16_basic", + "ElementwiseRsqrtModule_basic", + "ElementwiseSigmoidModule_basic", + "ElementwiseSinModule_basic", + "ElementwiseSqrtModule_basic", + "ElementwiseToDtypeF32ToI64Module_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ElementwiseToDtypeIdentityModule_basic", + "ElementwiseUnaryModule_basic", + "ElementwiseWhereScalarOtherStaticModule_basic", + "ElementwiseWhereScalarSelfStaticModule_basic", "EmptyLikeMemoryFormatModule_basic", "EmptyLikeModule_defaultDtype", "EmptyLikeModule_falsePinMemory", "EmptyLikeModule_float", "EmptyLikeModule_int", + "EmptyModule_contiguous", + "EmptyModule_defaultDtype", + "EmptyModule_falsePinMemory", + "EmptyModule_float", + "EmptyModule_int", + "EmptyStridedModule_basic", + "EqIntModule_basic", "ExpandAsIntModule_basic", - "ExpandModule_basic", - "EinsumStaticModule_basic", - "EinsumStaticFourDimensionModule_basic", - "EinsumStaticContractRhsModule_basic", + "Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithFloat32_basic", "Fill_TensorFloat64WithFloat64_basic", - "Fill_TensorFloat64WithInt64_basic", - "Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithInt64Static_basic", + "Fill_TensorFloat64WithInt64_basic", + "FlattenRank0Module_basic", + "FlattenStaticModule_basic", "FlipModuleStaticShape_basic", "FlipNegativeIndexModule_basic", "FullLikeModuleDefaultDtype_basic", @@ -616,188 +570,67 @@ STABLEHLO_PASS_SET = { "FullModuleFloat3D_basic", "FullModuleInt2D_basic", "FullModuleInt3D_basic", - "NewFullModuleDefaultDtype_basic", - "NewFullModuleFalsePinMemory_basic", - "NewFullModuleFloat2D_basic", - "NewFullModuleFloat3DStatic_basic", - "NewFullModuleFloat3D_basic", - "NewFullModuleInt2DStatic_basic", - "NewFullModuleInt2D_basic", - "NewFullModuleInt3D_basic", - "GroupNormModule_basic", "GatherStaticModule_basic", - "GatherModule_basic", - "Gather2DInputModdule_basic", - "GatherRandomIndexModule_basic", - "GatherNegativeDimModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", "GeluBackwardModule_basic", - "HardswishModule_basic", - "HardswishRandomModule_basic", - "HardTanhIntModule_basic", - "HardTanhModule_basic", - "HardsigmoidModule_basic", - "HardsigmoidRandomModule_basic", - "IndexSelectDynamicIndexSizeModule_basic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", - "IndexSelectNegativeDimModule_basic", - "IndexTensorStaticModule_basic", + "GluStaticModule_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorStaticContiguousWithNoneModule_basic", + "IndexTensorStaticModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", + "IntFloatModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", "LayerNormLastDimModule_basic", "LayerNormModule_basic", "LayerNormNormalizeOverAllDimsModule_basic", "LeakyReluBackwardStaticModule_basic", - "LinalgVectorNormModule_basic", - "LinalgVectorNormKeepDimModule_basic", - "MatmulBroadcastBatchDim_basic", - "MatmulSingleDynamicBatchDim_basic", - "Matmul_3d", - "Matmul_4d", + "LenStrModule_basic", + "LiftFreshCopyModule_basic", + "MaskedFillScalarFloatValueStaticModule_basic", + "MaskedFillScalarIntValueStaticModule_basic", + "Matmul4dStatic_basic", + "Matmul_2d", + "Matmul_dot", + "Matmul_matvec", + "Matmul_vecmat", + "MaxPool2dStaticModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + "MeanDimAllReduceKeepdimModule_basic", + "MeanDimAllReduceModule_basic", "MeanDimEmptyDimModule_basic", "MeanDtypeModule_basic", "MeanDynamicSizesModule_basic", - "MeanLargeInputModule_basic", "MeanModule_basic", - "Mlp1LayerModule_basic", - "Mlp2LayerModule_basic", - "MmTanhModule_basic", - "Mv_basic", - "NativeLayerNormModule4D_basic", - "NativeLayerNormModule_basic", - "OneHotModule_basic", - "PrimsConvertElementTypeModule_basic", - "ReduceFrobeniusNormKeepDimModule_basic", - "ReduceSumDimIntListElementTypeBoolModule_basic", - "ReduceSumElementTypeBoolModule_basic", - "ReduceSumDimIntListEmptyDimModule_basic", - "ReduceSumDimIntListDtypeFloatModule_basic", - "ReduceSumDimIntListDtypeIntModule_basic", - "ReduceSumDimIntListKeepDimFloatModule_basic", - "ReduceSumDimIntListKeepDimIntModule_basic", - "ReduceSumDtypeFloatModule_basic", - "ReduceSumDtypeIntModule_basic", - "ReduceL1NormModule_basic", - "ReduceL1NormWithDTypeModule_basic", - "ReduceL2NormModule_basic", - "ReduceL3NormAllDimsModule_basic", - "ReduceL3NormKeepDimModule_basic", - "ReduceLN3NormModule_basic", - "NormScalarOptDimKeepDimModule_basic", - "NormScalarOptDimModule_basic", - "NormalizeModule_basic", - "ScalarConstantTupleModule_basic", - "SelectIntModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", - "SliceSingleIdxModule_basic", - "SqueezeDimModule_dynamic", - "SqueezeDimModule_negDim", - "ToCopyBoolDTypeStaticModule_basic", - "ToCopyModule_basic", - "ToCopyWithDTypeFalsePinMemoryModule_basic", - "ToCopyWithDTypeModule_basic", - "ReduceFrobeniusNormModule_basic", - "FlattenStaticModule_basic", - "FlattenRank0Module_basic", - "TensorsConcatNegativeDimModule_basic", - "TensorsConcatPromoteDTypeModule_basic", - "TensorsConcatStaticModule_basic", - "TensorsConcatNegativeDimStaticModule_basic", - "TensorsStackModule_basic", - "TensorsStackNegativeDimModule_basic", - "TensorsStackPromoteDTypeModule_basic", - "LiftFreshCopyModule_basic", "Mlp2LayerModuleNoBias_basic", - "NumelModule_basic", - "SiluModule_basic", - "SquareModule_basic", - "SqueezeModule_allUnitDim", - "SqueezeDimModule_unitDim", - "ViewCollapseOnesMiddleModule_basic", - "ViewDoubleMergeStaticModule_basic", - "ViewExpandDynamicDimModule_basic", - "ViewFlattenAndExpandModule_basic", - "ViewFiveTestStaticModule_basic", - "ViewOffsetTestStaticModule_basic", - "ViewTwoFiveThreeStaticModule_basic", - "ViewTwoToThreeStaticModule_basic", - "ViewExpandOnesMiddleOppModule_basic", - "ViewOffsetBackwardTestStaticModule_basic", - "NumToTensorFloatModule_basic", - "AtenToDeviceModule_basic", - "AvgPool1dStaticModule_basic", - "AvgPool2dStaticModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - "Conv2dWithPaddingDilationStrideStaticModule_grouped", - "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", - "Convolution2DStaticModule_basic", - "ConvolutionModule2DTransposeStridedStatic_basic", - "ElementwiseCloneContiguousModule_basic", - "ElementwiseCloneChannelsLastMemoryFormatModule_basic", - "ElementwiseCloneModule_basic", - "ElementwiseBinaryStaticShapeModule_basic", - "ReturnThreeTensorFloat32_basic", - "BoolTensorReturnFalseModule_basic", - "BoolTensorReturnTrueModule_basic", - "BoolTensorReturnMixedModule_basic", - "SqueezeModule_static", - "TModuleRank1_basic", - "TModuleRank0_basic", - "ElementwiseToDtypeIdentityModule_basic", - "View1DFoldModule_basic", - "UnsafeView1DFoldModule_basic", - "UnflattenStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "RsubFloatModule_basic", - "RsubFloatModule_noalpha_basic", - "RsubIntModule_basic", - "RsubIntModule_noalpha_basic", - "RsubInt0d_NumToTensor_Module_basic", - "ScalarTensorDefaultDtypeModule_basic", - "ScalarTensorFloat32Module_basic", - "ScalarTensorInt32Module_basic", - "ScalarTensorInt64Module_basic", - "SelectScattertModule_basic", - "SelectScattertStaticModule_basic", - "SliceStaticModule_basic", - "SliceModule_basic", - "SliceNegIdxModule_basic", - "SliceOutOfLowerBoundStartIndexModule_basic", - "SliceOutOfUpperBoundIndexModule_basic", - "SliceOutOfUpperBoundIndexStaticModule_basic", - "SliceStartEqEndModule_basic", - "SliceSizeTwoStepModule_basic", - "SliceWholeTensorModule_basic", - "SliceScatterModule_basic", - "SliceScatterNegativeDimModule_basic", - "SliceScatterNegativeEndModule_basic", - "SliceScatterStaticModule_basic", - "SliceScatterStepVariationModule_basic", - "SliceScatterZeroDimModule_basic", - "SqueezeDimModule_static", - "SqueezeDimModule_identity", - "SqueezeModule_broadcast", - "ReturnTwoTensorF32I64_basic", - "Matmul4dStatic_basic", - "Matmul_dot", - "Matmul_2d", - "Matmul_matvec", - "Matmul_vecmat", - "MaxPool2dWithIndicesStaticModule_basic", "MmDagModule_basic", "MmModule_basic", "MmModule_chained", - "MaxPool2dStaticModule_basic", - "EmptyModule_contiguous", - "EmptyModule_defaultDtype", - "EmptyModule_falsePinMemory", - "EmptyModule_int", - "EmptyModule_float", + "MmTanhModule_basic", + "MoveDimIntModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "MulFloatModule_basic", + "MulIntModule_basic", + "Mv_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowTensorHorizontalModule_basic", + "NarrowTensorVerticalModule_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NativeDropoutEvalFloatModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "NativeGroupNormModule_basic", + "NativeLayerNormModule4D_basic", + "NativeLayerNormModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", "NewEmptyModuleDefaultDtype_basic", "NewEmptyModuleFalsePinMemory_basic", "NewEmptyModuleFloat2D_basic", @@ -808,117 +641,169 @@ STABLEHLO_PASS_SET = { "NewEmptyModuleNonDefaultFloatDtype_basic", "NewEmptyModuleNonDefaultIntDtype_basic", "NewEmptyStridedModuleDefaultDtype_basic", - "EmptyStridedModule_basic", - "EmptyStridedSizeIntStrideModule_basic", - "PermuteModule_basic", - "PermuteNegativeIndexModule_basic", - "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", - "ZeroFloat32Module_basic", - "ZeroInt32Module_basic", - "ZeroInt64Module_basic", - "ZerosLikeModule_defaultDtype", - "ZerosLikeModule_falsePinMemory", - "ZerosLikeModule_float", - "ZerosLikeModule_int", - "ZerosModuleDefaultDtype_basic", - "ZerosModuleInt2D_basic", - "ZerosModuleInt3D_basic", - "ZerosModuleFloat2D_basic", - "ZerosModuleFloat3D_basic", - "ZerosModuleFalsePinMemory_basic", - "OnesModuleDefaultDtype_basic", - "OnesModuleInt_basic", - "OnesModuleFloat_basic", - "OnesModuleFalsePinMemory_basic", + "NewFullModuleDefaultDtype_basic", + "NewFullModuleFalsePinMemory_basic", + "NewFullModuleFloat3DStatic_basic", + "NewFullModuleFloat3D_basic", + "NewFullModuleInt2D_basic", + "NewFullModuleInt3D_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleFalsePinMemory_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosStaticModuleLayoutStrided_basic", + "NormalizeModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "NumpyTRank0Module_basic", + "NumpyTRank1Module_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNDynamicModule_basic", + "NumpyTRankNStaticModule_basic", "OnesLikeModule_defaultDtype", "OnesLikeModule_falsePinMemory", "OnesLikeModule_float", "OnesLikeModule_int", - "NewZerosModuleDefaultDtype_basic", - "NewZerosModuleInt2D_basic", - "NewZerosModuleInt3D_basic", - "NewZerosModuleFloat2D_basic", - "NewZerosModuleFloat3D_basic", - "NewZerosModuleFalsePinMemory_basic", - "NewOnesModuleDefaultDtype_basic", - "NewOnesModuleInt2D_basic", - "NewOnesModuleInt3D_basic", - "NewOnesModuleFloat2D_basic", - "NewOnesModuleFloat3D_basic", - "NewOnesModuleFalsePinMemory_basic", - "NewZerosStaticModuleLayoutStrided_basic", - "DropoutEvalIntModule_basic", - "DropoutEvalFloatModule_basic", - "DropoutTrainStaticShapeModule_basic", - "NativeDropoutEvalFloatModule_basic", - "NativeDropoutTrainStaticShapeModule_basic", - "ContiguousModule_basic", - "DropoutModule_basic", - "ViewCollapseModule_basic", - "ViewCollapseInferredDimModule_basic", - "ViewDynamicExpandCollapseModule_basic", - "ViewDynamicExpandModule_basic", - "ViewExpandModule_basic", - "ViewExpandOnesModule_basic", - "ViewExpandOnesBeforeAndAfterModule_basic", - "ViewExpandOnesMiddleModule_basic", - "ViewExpandCollapseModule_basic", - "ViewExpandCollapseWithOnesModule_basic", - "ViewExpandInferredDimModule_basic", - "ViewNegativeStaticModule_basic", - "ViewNoChangeStaticModule_basic", - "ViewNoChange1dModule_basic", - "ViewNoChange2dModule_basic", - "ViewNoChange3dModule_basic", - "UnsafeViewExpandModule_basic", + "OnesModuleCPUDevice_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleFloat_basic", + "OnesModuleInt_basic", + "Permute0RankModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "PowIntFloatModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsConvertElementTypeModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "RandIntDtypeModule_basic", + "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "RandModule_basic", + "ReduceAmaxMultiDim_basic", + "ReduceAmaxOutOfOrderDim_basic", + "ReduceAmaxSingleDim_basic", + "ReduceFrobeniusNormModule_basic", "ReduceMaxAllDims_basic", + "ReduceMaxAlongDimNegative_basic", + "ReduceMaxAlongDimSignedInt_basic", + "ReduceMaxAlongDim_basic", "ReduceMaxFloatModule_basic", "ReduceMaxSignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic", - "ReduceMinAllDims_basic", "ReduceMinFloatModule_basic", "ReduceMinSignedIntModule_basic", "ReduceMinUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", + "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", "RepeatModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", - "ReshapeExpandModule_basic", "ReshapeAsModule_basic", - "TestMultipleTensorReturn_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", - "BaddbmmStaticModule_basic", - "BaddbmmBroadcast1DInputModule_basic", - "BaddbmmBroadcast2DInputModule_basic", - "NarrowHorizontalTest2_basic", - "NarrowHorizontalTest_basic", - "NarrowVerticalTest2_basic", - "NarrowVerticalTest_basic", - "NarrowTensorHorizontalModule_basic", - "NarrowTensorVerticalModule_basic", - "NumToTensorIntModule_basic", - "NumpyTRank0Module_basic", - "NumpyTRank1Module_basic", - "NumpyTRank2Module_basic", - "NumpyTRankNStaticModule_basic", - "NumpyTRankNDynamicModule_basic", + "ReshapeExpandModule_basic", + "ReturnThreeTensorFloat32_basic", + "ReturnTwoTensorF32I64_basic", + "RollModule_basic", + "RsubInt0d_NumToTensor_Module_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "ScalarTensorDefaultDtypeModule_basic", + "ScalarTensorFloat32Module_basic", + "ScalarTensorInt32Module_basic", + "ScalarTensorInt64Module_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SelectScattertStaticModule_basic", + "SliceModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceStartEqEndModule_basic", + "SliceStaticModule_basic", + "SliceWholeTensorModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "SqueezeDimModule_identity", + "SqueezeDimModule_static", + "SqueezeDimModule_unitDim", + "SqueezeModule_allUnitDim", + "SqueezeModule_static", + "SubFloatModule_basic", + "SubIntModule_basic", + "TModuleRank0_basic", + "TModuleRank1_basic", "TModuleRank2_basic", + "TensorIntModule_basic", "TensorLiteralModule_basic", - "TensorsConcatModule_basic", "TensorOpaqueLiteralModule_basic", - "TransposeIntModule_basic", - "TransposeIntNegDimsModule_basic", - "ToDtypeBoolLayoutNoneModule_basic", + "TensorToBoolZeroRank_basic", + "TensorToFloatZeroRank_basic", + "TensorToIntZeroRank_basic", + "TensorsConcatModule_basic", + "TensorsConcatNegativeDimModule_basic", + "TensorsConcatNegativeDimStaticModule_basic", + "TensorsConcatPromoteDTypeModule_basic", + "TensorsConcatStaticModule_basic", + "TestF16Return_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "TestMultipleTensorReturn_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "ToCopyBoolDTypeStaticModule_basic", "ToDtypeBoolLayoutNoneStaticModule_basic", + "ToDtypeLayoutCPUModule_basic", "ToDtypeLayoutNoneModule_basic", "ToDtypeLayoutStridedModule_basic", - "TypeAsSameModule_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", + "TupleModule_basic", "TypeAsDifferentModule_basic", + "TypeAsSameModule_basic", "TypeConversionF32ToF64Module_basic", "TypeConversionF64ToF32Module_basic", "TypeConversionI1ToF32Module_basic", @@ -927,57 +812,58 @@ STABLEHLO_PASS_SET = { "TypeConversionI1ToI64Module_basic", "TypeConversionI32ToI64Module_basic", "TypeConversionI64ToI32Module_basic", - "TypePromotionAlphaWiderModule_basic", - "TypePromotionSameCategoryZeroRankWider_basic", - "TypePromotionZeroRankHigherCategoryModule_basic", - "OnesModuleCPUDevice_basic", - "Permute0RankModule_basic", + "UnbindIntGetItem_Module_basic", + "UnbindIntListUnpack_Module_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", + "UniformNoCorrelationModule_basic", + "UniformStaticShapeModule_basic", + "UnsafeView1DFoldModule_basic", "UnsafeViewCollapseModule_basic", "UnsafeViewDynamicExpandModule_basic", - "AtenRoundIntModule_basic", - "TestF16Return_basic", - "_LogSoftmaxModuleStable_basic", - "PrimsSqueezeModule_basic", - "PrimsSqueezeEmptyDimensionsModule_basic", - "MoveDimIntModule_basic", - "MoveDimIntNegativeIndexModule_basic", - "ConvolutionBackwardModule2DStatic_basic", - "ConvolutionBackwardModule2DStrided_basic", - "PrimsViewOfModule_basic", - "PrimsViewOfZeroRankModule_basic", - "AtenComplex64Module_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitTensorLastSmallerModule_basic", - "SplitWithSizesListUnpackModule_basic", - "UnbindIntListUnpack_Module_basic", - "UnbindIntGetItem_Module_basic", - "ChunkListUnpack_Module_basic", - "ChunkListUnpackUneven_Module_basic", - "RandIntDtypeModule_basic", - "RandIntLowDtypeModule_basic", - "RandIntLowModule_basic", - "RandIntModule_basic", - "RandIntPinMemoryModule_basic", - "RandModule_basic", - "UniformStaticShapeModule_basic", - "UniformNoCorrelationModule_basic", - "TupleModule_basic", - "AtenEmbeddingBagStaticModule_basic", + "UnsafeViewExpandModule_basic", + "View1DFoldModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewCollapseModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewExpandDynamicDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandOnesModule_basic", + "ViewNegativeStaticModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "ViewNoChangeStaticModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", + "ZerosLikeModule_defaultDtype", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_float", + "ZerosLikeModule_int", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleFalsePinMemory_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", } -STABLEHLO_CRASHING_SET = { - # These e2e tests crash because currently mlir-hlo's shape-component-analysis - # only support exact one index in tensor::ExtractOp when it's related with - # some tensors' shape. REF: - # https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/analysis/shape_component_analysis.cc#L586 - # FIXME if upstream mlir-hlo fix this. - "ViewCollapseDynamicWithAtenSizeIntModule_basic", - "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", - - "Aten_EmbeddingBagExample_basic", - "AtenEmbeddingBagSumExample_basic" +STABLEHLO_CRASHING_SET = { + "AtenEmbeddingBagSumExample_basic", } # Write the TOSA set as a "passing" set as it is very early in development diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py new file mode 100644 index 000000000..9143ae5ea --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -0,0 +1,57 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.compiler_utils import run_pipeline_with_repro_report + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend + +from .abc import StablehloBackend + +__all__ = [ + "LinalgOnTensorsStablehloBackend", +] + +# The pipeline of func.func passes that lower the STABLEHLO backend contract to the +# Linalg-on-Tensors backend contract accepted by RefBackend. +STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([ + "func.func(chlo-legalize-to-stablehlo)", + "canonicalize", + "stablehlo-legalize-to-linalg" +]) + + +class LinalgOnTensorsStablehloBackend(StablehloBackend): + """Main entry-point for the linalg-on-tensors based TOSA backend. + + This currently uses the linalg-on-tensors RefBackend for actual execution. + """ + + def __init__(self): + super().__init__() + self.refbackend = RefBackendLinalgOnTensorsBackend() + + def compile(self, imported_module: Module): + """Compiles an imported module that satisfied the TOSA backend contract. + + Args: + imported_module: The MLIR module consisting of funcs in the TOSA + dialect. + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + + run_pipeline_with_repro_report( + imported_module, + f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})", + "Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract") + + return self.refbackend.compile(imported_module) + + def load(self, module): + """Loads a compiled artifact into the runtime.""" + return self.refbackend.load(module)