[Stablehlo] add refbackend (#2712)

pull/2911/head
Yuanqiang Liu 2024-02-16 01:08:48 +08:00 committed by GitHub
parent 8e2e5eeae9
commit f3e8199a6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 485 additions and 478 deletions

View File

@ -20,6 +20,10 @@ echo "::group::Run TOSA e2e integration tests"
python -m e2e_testing.main --config=tosa -v
echo "::endgroup::"
echo "::group::Run Stablehlo e2e integration tests"
python -m e2e_testing.main --config=stablehlo -v
echo "::endgroup::"
case $torch_version in
nightly)
# Failing with: NotImplementedError:

2
externals/stablehlo vendored

@ -1 +1 @@
Subproject commit e191eb4c3c3f3144503a8a117d760de5ddcc7e89
Subproject commit 4ac26f8786d491c5d8376e6e563d1b72af09de75

View File

@ -31,6 +31,10 @@ set(LinkedLibs
TorchMLIRTorchOnnxToTorch
)
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs StablehloPasses StablehloLinalgTransforms)
endif()
if(TORCH_MLIR_ENABLE_REFBACKEND)
add_subdirectory(RefBackend)
list(APPEND LinkedLibs TorchMLIRRefBackend)

View File

@ -24,6 +24,9 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include <unordered_set>
#include <vector>
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
@ -116,6 +119,12 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
initIndex = hlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
}
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
outputShape.erase(outputShape.begin() + dim);
auto outputTy = RankedTensorType::get(outputShape, inputElemTy);
auto outputIndexTy =
RankedTensorType::get(outputShape, rewriter.getIntegerType(64));
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
auto indexTensor = rewriter.create<stablehlo::DynamicIotaOp>(
@ -125,7 +134,8 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
inputShapeTensor, static_cast<uint64_t>(dim));
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), ValueRange{input, indexTensor},
op->getLoc(), TypeRange{outputTy, outputIndexTy},
ValueRange{input, indexTensor},
ValueRange{
initValue,
initIndex,
@ -412,7 +422,8 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));
op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input,
initValue, rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
@ -473,7 +484,8 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
@ -535,7 +547,8 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
return failure();
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));
op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue,
rewriter.getDenseI64ArrayAttr(dims));
Block &block = stablehloReduceOp.getBody().emplaceBlock();
auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType());
@ -614,6 +627,14 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
}
}
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputTy.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputTy.getDimSize(i));
}
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported");
@ -625,7 +646,9 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
llvm::sort(dims.begin(), dims.end());
auto stablehloReduceOp = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims));
op.getLoc(),
RankedTensorType::get(reduceResultShape, outTy.getElementType()), input,
initValue, rewriter.getDenseI64ArrayAttr(dims));
Region &region = stablehloReduceOp.getBody();
Block &block = region.emplaceBlock();
@ -714,6 +737,14 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
// stable with unordered dims.
std::sort(dims.begin(), dims.end());
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputRank; i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputType.getDimSize(i));
}
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(
@ -728,8 +759,8 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
}
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), squareOp.getResult(), initValue,
rewriter.getDenseI64ArrayAttr(dims));
op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType),
squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims));
Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();
@ -832,6 +863,14 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
std::sort(dims.begin(), dims.end());
}
std::unordered_set<int64_t> dimsSet(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape;
for (int64_t i = 0; i < inputType.getRank(); i++) {
if (dimsSet.find(i) == dimsSet.end()) {
reduceResultShape.push_back(inputType.getDimSize(i));
}
}
bool keepDim = false;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) {
return rewriter.notifyMatchFailure(
@ -848,7 +887,8 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
ord, nullptr);
auto reduceOp = rewriter.create<stablehlo::ReduceOp>(
op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims));
op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType),
powValue, initValue, rewriter.getDenseI64ArrayAttr(dims));
Region &region = reduceOp.getBody();
Block &block = region.emplaceBlock();

View File

@ -29,6 +29,11 @@
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/RefBackend/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "stablehlo/conversions/linalg/transforms/Passes.h"
#include "stablehlo/transforms/Passes.h"
#endif
void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::torch::Torch::TorchDialect>();
@ -52,6 +57,11 @@ void mlir::torch::registerAllPasses() {
mlir::torch::onnx_c::registerTorchOnnxToTorchPasses();
mlir::torch::TMTensor::registerPasses();
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::stablehlo::registerChloLegalizeToStablehloPass();
mlir::stablehlo::registerStablehloLegalizeToLinalgPass();
#endif
#ifdef TORCH_MLIR_ENABLE_REFBACKEND
mlir::torch::RefBackend::registerRefBackendPasses();
#endif

View File

@ -25,6 +25,7 @@ from torch_mlir_e2e_test.configs import (
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
from .xfail_sets import (
LINALG_XFAIL_SET,
@ -43,7 +44,7 @@ from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()
def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config",
choices=config_choices,
@ -52,6 +53,7 @@ def _get_argparse():
Meaning of options:
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
"tosa": run through torch-mlir"s default TOSA backend.
"stablehlo": run through torch-mlir"s default Stablehlo backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
@ -90,6 +92,10 @@ def main():
config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = LINALG_XFAIL_SET
crashing_set = set()
elif args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
crashing_set = STABLEHLO_CRASHING_SET
elif args.config == "tosa":
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET

View File

@ -378,84 +378,16 @@ TORCHDYNAMO_CRASHING_SET = {
}
STABLEHLO_PASS_SET = {
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"AddIntModule_basic",
"AtenIntBoolOpModule_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"BoolFloatFalseModule_basic",
"BoolFloatTrueModule_basic",
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
"CeilFloatModule_basic",
"DivFloatModule_basic",
"DivIntModule_basic",
"EqIntModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"MulIntModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"SqrtIntModule_basic",
"SubFloatModule_basic",
"SubIntModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToIntZeroRank_basic",
"TensorToFloatZeroRank_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"AliasModule_basic",
"TensorIntModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
"AnyBoolFalseModule_basic",
"AnyBoolTrueModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
"AtenFloatScalarModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"AtenSubFloatModule_basic",
"BoolFloatConstantModule_basic",
"BoolIntConstantModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
"IntFloatModule_basic",
"IsFloatingPointFloat_True",
"IsFloatingPointInt_False",
"LenStrModule_basic",
"MeanDimAllReduceKeepdimModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimDtypeModule_basic",
"MeanDimKeepdimModule_basic",
"MeanDimModule_basic",
"MeanDimNegativeModule_basic",
"NumelZeroRankModule_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntModule_basic",
"PrimMinIntDynamicModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SqrtIntConstantModule_basic",
"StdBiasedModule_basic",
"StdDimBiasedModule_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"VarBiasedModule_basic",
"VarDimBiasedModule_basic",
"VarMeanBiasedModule_basic",
"VarMeanDimBiasedModule_basic",
"ConstantBoolParameterModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddSizeIntModule_basic",
"AddSizeIntNegDimModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
@ -467,139 +399,161 @@ STABLEHLO_PASS_SET = {
"ArangeStartIntModule_basic",
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartNegativeStepIntModule_basic",
"ArangeStartOutDtypeModule_basic",
"ArangeStartOutModule_basic",
"ArangeStartOutViewModule_basic",
"ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"BatchMlpLayerModule_basic",
"BatchNorm1DModule_basic",
"BatchNorm1DWith2DInputModule_basic",
"BatchNorm2DModule_basic",
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"ResNet18StaticModule_basic",
"BmmFloatModule_basic",
"BmmIntModule_basic",
"BroadcastToModule_basic",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic",
"AtenEyeMModuleCPUDevice_basic",
"AtenEyeMModuleDefaultDtype_basic",
"AtenEyeMModuleFalsePinMemory_basic",
"AtenEyeMModuleFloat2D_basic",
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleCPUDevice_basic",
"AtenEyeModuleDefaultDtype_basic",
"AtenEyeModuleFalsePinMemory_basic",
"AtenEyeModuleFloat2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenFloatScalarModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
"AtenIntBoolOpModule_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"AtenItemFpOpModule_basic",
"AtenItemIntOpModule_basic",
"AtenMmFloatTypes_basic",
"AtenMmIntTypes_basic",
"AtenRoundIntModule_basic",
"AtenSubFloatModule_basic",
"AtenToDeviceModule_basic",
"AvgPool1dStaticModule_basic",
"AvgPool2dStaticModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
"BaddbmmStaticModule_basic",
"BoolFloatConstantModule_basic",
"BoolFloatFalseModule_basic",
"BoolFloatTrueModule_basic",
"BoolIntConstantModule_basic",
"BoolIntFalseModule_basic",
"BoolIntTrueModule_basic",
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnMixedModule_basic",
"BoolTensorReturnTrueModule_basic",
"BroadcastListConstructWithMinusOneModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"BroadcastListConstructWithMinusOneModule_basic",
"BucketizeTensorStaticFloatModule_basic",
"BucketizeTensorStaticModule_basic",
"CeilFloatModule_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"CloneModule_basic",
"ConstantBoolParameterModule_basic",
"ContainsIntList_False",
"ContainsIntList_True",
"ContiguousModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Convolution2DStaticModule_basic",
"ConvolutionBackwardModule2DStatic_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"CosineSimilarityStaticModule_basic",
"CumsumInputDtypeInt32Module_basic",
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"CosineSimilarityStaticModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
"DetachModule_basic",
"ElementwiseIsnanModule_basic",
"DivFloatModule_basic",
"DivIntModule_basic",
"DropoutEvalFloatModule_basic",
"DropoutEvalIntModule_basic",
"DropoutTrainStaticShapeModule_basic",
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAtenIsinfOpModule_basic",
"ElementwiseAtenIsneginfOpModule_basic",
"ElementwiseAtenIsposinfOpModule_basic",
"ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalNotOpModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseBitwiseNotInt32Module_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseBitwiseOrStaticShapeModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseSignModule_basic",
"ElementwisePowModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
"ElementwiseSeluModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNegModule_basic",
"ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseCeilModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseClampTensorInt8Module_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseCosModule_basic",
"ElementwiseErfModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseFloorIntModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseAddModule_basic",
"ElementwiseAddScalarFloatModule_basic",
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseDivScalarModule_basic",
"ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseEqDiffWidthScalarModule_basic",
"ElementwiseEqFloatScalarModule_basic",
"ElementwiseEqIntScalarModule_basic",
"ElementwiseNeFloatScalarModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseNeFloatTensorStaticModule_basic",
"ElementwiseNeIntTensorStaticModule_basic",
"ElementwiseEqBoolScalarModule_basic",
"ElementwiseErfModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseGtFloatScalarModule_basic",
"ElementwiseGtIntScalarModule_basic",
"ElementwiseGtMixed2ScalarModule_basic",
"ElementwiseGeFloatIntScalarModule_basic",
"ElementwiseGeFloatScalarModule_basic",
"ElementwiseGeIntScalarModule_basic",
"ElementwiseGeMixedIntScalarModule_basic",
"ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLeFloatIntScalarModule_basic",
"ElementwiseLeFloatScalarModule_basic",
"ElementwiseLeIntScalarModule_basic",
"ElementwiseLeMixedIntScalarModule_basic",
"ElementwiseLtDiffWidthScalarModule_basic",
"ElementwiseLtFloatScalarModule_basic",
"ElementwiseLtIntScalarModule_basic",
"ElementwiseMulScalarModule_basic",
"ElementwiseMulScalarModule_float",
"ElementwiseMulScalarModule_int",
"ElementwiseNeIntScalarModule_basic",
"ElementwiseNegModule_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwiseReciprocalModule_basic",
"ElementwiseRelu6Module_basic",
"ElementwiseReluModule_basic",
"ElementwiseRemainderScalarModule_Bool_basic",
"ElementwiseRemainderScalarModule_Float_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"ElementwiseSubScalarFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleI32Static_basic",
"EmbeddingModuleI32_basic",
"EmbeddingModuleI64_basic",
"EmbeddingModuleF16_basic",
"ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSinModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToI8Module_basic",
"ElementwiseToDtypeIdentityModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory",
"EmptyLikeModule_float",
"EmptyLikeModule_int",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
"EmptyModule_float",
"EmptyModule_int",
"EmptyStridedModule_basic",
"EqIntModule_basic",
"ExpandAsIntModule_basic",
"ExpandModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticContractRhsModule_basic",
"Fill_TensorFloat64WithFloat32Static_basic",
"Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic",
"Fill_TensorFloat64WithInt64_basic",
"Fill_TensorFloat64WithFloat32Static_basic",
"Fill_TensorFloat64WithInt64Static_basic",
"Fill_TensorFloat64WithInt64_basic",
"FlattenRank0Module_basic",
"FlattenStaticModule_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"FullLikeModuleDefaultDtype_basic",
@ -616,188 +570,67 @@ STABLEHLO_PASS_SET = {
"FullModuleFloat3D_basic",
"FullModuleInt2D_basic",
"FullModuleInt3D_basic",
"NewFullModuleDefaultDtype_basic",
"NewFullModuleFalsePinMemory_basic",
"NewFullModuleFloat2D_basic",
"NewFullModuleFloat3DStatic_basic",
"NewFullModuleFloat3D_basic",
"NewFullModuleInt2DStatic_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"GroupNormModule_basic",
"GatherStaticModule_basic",
"GatherModule_basic",
"Gather2DInputModdule_basic",
"GatherRandomIndexModule_basic",
"GatherNegativeDimModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
"GeluBackwardModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"HardsigmoidModule_basic",
"HardsigmoidRandomModule_basic",
"IndexSelectDynamicIndexSizeModule_basic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexSelectNegativeDimModule_basic",
"IndexTensorStaticModule_basic",
"GluStaticModule_basic",
"GroupNormModule_basic",
"GroupNormNoWeightAndBiasModule_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"IntFloatModule_basic",
"IsFloatingPointFloat_True",
"IsFloatingPointInt_False",
"LayerNormLastDimModule_basic",
"LayerNormModule_basic",
"LayerNormNormalizeOverAllDimsModule_basic",
"LeakyReluBackwardStaticModule_basic",
"LinalgVectorNormModule_basic",
"LinalgVectorNormKeepDimModule_basic",
"MatmulBroadcastBatchDim_basic",
"MatmulSingleDynamicBatchDim_basic",
"Matmul_3d",
"Matmul_4d",
"LenStrModule_basic",
"LiftFreshCopyModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"MaskedFillScalarIntValueStaticModule_basic",
"Matmul4dStatic_basic",
"Matmul_2d",
"Matmul_dot",
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool2dStaticModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
"MeanDimAllReduceKeepdimModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic",
"MeanDtypeModule_basic",
"MeanDynamicSizesModule_basic",
"MeanLargeInputModule_basic",
"MeanModule_basic",
"Mlp1LayerModule_basic",
"Mlp2LayerModule_basic",
"MmTanhModule_basic",
"Mv_basic",
"NativeLayerNormModule4D_basic",
"NativeLayerNormModule_basic",
"OneHotModule_basic",
"PrimsConvertElementTypeModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListKeepDimFloatModule_basic",
"ReduceSumDimIntListKeepDimIntModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceL1NormModule_basic",
"ReduceL1NormWithDTypeModule_basic",
"ReduceL2NormModule_basic",
"ReduceL3NormAllDimsModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ReduceLN3NormModule_basic",
"NormScalarOptDimKeepDimModule_basic",
"NormScalarOptDimModule_basic",
"NormalizeModule_basic",
"ScalarConstantTupleModule_basic",
"SelectIntModule_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SliceSingleIdxModule_basic",
"SqueezeDimModule_dynamic",
"SqueezeDimModule_negDim",
"ToCopyBoolDTypeStaticModule_basic",
"ToCopyModule_basic",
"ToCopyWithDTypeFalsePinMemoryModule_basic",
"ToCopyWithDTypeModule_basic",
"ReduceFrobeniusNormModule_basic",
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"TensorsConcatNegativeDimModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsConcatStaticModule_basic",
"TensorsConcatNegativeDimStaticModule_basic",
"TensorsStackModule_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
"LiftFreshCopyModule_basic",
"Mlp2LayerModuleNoBias_basic",
"NumelModule_basic",
"SiluModule_basic",
"SquareModule_basic",
"SqueezeModule_allUnitDim",
"SqueezeDimModule_unitDim",
"ViewCollapseOnesMiddleModule_basic",
"ViewDoubleMergeStaticModule_basic",
"ViewExpandDynamicDimModule_basic",
"ViewFlattenAndExpandModule_basic",
"ViewFiveTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
"ViewExpandOnesMiddleOppModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"NumToTensorFloatModule_basic",
"AtenToDeviceModule_basic",
"AvgPool1dStaticModule_basic",
"AvgPool2dStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"Conv2dWithPaddingDilationStrideStaticModule_grouped",
"Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier",
"Convolution2DStaticModule_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneChannelsLastMemoryFormatModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ReturnThreeTensorFloat32_basic",
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic",
"SqueezeModule_static",
"TModuleRank1_basic",
"TModuleRank0_basic",
"ElementwiseToDtypeIdentityModule_basic",
"View1DFoldModule_basic",
"UnsafeView1DFoldModule_basic",
"UnflattenStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic",
"ScalarTensorInt32Module_basic",
"ScalarTensorInt64Module_basic",
"SelectScattertModule_basic",
"SelectScattertStaticModule_basic",
"SliceStaticModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceStartEqEndModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceWholeTensorModule_basic",
"SliceScatterModule_basic",
"SliceScatterNegativeDimModule_basic",
"SliceScatterNegativeEndModule_basic",
"SliceScatterStaticModule_basic",
"SliceScatterStepVariationModule_basic",
"SliceScatterZeroDimModule_basic",
"SqueezeDimModule_static",
"SqueezeDimModule_identity",
"SqueezeModule_broadcast",
"ReturnTwoTensorF32I64_basic",
"Matmul4dStatic_basic",
"Matmul_dot",
"Matmul_2d",
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool2dWithIndicesStaticModule_basic",
"MmDagModule_basic",
"MmModule_basic",
"MmModule_chained",
"MaxPool2dStaticModule_basic",
"EmptyModule_contiguous",
"EmptyModule_defaultDtype",
"EmptyModule_falsePinMemory",
"EmptyModule_int",
"EmptyModule_float",
"MmTanhModule_basic",
"MoveDimIntModule_basic",
"MoveDimIntNegativeIndexModule_basic",
"MulFloatModule_basic",
"MulIntModule_basic",
"Mv_basic",
"NarrowHorizontalTest2_basic",
"NarrowHorizontalTest_basic",
"NarrowTensorHorizontalModule_basic",
"NarrowTensorVerticalModule_basic",
"NarrowVerticalTest2_basic",
"NarrowVerticalTest_basic",
"NativeDropoutEvalFloatModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"NativeGroupNormModule_basic",
"NativeLayerNormModule4D_basic",
"NativeLayerNormModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewEmptyModuleDefaultDtype_basic",
"NewEmptyModuleFalsePinMemory_basic",
"NewEmptyModuleFloat2D_basic",
@ -808,117 +641,169 @@ STABLEHLO_PASS_SET = {
"NewEmptyModuleNonDefaultFloatDtype_basic",
"NewEmptyModuleNonDefaultIntDtype_basic",
"NewEmptyStridedModuleDefaultDtype_basic",
"EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ZeroFloat32Module_basic",
"ZeroInt32Module_basic",
"ZeroInt64Module_basic",
"ZerosLikeModule_defaultDtype",
"ZerosLikeModule_falsePinMemory",
"ZerosLikeModule_float",
"ZerosLikeModule_int",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleFalsePinMemory_basic",
"OnesModuleDefaultDtype_basic",
"OnesModuleInt_basic",
"OnesModuleFloat_basic",
"OnesModuleFalsePinMemory_basic",
"NewFullModuleDefaultDtype_basic",
"NewFullModuleFalsePinMemory_basic",
"NewFullModuleFloat3DStatic_basic",
"NewFullModuleFloat3D_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"NewOnesModuleDefaultDtype_basic",
"NewOnesModuleFalsePinMemory_basic",
"NewOnesModuleFloat2D_basic",
"NewOnesModuleFloat3D_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NewZerosStaticModuleLayoutStrided_basic",
"NormalizeModule_basic",
"NumToTensorFloatModule_basic",
"NumToTensorIntModule_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic",
"NumpyTRank2Module_basic",
"NumpyTRankNDynamicModule_basic",
"NumpyTRankNStaticModule_basic",
"OnesLikeModule_defaultDtype",
"OnesLikeModule_falsePinMemory",
"OnesLikeModule_float",
"OnesLikeModule_int",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewOnesModuleDefaultDtype_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewOnesModuleFloat2D_basic",
"NewOnesModuleFloat3D_basic",
"NewOnesModuleFalsePinMemory_basic",
"NewZerosStaticModuleLayoutStrided_basic",
"DropoutEvalIntModule_basic",
"DropoutEvalFloatModule_basic",
"DropoutTrainStaticShapeModule_basic",
"NativeDropoutEvalFloatModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ContiguousModule_basic",
"DropoutModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"UnsafeViewExpandModule_basic",
"OnesModuleCPUDevice_basic",
"OnesModuleDefaultDtype_basic",
"OnesModuleFalsePinMemory_basic",
"OnesModuleFloat_basic",
"OnesModuleInt_basic",
"Permute0RankModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"PowIntFloatModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
"PrimMinIntModule_basic",
"PrimsConvertElementTypeModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsSqueezeModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RandModule_basic",
"ReduceAmaxMultiDim_basic",
"ReduceAmaxOutOfOrderDim_basic",
"ReduceAmaxSingleDim_basic",
"ReduceFrobeniusNormModule_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxAlongDimNegative_basic",
"ReduceMaxAlongDimSignedInt_basic",
"ReduceMaxAlongDim_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceMinAllDims_basic",
"ReduceMinFloatModule_basic",
"ReduceMinSignedIntModule_basic",
"ReduceMinUnsignedIntModule_basic",
"ReduceSumDimIntListDtypeFloatModule_basic",
"ReduceSumDimIntListDtypeIntModule_basic",
"ReduceSumDimIntListElementTypeBoolModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic",
"ReduceSumDtypeFloatModule_basic",
"ReduceSumDtypeIntModule_basic",
"ReduceSumElementTypeBoolModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
"ReshapeAsModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"BaddbmmStaticModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
"NarrowHorizontalTest2_basic",
"NarrowHorizontalTest_basic",
"NarrowVerticalTest2_basic",
"NarrowVerticalTest_basic",
"NarrowTensorHorizontalModule_basic",
"NarrowTensorVerticalModule_basic",
"NumToTensorIntModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic",
"NumpyTRank2Module_basic",
"NumpyTRankNStaticModule_basic",
"NumpyTRankNDynamicModule_basic",
"ReshapeExpandModule_basic",
"ReturnThreeTensorFloat32_basic",
"ReturnTwoTensorF32I64_basic",
"RollModule_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ScalarConstantTupleModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
"ScalarTensorDefaultDtypeModule_basic",
"ScalarTensorFloat32Module_basic",
"ScalarTensorInt32Module_basic",
"ScalarTensorInt64Module_basic",
"SelectIntNegativeDimAndIndexStaticModule_basic",
"SelectScattertStaticModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceOutOfUpperBoundIndexStaticModule_basic",
"SliceScatterModule_basic",
"SliceScatterNegativeDimModule_basic",
"SliceScatterNegativeEndModule_basic",
"SliceScatterStaticModule_basic",
"SliceScatterStepVariationModule_basic",
"SliceScatterZeroDimModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceStartEqEndModule_basic",
"SliceStaticModule_basic",
"SliceWholeTensorModule_basic",
"SortIntListReverse_basic",
"SortIntList_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SqrtIntConstantModule_basic",
"SqrtIntModule_basic",
"SqueezeDimModule_identity",
"SqueezeDimModule_static",
"SqueezeDimModule_unitDim",
"SqueezeModule_allUnitDim",
"SqueezeModule_static",
"SubFloatModule_basic",
"SubIntModule_basic",
"TModuleRank0_basic",
"TModuleRank1_basic",
"TModuleRank2_basic",
"TensorIntModule_basic",
"TensorLiteralModule_basic",
"TensorsConcatModule_basic",
"TensorOpaqueLiteralModule_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
"ToDtypeBoolLayoutNoneModule_basic",
"TensorToBoolZeroRank_basic",
"TensorToFloatZeroRank_basic",
"TensorToIntZeroRank_basic",
"TensorsConcatModule_basic",
"TensorsConcatNegativeDimModule_basic",
"TensorsConcatNegativeDimStaticModule_basic",
"TensorsConcatPromoteDTypeModule_basic",
"TensorsConcatStaticModule_basic",
"TestF16Return_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"TestMultipleTensorReturn_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"ToCopyBoolDTypeStaticModule_basic",
"ToDtypeBoolLayoutNoneStaticModule_basic",
"ToDtypeLayoutCPUModule_basic",
"ToDtypeLayoutNoneModule_basic",
"ToDtypeLayoutStridedModule_basic",
"TypeAsSameModule_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
"TriuBroadcastModule_basic",
"TriuModule_basic",
"TupleModule_basic",
"TypeAsDifferentModule_basic",
"TypeAsSameModule_basic",
"TypeConversionF32ToF64Module_basic",
"TypeConversionF64ToF32Module_basic",
"TypeConversionI1ToF32Module_basic",
@ -927,57 +812,58 @@ STABLEHLO_PASS_SET = {
"TypeConversionI1ToI64Module_basic",
"TypeConversionI32ToI64Module_basic",
"TypeConversionI64ToI32Module_basic",
"TypePromotionAlphaWiderModule_basic",
"TypePromotionSameCategoryZeroRankWider_basic",
"TypePromotionZeroRankHigherCategoryModule_basic",
"OnesModuleCPUDevice_basic",
"Permute0RankModule_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic",
"UniformNoCorrelationModule_basic",
"UniformStaticShapeModule_basic",
"UnsafeView1DFoldModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
"AtenRoundIntModule_basic",
"TestF16Return_basic",
"_LogSoftmaxModuleStable_basic",
"PrimsSqueezeModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"MoveDimIntModule_basic",
"MoveDimIntNegativeIndexModule_basic",
"ConvolutionBackwardModule2DStatic_basic",
"ConvolutionBackwardModule2DStrided_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"AtenComplex64Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitTensorLastSmallerModule_basic",
"SplitWithSizesListUnpackModule_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
"ChunkListUnpack_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandIntLowModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RandModule_basic",
"UniformStaticShapeModule_basic",
"UniformNoCorrelationModule_basic",
"TupleModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"UnsafeViewExpandModule_basic",
"View1DFoldModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseOnesMiddleModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandDynamicDimModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewExpandOnesModule_basic",
"ViewNegativeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewOffsetBackwardTestStaticModule_basic",
"ViewOffsetTestStaticModule_basic",
"ViewTwoFiveThreeStaticModule_basic",
"ViewTwoToThreeStaticModule_basic",
"ZeroFloat32Module_basic",
"ZeroInt32Module_basic",
"ZeroInt64Module_basic",
"ZerosLikeModule_defaultDtype",
"ZerosLikeModule_falsePinMemory",
"ZerosLikeModule_float",
"ZerosLikeModule_int",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleFalsePinMemory_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
}
STABLEHLO_CRASHING_SET = {
# These e2e tests crash because currently mlir-hlo's shape-component-analysis
# only support exact one index in tensor::ExtractOp when it's related with
# some tensors' shape. REF:
# https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/analysis/shape_component_analysis.cc#L586
# FIXME if upstream mlir-hlo fix this.
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"Aten_EmbeddingBagExample_basic",
"AtenEmbeddingBagSumExample_basic"
STABLEHLO_CRASHING_SET = {
"AtenEmbeddingBagSumExample_basic",
}
# Write the TOSA set as a "passing" set as it is very early in development

View File

@ -0,0 +1,57 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from torch_mlir.ir import *
from torch_mlir.passmanager import *
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from .abc import StablehloBackend
__all__ = [
"LinalgOnTensorsStablehloBackend",
]
# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
# Linalg-on-Tensors backend contract accepted by RefBackend.
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([
"func.func(chlo-legalize-to-stablehlo)",
"canonicalize",
"stablehlo-legalize-to-linalg"
])
class LinalgOnTensorsStablehloBackend(StablehloBackend):
"""Main entry-point for the linalg-on-tensors based TOSA backend.
This currently uses the linalg-on-tensors RefBackend for actual execution.
"""
def __init__(self):
super().__init__()
self.refbackend = RefBackendLinalgOnTensorsBackend()
def compile(self, imported_module: Module):
"""Compiles an imported module that satisfied the TOSA backend contract.
Args:
imported_module: The MLIR module consisting of funcs in the TOSA
dialect.
Returns:
An opaque, backend specific compiled artifact object that can be
passed to `load`.
"""
run_pipeline_with_repro_report(
imported_module,
f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})",
"Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract")
return self.refbackend.compile(imported_module)
def load(self, module):
"""Loads a compiled artifact into the runtime."""
return self.refbackend.load(module)