mirror of https://github.com/llvm/torch-mlir
[onnx] Add testing using the `onnx` compilation using torch tests (#2795)
We can route the torch tests via `onnx` using the `torch.onnx.export` tooling. We can then reimport, lower to torch, and compile to linalg to validate the onnx path is working correctly. The current implementation exposes some failures in the `onnx` path so we cannot enable the onnx test suite yet due to segmentation faults.pull/2914/head
parent
49f63df068
commit
074f112d6a
|
@ -24,6 +24,10 @@ echo "::group::Run Stablehlo e2e integration tests"
|
||||||
python -m e2e_testing.main --config=stablehlo -v
|
python -m e2e_testing.main --config=stablehlo -v
|
||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
|
echo "::group::Run ONNX e2e integration tests"
|
||||||
|
python -m e2e_testing.main --config=onnx -v
|
||||||
|
echo "::endgroup::"
|
||||||
|
|
||||||
case $torch_version in
|
case $torch_version in
|
||||||
nightly)
|
nightly)
|
||||||
# Failing with: NotImplementedError:
|
# Failing with: NotImplementedError:
|
||||||
|
|
|
@ -305,6 +305,9 @@ function test_in_tree() {
|
||||||
echo ":::: Run Linalg e2e integration tests"
|
echo ":::: Run Linalg e2e integration tests"
|
||||||
python -m e2e_testing.main --config=linalg -v
|
python -m e2e_testing.main --config=linalg -v
|
||||||
|
|
||||||
|
echo ":::: Run Onnx e2e integration tests"
|
||||||
|
python -m e2e_testing.main --config=onnx -v
|
||||||
|
|
||||||
# Dynamo is changing a lot in nightly versions, and thus the implementation
|
# Dynamo is changing a lot in nightly versions, and thus the implementation
|
||||||
# tends to become incompatible to the stable version.
|
# tends to become incompatible to the stable version.
|
||||||
echo ":::: Run TorchDynamo e2e integration tests"
|
echo ":::: Run TorchDynamo e2e integration tests"
|
||||||
|
|
|
@ -524,6 +524,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
loc, rewriter.getI64IntegerAttr(i))));
|
loc, rewriter.getI64IntegerAttr(i))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Correct for negative axis:
|
||||||
|
if (axis < 0)
|
||||||
|
axis += dataRank;
|
||||||
|
|
||||||
// 4. We can not directly perform torch.gather as the onnx.gather op
|
// 4. We can not directly perform torch.gather as the onnx.gather op
|
||||||
// collects the input data at different location of output compared to
|
// collects the input data at different location of output compared to
|
||||||
// torch.gather op. The output of torch.gather and onnx.gather ops are
|
// torch.gather op. The output of torch.gather and onnx.gather ops are
|
||||||
|
|
|
@ -586,14 +586,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
||||||
AtenAddTensorOp::Adaptor adaptor(operands);
|
AtenAddTensorOp::Adaptor adaptor(operands);
|
||||||
|
Type resultElementType = add.getType().cast<BaseTensorType>().getDtype();
|
||||||
Type dtype = converter->convertType(add.getType())
|
Type dtype = converter->convertType(add.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype,
|
||||||
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
|
Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype,
|
||||||
/*srcOriginalDtype=*/std::nullopt,
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
/*dstOriginalDtype=*/dtype);
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
Value scaled = b.create<arith::MulFOp>(loc, rhs, alpha);
|
||||||
return b.create<arith::AddFOp>(loc, lhs, scaled);
|
return b.create<arith::AddFOp>(loc, lhs, scaled);
|
||||||
|
|
|
@ -18,12 +18,14 @@ from torch_mlir_e2e_test.configs import (
|
||||||
LinalgOnTensorsBackendTestConfig,
|
LinalgOnTensorsBackendTestConfig,
|
||||||
StablehloBackendTestConfig,
|
StablehloBackendTestConfig,
|
||||||
NativeTorchTestConfig,
|
NativeTorchTestConfig,
|
||||||
|
OnnxBackendTestConfig,
|
||||||
TorchScriptTestConfig,
|
TorchScriptTestConfig,
|
||||||
TosaBackendTestConfig,
|
TosaBackendTestConfig,
|
||||||
TorchDynamoTestConfig,
|
TorchDynamoTestConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
|
from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import LinalgOnTensorsOnnxBackend
|
||||||
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend
|
||||||
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
|
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
|
||||||
|
|
||||||
|
@ -36,7 +38,9 @@ from .xfail_sets import (
|
||||||
LTC_XFAIL_SET,
|
LTC_XFAIL_SET,
|
||||||
LTC_CRASHING_SET,
|
LTC_CRASHING_SET,
|
||||||
TORCHDYNAMO_XFAIL_SET,
|
TORCHDYNAMO_XFAIL_SET,
|
||||||
TORCHDYNAMO_CRASHING_SET
|
TORCHDYNAMO_CRASHING_SET,
|
||||||
|
ONNX_CRASHING_SET,
|
||||||
|
ONNX_XFAIL_SET,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Import tests to register them in the global registry.
|
# Import tests to register them in the global registry.
|
||||||
|
@ -44,7 +48,7 @@ from torch_mlir_e2e_test.test_suite import register_all_tests
|
||||||
register_all_tests()
|
register_all_tests()
|
||||||
|
|
||||||
def _get_argparse():
|
def _get_argparse():
|
||||||
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"]
|
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo", "onnx"]
|
||||||
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
|
||||||
parser.add_argument("-c", "--config",
|
parser.add_argument("-c", "--config",
|
||||||
choices=config_choices,
|
choices=config_choices,
|
||||||
|
@ -58,6 +62,7 @@ Meaning of options:
|
||||||
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
|
||||||
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
|
"lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph.
|
||||||
"torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors.
|
"torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors.
|
||||||
|
"onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path.
|
||||||
""")
|
""")
|
||||||
parser.add_argument("-f", "--filter", default=".*", help="""
|
parser.add_argument("-f", "--filter", default=".*", help="""
|
||||||
Regular expression specifying which tests to include in this run.
|
Regular expression specifying which tests to include in this run.
|
||||||
|
@ -120,6 +125,10 @@ def main():
|
||||||
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend())
|
||||||
xfail_set = TORCHDYNAMO_XFAIL_SET
|
xfail_set = TORCHDYNAMO_XFAIL_SET
|
||||||
crashing_set = TORCHDYNAMO_CRASHING_SET
|
crashing_set = TORCHDYNAMO_CRASHING_SET
|
||||||
|
elif args.config == "onnx":
|
||||||
|
config = OnnxBackendTestConfig(LinalgOnTensorsOnnxBackend())
|
||||||
|
xfail_set = ONNX_XFAIL_SET
|
||||||
|
crashing_set = ONNX_CRASHING_SET
|
||||||
|
|
||||||
do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set)
|
do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set)
|
||||||
available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt]
|
available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt]
|
||||||
|
|
|
@ -1451,3 +1451,759 @@ LTC_XFAIL_SET = {
|
||||||
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||||
"Conv2dQInt8Module_basic",
|
"Conv2dQInt8Module_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ONNX_XFAIL_SET = {
|
||||||
|
# Failure - onnx_export
|
||||||
|
"AdaptiveAvgPool1dGeneralDynamic_basic",
|
||||||
|
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
|
||||||
|
"AdaptiveAvgPool1dStaticLargerOutput_basic",
|
||||||
|
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
|
||||||
|
"AdaptiveMaxPool2dDynamicWithIndices_basic",
|
||||||
|
"AdaptiveMaxPool2dDynamic_basic",
|
||||||
|
"AdaptiveMaxPool2dStaticWithIndices_basic",
|
||||||
|
"AdaptiveMaxPool2dStatic_basic",
|
||||||
|
"AddCDivModule_basic",
|
||||||
|
"AddIntModule_basic",
|
||||||
|
"Add_Module_basic",
|
||||||
|
"AllBoolFalseModule_basic",
|
||||||
|
"AllBoolTrueModule_basic",
|
||||||
|
"AnyBoolFalseModule_basic",
|
||||||
|
"AnyBoolTrueModule_basic",
|
||||||
|
"AtenComplex64Module_basic",
|
||||||
|
"AtenComplexImagModule_basic",
|
||||||
|
"AtenComplexRealModule_basic",
|
||||||
|
"AtenComplexViewModule_basic",
|
||||||
|
"AtenEmbeddingBagStaticModule_basic",
|
||||||
|
"AtenEmbeddingBagSumExample_basic",
|
||||||
|
"AtenFloatScalarModule_basic",
|
||||||
|
"AtenIntBoolOpConstFalseModule_basic",
|
||||||
|
"AtenIntBoolOpConstTrueModule_basic",
|
||||||
|
"AtenIntBoolOpModule_basic",
|
||||||
|
"AtenIntTensorByteDtypeModule_basic",
|
||||||
|
"AtenIntTensorCharDtypeModule_basic",
|
||||||
|
"AtenItemFpOpModule_basic",
|
||||||
|
"AtenItemIntOpModule_basic",
|
||||||
|
"AtenMmQuint8_basic",
|
||||||
|
"AtenRealView128Module_basic",
|
||||||
|
"AtenRealView64Module_basic",
|
||||||
|
"AtenSubFloatModule_basic",
|
||||||
|
"AtenTopKModule_basic",
|
||||||
|
"AtenTopKSmallestModule_basic",
|
||||||
|
"Aten_EmbeddingBagExample_basic",
|
||||||
|
"AvgPool2dWithoutPadModule_basic",
|
||||||
|
"BatchMlpLayerModule_basic",
|
||||||
|
"BincountMinlengthModule_basic",
|
||||||
|
"BincountModule_basic",
|
||||||
|
"BincountStaticSizeModule_basic",
|
||||||
|
"BoolFloatConstantModule_basic",
|
||||||
|
"BoolFloatFalseModule_basic",
|
||||||
|
"BoolFloatTrueModule_basic",
|
||||||
|
"BoolIntConstantModule_basic",
|
||||||
|
"BoolIntFalseModule_basic",
|
||||||
|
"BoolIntTrueModule_basic",
|
||||||
|
"CeilFloatModule_basic",
|
||||||
|
"ChunkListUnpackDynamic_Module_basic",
|
||||||
|
"ChunkListUnpackUnevenDynamic_Module_basic",
|
||||||
|
"CollapseAllDimensionsModule_basic",
|
||||||
|
"CollapseFullDynamicModule_basic",
|
||||||
|
"CollapsePartialDynamicModule_basic",
|
||||||
|
"CollapseRank1DynamicModule_basic",
|
||||||
|
"CollapseStaticModule_basic",
|
||||||
|
"ConstantBoolParameterModule_basic",
|
||||||
|
"ContainsIntList_False",
|
||||||
|
"ContainsIntList_True",
|
||||||
|
"Conv1dModule_basic",
|
||||||
|
"Conv2dBiasNoPaddingModule_basic",
|
||||||
|
"Conv2dModule_basic",
|
||||||
|
"Conv2dNoPaddingModule_basic",
|
||||||
|
"Conv2dQInt8Module_basic",
|
||||||
|
"Conv2dWithPaddingDilationStrideModule_basic",
|
||||||
|
"Conv2dWithPaddingModule_basic",
|
||||||
|
"Conv3dModule_basic",
|
||||||
|
"ConvTbcModule_basic",
|
||||||
|
"Conv_Transpose2dModule_basic",
|
||||||
|
"Convolution2DModule_basic",
|
||||||
|
"Convolution2DStridedModule_basic",
|
||||||
|
"ConvolutionBackwardModule2DPadded_basic",
|
||||||
|
"ConvolutionBackwardModule2DStatic_basic",
|
||||||
|
"ConvolutionBackwardModule2DStrided_basic",
|
||||||
|
"ConvolutionBackwardModule2D_basic",
|
||||||
|
"ConvolutionModule2DGroups_basic",
|
||||||
|
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
|
||||||
|
"ConvolutionModule2DTransposeStrided_basic",
|
||||||
|
"ConvolutionModule2DTranspose_basic",
|
||||||
|
"DivFloatModule_basic",
|
||||||
|
"DivIntModule_basic",
|
||||||
|
"ElementwiseAcoshIntModule_basic",
|
||||||
|
"ElementwiseAcoshModule_basic",
|
||||||
|
"ElementwiseAsinhIntModule_basic",
|
||||||
|
"ElementwiseAsinhModule_basic",
|
||||||
|
"ElementwiseAtanhIntModule_basic",
|
||||||
|
"ElementwiseAtanhModule_basic",
|
||||||
|
"ElementwiseAtenIsneginfOpModule_basic",
|
||||||
|
"ElementwiseAtenIsposinfOpModule_basic",
|
||||||
|
"ElementwiseBitwiseAndModule_basic",
|
||||||
|
"ElementwiseBitwiseAndScalarInt32Module_basic",
|
||||||
|
"ElementwiseBitwiseAndScalarInt64Module_basic",
|
||||||
|
"ElementwiseBitwiseAndScalarInt8Module_basic",
|
||||||
|
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||||
|
"ElementwiseBitwiseLeftShiftInt32Module_basic",
|
||||||
|
"ElementwiseBitwiseLeftShiftInt64Module_basic",
|
||||||
|
"ElementwiseBitwiseLeftShiftInt8Module_basic",
|
||||||
|
"ElementwiseBitwiseNotInt32Module_basic",
|
||||||
|
"ElementwiseBitwiseNotInt64Module_basic",
|
||||||
|
"ElementwiseBitwiseOrModule_basic",
|
||||||
|
"ElementwiseBitwiseOrStaticShapeModule_basic",
|
||||||
|
"ElementwiseBitwiseRightShiftInt32Module_basic",
|
||||||
|
"ElementwiseBitwiseRightShiftInt64Module_basic",
|
||||||
|
"ElementwiseBitwiseRightShiftInt8Module_basic",
|
||||||
|
"ElementwiseBitwiseXorModule_basic",
|
||||||
|
"ElementwiseBitwiseXorStaticShapeModule_basic",
|
||||||
|
"ElementwiseCoshIntModule_basic",
|
||||||
|
"ElementwiseCoshModule_basic",
|
||||||
|
"ElementwiseDequantizePerChannelModule_basic",
|
||||||
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
|
"ElementwiseEluNonDefaultModule_basic",
|
||||||
|
"ElementwiseExpm1IntModule_basic",
|
||||||
|
"ElementwiseExpm1Module_basic",
|
||||||
|
"ElementwiseMulTensorComplexModule_basic",
|
||||||
|
"ElementwiseOrTensorModule_basic",
|
||||||
|
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||||
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
|
"ElementwiseRemainderTensorModule_Int_basic",
|
||||||
|
"EmptyStridedModule_basic",
|
||||||
|
"EmptyStridedSizeIntStrideModule_basic",
|
||||||
|
"EqIntModule_basic",
|
||||||
|
"ExponentialModule_basic",
|
||||||
|
"GeFloatIntModule_basic",
|
||||||
|
"GeFloatModule_basic",
|
||||||
|
"GeIntModule_basic",
|
||||||
|
"GeluBackwardModule_basic",
|
||||||
|
"GtFloatIntModule_basic",
|
||||||
|
"GtIntModule_basic",
|
||||||
|
"HardtanhBackward_basic",
|
||||||
|
"IndexPutImpl1DFloatAccumulateModule_basic",
|
||||||
|
"IndexPutImpl1DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPutImpl1DIntAccumulateModule_basic",
|
||||||
|
"IndexPutImpl1DIntNonAccumulateModule_basic",
|
||||||
|
"IndexPutImpl2DFloatAccumulateModule_basic",
|
||||||
|
"IndexPutImpl2DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPutImpl2DIndexModule_basic",
|
||||||
|
"IndexPutImpl2DNoneIndexStaticModule_basic",
|
||||||
|
"IndexPutImpl3DFloatAccumulateModule_basic",
|
||||||
|
"IndexPutImpl3DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPutImplIndexWithNoneModule_basic",
|
||||||
|
"IntFloatModule_basic",
|
||||||
|
"IouOfModule_basic",
|
||||||
|
"IsFloatingPointFloat_True",
|
||||||
|
"IsFloatingPointInt_False",
|
||||||
|
"IscloseStaticModuleTrue_basic",
|
||||||
|
"IscloseStaticModule_basic",
|
||||||
|
"LeakyReluBackwardModule_basic",
|
||||||
|
"LeakyReluBackwardStaticModule_basic",
|
||||||
|
"LenStrModule_basic",
|
||||||
|
"LiftFreshCopyModule_basic",
|
||||||
|
"LogSoftmaxBackwardModule_basic",
|
||||||
|
"MaxPool2dCeilModeTrueModule_basic",
|
||||||
|
"MaxPool2dModule_basic",
|
||||||
|
"MaxPool2dWithIndicesAllOnesModule_basic",
|
||||||
|
"MaxPool2dWithIndicesBackwardDynamic3DModule_basic",
|
||||||
|
"MaxPool2dWithIndicesBackwardDynamic4DModule_basic",
|
||||||
|
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
|
||||||
|
"MaxPool2dWithIndicesBackwardStatic4DModule_basic",
|
||||||
|
"MaxPool2dWithIndicesCeilModeTrueModule_basic",
|
||||||
|
"MaxPool2dWithIndicesFullSizeKernelModule_basic",
|
||||||
|
"MaxPool2dWithIndicesModule_basic",
|
||||||
|
"MaxPool2dWithIndicesNonDefaultDilationModule_basic",
|
||||||
|
"MaxPool2dWithIndicesNonDefaultParamsModule_basic",
|
||||||
|
"MaxPool2dWithIndicesNonDefaultStrideModule_basic",
|
||||||
|
"MaxPool3dCeilModeTrueModule_basic",
|
||||||
|
"MaxPool3dLargeDatadModule_basic",
|
||||||
|
"MaxPool3dModuleRandomSimple_basic",
|
||||||
|
"MaxPool3dModule_basic",
|
||||||
|
"MeanDimEmptyDimModule_basic",
|
||||||
|
"Mlp1LayerModule_basic",
|
||||||
|
"Mlp2LayerModuleNoBias_basic",
|
||||||
|
"Mlp2LayerModule_basic",
|
||||||
|
"MulFloatModule_basic",
|
||||||
|
"MulIntModule_basic",
|
||||||
|
"NarrowHorizontalTest2_basic",
|
||||||
|
"NarrowHorizontalTest_basic",
|
||||||
|
"NarrowTensorHorizontalModule_basic",
|
||||||
|
"NarrowTensorVerticalModule_basic",
|
||||||
|
"NarrowVerticalTest2_basic",
|
||||||
|
"NarrowVerticalTest_basic",
|
||||||
|
"NativeBatchNorm1DModule_basic",
|
||||||
|
"NativeBatchNorm2DModule_basic",
|
||||||
|
"NativeBatchNorm3DModule_basic",
|
||||||
|
"NativeBatchNormNoneWeightModule_basic",
|
||||||
|
"NativeDropoutEvalFloatModule_basic",
|
||||||
|
"NativeGroupNormBackwardModule_basic",
|
||||||
|
"NativeGroupNormModule_basic",
|
||||||
|
"NativeLayerNormDynamicModule_basic",
|
||||||
|
"NeFloatIntModule_basic",
|
||||||
|
"NeIntModule_basic",
|
||||||
|
"NewEmptyStridedModuleDefaultDtype_basic",
|
||||||
|
"NllLossModuleBackward1DMeanWeight_basic",
|
||||||
|
"NllLossModuleBackward1DMean_basic",
|
||||||
|
"NllLossModuleBackward1DSumWeight_basic",
|
||||||
|
"NllLossModuleBackward1DSum_basic",
|
||||||
|
"NllLossModuleBackward1DWeight_basic",
|
||||||
|
"NllLossModuleBackward1D_basic",
|
||||||
|
"NllLossModuleBackwardMeanWeight_basic",
|
||||||
|
"NllLossModuleBackwardMean_basic",
|
||||||
|
"NllLossModuleBackwardSumWeight_basic",
|
||||||
|
"NllLossModuleBackwardSum_basic",
|
||||||
|
"NllLossModuleBackwardWeight_basic",
|
||||||
|
"NllLossModuleBackward_basic",
|
||||||
|
"NllLossModuleBackward_ignore_index",
|
||||||
|
"NllLossModule_1D_basic",
|
||||||
|
"NllLossModule_basic",
|
||||||
|
"NllLossModule_ignore_index_out_of_bounds_basic",
|
||||||
|
"NllLossModule_mean_basic",
|
||||||
|
"NllLossModule_sum_basic",
|
||||||
|
"NormScalarOptDimKeepDimModule_basic",
|
||||||
|
"NormScalarOptDimModule_basic",
|
||||||
|
"NormalFunctionalModule_basic",
|
||||||
|
"NumToTensorFloatModule_basic",
|
||||||
|
"NumToTensorIntModule_basic",
|
||||||
|
"NumelModule_basic",
|
||||||
|
"NumelZeroRankModule_basic",
|
||||||
|
"PixelShuffleModuleFullDynamic_basic",
|
||||||
|
"PixelShuffleModuleSpatiallyDynamic_basic",
|
||||||
|
"PixelShuffleModuleSpatiallyStatic_basic",
|
||||||
|
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||||
|
"PowIntFloatModule_basic",
|
||||||
|
"PrimMaxIntModule_basic",
|
||||||
|
"PrimMinIntDynamicModule_basic",
|
||||||
|
"PrimMinIntModule_basic",
|
||||||
|
"PrimsConvertElementTypeModule_basic",
|
||||||
|
"PrimsSqueezeEmptyDimensionsModule_basic",
|
||||||
|
"PrimsSqueezeModule_basic",
|
||||||
|
"PrimsViewOfModule_basic",
|
||||||
|
"PrimsViewOfZeroRankModule_basic",
|
||||||
|
"RandIntDtypeModule_basic",
|
||||||
|
"RandIntModule_basic",
|
||||||
|
"RandIntPinMemoryModule_basic",
|
||||||
|
"ReshapeAliasCollapseModule_basic",
|
||||||
|
"ReshapeAliasExpandModule_basic",
|
||||||
|
"ReshapeExpandModule_basic",
|
||||||
|
"ScalarConstantTupleModule_basic",
|
||||||
|
"ScalarImplicitFloatModule_basic",
|
||||||
|
"ScalarImplicitIntModule_basic",
|
||||||
|
"ScatterReduceFloatMaxModule",
|
||||||
|
"ScatterReduceFloatMeanModule",
|
||||||
|
"ScatterReduceFloatMeanModuleIncludeSelf",
|
||||||
|
"ScatterReduceFloatMinModule",
|
||||||
|
"ScatterReduceFloatProdModule",
|
||||||
|
"ScatterReduceFloatSumModule",
|
||||||
|
"ScatterReduceIntMaxModule",
|
||||||
|
"ScatterReduceIntMeanModule",
|
||||||
|
"ScatterReduceIntMeanModuleIncludeSelf",
|
||||||
|
"ScatterReduceIntMinModule",
|
||||||
|
"ScatterReduceIntProdModule",
|
||||||
|
"ScatterReduceIntSumModule",
|
||||||
|
"SelectScattertModule_basic",
|
||||||
|
"SelectScattertStaticModule_basic",
|
||||||
|
"SliceEndSleStartModule_basic",
|
||||||
|
"SliceOutOfUpperBoundIndexModule_basic",
|
||||||
|
"SliceScatterModule_basic",
|
||||||
|
"SliceScatterNegativeDimModule_basic",
|
||||||
|
"SliceScatterNegativeEndModule_basic",
|
||||||
|
"SliceScatterStaticModule_basic",
|
||||||
|
"SliceScatterStepVariationModule_basic",
|
||||||
|
"SliceScatterZeroDimModule_basic",
|
||||||
|
"SliceStartEqEndModule_basic",
|
||||||
|
"SoftmaxBackwardModule_basic",
|
||||||
|
"SortIntListReverse_basic",
|
||||||
|
"SortIntList_basic",
|
||||||
|
"SplitDimDynamicModule_basic",
|
||||||
|
"SplitDimStaticModule_basic",
|
||||||
|
"SqrtIntConstantModule_basic",
|
||||||
|
"SqrtIntModule_basic",
|
||||||
|
"StdCorrectionEmptyDimModule_basic",
|
||||||
|
"StdDimEmptyDimModule_basic",
|
||||||
|
"SubFloatModule_basic",
|
||||||
|
"SubIntModule_basic",
|
||||||
|
"TanhBackward_basic",
|
||||||
|
"TensorToBoolZeroRank_basic",
|
||||||
|
"TensorToBool_basic",
|
||||||
|
"TensorToFloatZeroRank_basic",
|
||||||
|
"TensorToFloat_basic",
|
||||||
|
"TensorToIntZeroRank_basic",
|
||||||
|
"TensorToInt_basic",
|
||||||
|
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||||
|
"Threshold1dFloatModule_basic",
|
||||||
|
"Threshold1dIntI32Module_basic",
|
||||||
|
"Threshold1dIntModule_basic",
|
||||||
|
"Threshold2dFloatModule_basic",
|
||||||
|
"Threshold2dIntModule_basic",
|
||||||
|
"Threshold3dFloatModule_basic",
|
||||||
|
"Threshold3dIntModule_basic",
|
||||||
|
"ThresholdBackward1dFloatModule_basic",
|
||||||
|
"ThresholdBackward1dIntModule_basic",
|
||||||
|
"ThresholdBackward1dMixedModule_basic",
|
||||||
|
"ThresholdBackward2dFloatModule_basic",
|
||||||
|
"ThresholdBackward2dIntModule_basic",
|
||||||
|
"ThresholdBackward2dMixedModule_basic",
|
||||||
|
"ThresholdBackward3dFloatModule_basic",
|
||||||
|
"ThresholdBackward3dIntModule_basic",
|
||||||
|
"ThresholdBackward3dMixedModule_basic",
|
||||||
|
"ToCopyBoolDTypeStaticModule_basic",
|
||||||
|
"ToCopyModule_basic",
|
||||||
|
"ToCopyWithDTypeFalsePinMemoryModule_basic",
|
||||||
|
"ToCopyWithDTypeModule_basic",
|
||||||
|
"TorchPrimLoopForLikeModule_basic",
|
||||||
|
"TorchPrimLoopWhileLikeModule_basic",
|
||||||
|
"TraceModule_basic",
|
||||||
|
"TraceModule_empty",
|
||||||
|
"TraceModule_nonsquare",
|
||||||
|
"TraceSignedIntModule_basic",
|
||||||
|
"TraceUnsignedIntModule_basic",
|
||||||
|
"TraceUnsignedIntModule_empty",
|
||||||
|
"UniformModule_basic",
|
||||||
|
"UniformNoCorrelationModule_basic",
|
||||||
|
"UniformStaticShapeModule_basic",
|
||||||
|
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||||
|
"UnsafeView1DFoldModule_basic",
|
||||||
|
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
|
"UnsafeViewCollapseModule_basic",
|
||||||
|
"UnsafeViewDynamicExpandModule_basic",
|
||||||
|
"UnsafeViewDynamicExpandWithAtenSizeIntModule_basic",
|
||||||
|
"UnsafeViewExpandModule_basic",
|
||||||
|
"UpSampleNearest2dBackwardScalesNone_basic",
|
||||||
|
"UpSampleNearest2dBackward_basic",
|
||||||
|
"UpSampleNearest2dDynamicFactor_basic",
|
||||||
|
"UpSampleNearest2dStaticFactor_basic",
|
||||||
|
"UpSampleNearest2d_basic",
|
||||||
|
"VarCorrectionEmptyDimModule_basic",
|
||||||
|
"VarDimEmptyDimModule_basic",
|
||||||
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
|
"ViewCollapseModule_basic",
|
||||||
|
"ViewDynamicExpandCollapseModule_basic",
|
||||||
|
"ViewDynamicExpandCollapseWithAtenIntModule_basic",
|
||||||
|
"ViewDynamicExpandModule_basic",
|
||||||
|
"ViewDynamicExpandWithAtenSizeIntModule_basic",
|
||||||
|
"ViewExpandDynamicDimModule_basic",
|
||||||
|
"ViewNoChange1dModule_basic",
|
||||||
|
"ViewNoChange2dModule_basic",
|
||||||
|
"ViewNoChange3dModule_basic",
|
||||||
|
"_Convolution2DAllFalseModule_basic",
|
||||||
|
"_Convolution2DBenchmarkModule_basic",
|
||||||
|
"_Convolution2DCudnnModule_basic",
|
||||||
|
"_Convolution2DDeterministicModule_basic",
|
||||||
|
"_Convolution2DTF32Module_basic",
|
||||||
|
"_ConvolutionDeprecated2DAllFalseModule_basic",
|
||||||
|
"_ConvolutionDeprecated2DBenchmarkModule_basic",
|
||||||
|
"_ConvolutionDeprecated2DCudnnModule_basic",
|
||||||
|
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||||
|
"_SoftmaxModule_basic",
|
||||||
|
|
||||||
|
# Failure - onnx_import
|
||||||
|
"BucketizeTensorFloatModule_basic",
|
||||||
|
"BucketizeTensorModule_basic",
|
||||||
|
"BucketizeTensorOutInt32RightModule_basic",
|
||||||
|
"BucketizeTensorStaticFloatModule_basic",
|
||||||
|
"BucketizeTensorStaticModule_basic",
|
||||||
|
"DiagonalModule_basic",
|
||||||
|
"DiagonalModule_nonsquare",
|
||||||
|
"DiagonalModule_transposed",
|
||||||
|
"DiagonalModule_with_dims",
|
||||||
|
"DiagonalModule_with_dims_and_offset",
|
||||||
|
"DiagonalModule_with_negative_dims",
|
||||||
|
"DiagonalModule_with_offset",
|
||||||
|
"ElementwiseClampMaxModule_basic",
|
||||||
|
"ElementwiseClampMinModule_basic",
|
||||||
|
"ElementwiseClampMinTensorFloatModule_basic",
|
||||||
|
"ElementwiseClampMinTensorIntModule_basic",
|
||||||
|
"ElementwiseClampModule_basic",
|
||||||
|
"ElementwiseClampTensorFloatModule_basic",
|
||||||
|
"ElementwiseClampTensorInt8Module_basic",
|
||||||
|
"ElementwiseClampTensorIntModule_basic",
|
||||||
|
"EmptyLikeMemoryFormatModule_basic",
|
||||||
|
"EmptyLikeModule_defaultDtype",
|
||||||
|
"EmptyLikeModule_falsePinMemory",
|
||||||
|
"EmptyLikeModule_float",
|
||||||
|
"EmptyLikeModule_int",
|
||||||
|
"Fill_TensorFloat32WithFloat32_basic",
|
||||||
|
"Fill_TensorFloat32WithFloat64_basic",
|
||||||
|
"Fill_TensorFloat32WithInt64_basic",
|
||||||
|
"Fill_TensorFloat64WithFloat32_basic",
|
||||||
|
"Fill_TensorFloat64WithFloat64_basic",
|
||||||
|
"Fill_TensorFloat64WithInt64_basic",
|
||||||
|
"FullLikeModuleDefaultDtype_basic",
|
||||||
|
"FullLikeModuleFalsePinMemory_basic",
|
||||||
|
"FullLikeModuleFloat2D_basic",
|
||||||
|
"FullLikeModuleFloat3D_basic",
|
||||||
|
"FullLikeModuleInt2D_basic",
|
||||||
|
"FullLikeModuleInt3D_basic",
|
||||||
|
"HBC_basic",
|
||||||
|
"IndexPut1DFloatAccumulateModule_basic",
|
||||||
|
"IndexPut1DIntAccumulateModule_basic",
|
||||||
|
"IndexPut2DFloatAccumulateModule_basic",
|
||||||
|
"IndexPut2DIntAccumulateModule_basic",
|
||||||
|
"IndexPut3DFloatAccumulateModule_basic",
|
||||||
|
"IndexPut3DIntAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin1DFloatAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin1DIntAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin2DFloatAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin2DIntAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin3DFloatAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin3DIntAccumulateModule_basic",
|
||||||
|
"NormalizeModule_basic",
|
||||||
|
"OnesLikeModule_defaultDtype",
|
||||||
|
"OnesLikeModule_falsePinMemory",
|
||||||
|
"OnesLikeModule_float",
|
||||||
|
"OnesLikeModule_int",
|
||||||
|
"PadWithNoneValModule_basic",
|
||||||
|
"QuantizedMLP_basic",
|
||||||
|
"RandModule_basic",
|
||||||
|
"ScatterReduceFloatMaxModuleIncludeSelf",
|
||||||
|
"ScatterReduceFloatMinModuleIncludeSelf",
|
||||||
|
"ScatterReduceFloatProdModuleIncludeSelf",
|
||||||
|
"ScatterReduceFloatSumModuleIncludeSelf",
|
||||||
|
"ScatterReduceIntMaxModuleIncludeSelf",
|
||||||
|
"ScatterReduceIntMinModuleIncludeSelf",
|
||||||
|
"ScatterReduceIntProdModuleIncludeSelf",
|
||||||
|
"ScatterReduceIntSumModuleIncludeSelf",
|
||||||
|
"TileBigDimsSizeModule_basic",
|
||||||
|
"TileSmallDimsSizeModule_basic",
|
||||||
|
"UpSampleNearest2dDynamicSize_basic",
|
||||||
|
"UpSampleNearest2dStaticSize_basic",
|
||||||
|
"ZeroFloat32Module_basic",
|
||||||
|
"ZeroInt32Module_basic",
|
||||||
|
"ZeroInt64Module_basic",
|
||||||
|
"ZerosLikeModule_defaultDtype",
|
||||||
|
"ZerosLikeModule_falsePinMemory",
|
||||||
|
"ZerosLikeModule_float",
|
||||||
|
"ZerosLikeModule_int",
|
||||||
|
|
||||||
|
# Failure - onnx_lowering
|
||||||
|
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
|
||||||
|
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
|
||||||
|
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
|
||||||
|
"AtenMmFloatTypes_basic",
|
||||||
|
"AtenMmIntTypes_basic",
|
||||||
|
"AtenTrilModule_basic",
|
||||||
|
"AtenTrilWithNegDiagonalModule_basic",
|
||||||
|
"AtenTrilWithPosDiagonalModule_basic",
|
||||||
|
"AtenTriuModule_basic",
|
||||||
|
"AtenTriuWithNegDiagonalModule_basic",
|
||||||
|
"AtenTriuWithPosDiagonalModule_basic",
|
||||||
|
"AvgPool1dFloatModule_basic",
|
||||||
|
"AvgPool1dIntModule_basic",
|
||||||
|
"AvgPool1dStaticModule_basic",
|
||||||
|
"AvgPool2dCeilModeTrueModule_basic",
|
||||||
|
"AvgPool2dDivisorOverrideModule_basic",
|
||||||
|
"AvgPool2dFloatModule_basic",
|
||||||
|
"AvgPool2dIntModule_basic",
|
||||||
|
"AvgPool2dStaticModule_basic",
|
||||||
|
"BernoulliFloatModule_basic",
|
||||||
|
"BernoulliModule_basic",
|
||||||
|
"BernoulliPModule_basic",
|
||||||
|
"BernoulliTensorModule_basic",
|
||||||
|
"ConstantPad2dStaticModule_basic",
|
||||||
|
"ConstantPadNdModule_basic",
|
||||||
|
"ConstantPadNdPartialStaticModule_basic",
|
||||||
|
"ConstantPadNdStaticModule_basic",
|
||||||
|
"CrossEntropyLossModule_basic",
|
||||||
|
"CrossEntropyLossNoReductionModule_basic",
|
||||||
|
"DropoutTrainModule_basic",
|
||||||
|
"DropoutTrainStaticShapeModule_basic",
|
||||||
|
"EinsumStaticContractRhsModule_basic",
|
||||||
|
"EinsumStaticFourDimensionModule_basic",
|
||||||
|
"EinsumStaticModule_basic",
|
||||||
|
"ElementwiseMishModule_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_basic",
|
||||||
|
"ElementwiseToDtypeI64ToI8Module_basic",
|
||||||
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
|
"GroupNormModule_basic",
|
||||||
|
"GroupNormNoWeightAndBiasModule_basic",
|
||||||
|
"HardswishModule_basic",
|
||||||
|
"HardswishRandomModule_basic",
|
||||||
|
"IndexPut1DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPut1DIntNonAccumulateModule_basic",
|
||||||
|
"IndexPut2DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPut2DIntNonAccumulateModule_basic",
|
||||||
|
"IndexPut3DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPut3DIntNonAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin1DIntNonAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin2DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin2DIntNonAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin3DFloatNonAccumulateModule_basic",
|
||||||
|
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
|
||||||
|
"LogSoftmaxIntModule_basic",
|
||||||
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
|
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
|
"MaxPool2dWithIndicesStaticModule_basic",
|
||||||
|
"MmDagModule_basic",
|
||||||
|
"MmModule_basic",
|
||||||
|
"MmModule_chained",
|
||||||
|
"MmTanhModule_basic",
|
||||||
|
"MobilenetV3Module_basic",
|
||||||
|
"MseLossSumReductionWithDifferentElemTypeModule_basic",
|
||||||
|
"NativeDropoutTrainModule_basic",
|
||||||
|
"NativeDropoutTrainStaticShapeModule_basic",
|
||||||
|
"OneHotModule_basic",
|
||||||
|
"PadModule_basic",
|
||||||
|
"RandIntLowDtypeModule_basic",
|
||||||
|
"RandIntLowModule_basic",
|
||||||
|
"RandLikeDtypeModule_basic",
|
||||||
|
"RandLikeModule_basic",
|
||||||
|
"RandnDtypeDeviceModule_basic",
|
||||||
|
"RandnGeneratorF64Module_basic",
|
||||||
|
"RandnGeneratorModule_basic",
|
||||||
|
"RandnLikeDtypeModule_basic",
|
||||||
|
"RandnLikeModule_basic",
|
||||||
|
"RandnModule_basic",
|
||||||
|
"ReduceL1NormModule_basic",
|
||||||
|
"ReduceL1NormWithDTypeModule_basic",
|
||||||
|
"ReduceL2NormModule_basic",
|
||||||
|
"ReduceL3NormAllDimsModule_basic",
|
||||||
|
"ReduceL3NormKeepDimModule_basic",
|
||||||
|
"ReduceProdDimIntFloatModule_basic",
|
||||||
|
"ReduceSumDtypeFloatModule_basic",
|
||||||
|
"ReduceSumDtypeIntModule_basic",
|
||||||
|
"ReduceSumElementTypeBoolModule_basic",
|
||||||
|
"ReduceSumFloatModule_basic",
|
||||||
|
"ReduceSumSignedIntModule_basic",
|
||||||
|
"ReduceSumUnsignedIntModule_basic",
|
||||||
|
"ReflectionPad1dModule2dInput_Right",
|
||||||
|
"ReflectionPad1dModule2dInput_basic",
|
||||||
|
"ReflectionPad1dModule3dInput_Left",
|
||||||
|
"ReflectionPad1dModule3dInput_basic",
|
||||||
|
"ReflectionPad2dModule_Bottom",
|
||||||
|
"ReflectionPad2dModule_Left",
|
||||||
|
"ReflectionPad2dModule_Right",
|
||||||
|
"ReflectionPad2dModule_Top",
|
||||||
|
"ReflectionPad2dModule_basic",
|
||||||
|
"ReplicationPad2dModule_basic",
|
||||||
|
"ReplicationPad2dModule_bottom0",
|
||||||
|
"ReplicationPad2dModule_left0",
|
||||||
|
"ReplicationPad2dModule_right0",
|
||||||
|
"ReplicationPad2dModule_top0",
|
||||||
|
"ScatterSrcModule_basic",
|
||||||
|
"ScatterSrcStaticModule_basic",
|
||||||
|
"ScatterValueFloatModule_basic",
|
||||||
|
"ScatterValueIntModule_basic",
|
||||||
|
"SoftplusModule_basic",
|
||||||
|
"SortTensorDescending_basic",
|
||||||
|
"SortTensorInteger_basic",
|
||||||
|
"SortTensorNegativeDimension_basic",
|
||||||
|
"SortTensorSpecificDimension_basic",
|
||||||
|
"SortTensor_basic",
|
||||||
|
"SqueezeModule_allUnitDim",
|
||||||
|
"SqueezeModule_broadcast",
|
||||||
|
"SqueezeModule_static",
|
||||||
|
"StdCorrectionAllDimReduceModule_basic",
|
||||||
|
"StdCorrectionKeepDimModule_basic",
|
||||||
|
"StdCorrectionLargeInputModule_basic",
|
||||||
|
"StdCorrectionModule_basic",
|
||||||
|
"StdCorrectionNoneModule_basic",
|
||||||
|
"StdCorrectionSingleDimReduceModule_basic",
|
||||||
|
"StdDimKeepDimFalseModule_basic",
|
||||||
|
"StdDimKeepDimTrueModule_basic",
|
||||||
|
"StdDimNoneDimModule_basic",
|
||||||
|
"StdUnbiasedModule_basic",
|
||||||
|
"TriuBroadcastModule_basic",
|
||||||
|
"TriuModule_basic",
|
||||||
|
"TypeConversionI1ToI32Module_basic",
|
||||||
|
"TypeConversionI64ToI32Module_basic",
|
||||||
|
"UnflattenIntNegativeOneDimStaticModule_basic",
|
||||||
|
"UnflattenIntNegativeOneSizeStaticModule_basic",
|
||||||
|
"UnflattenIntStaticModule_basic",
|
||||||
|
"UnflattenStaticModule_basic",
|
||||||
|
"VarCorrectionAllDimReduceModule_basic",
|
||||||
|
"VarCorrectionKeepDimModule_basic",
|
||||||
|
"VarCorrectionLargeInputModule_basic",
|
||||||
|
"VarCorrectionModule_basic",
|
||||||
|
"VarCorrectionNoneModule_basic",
|
||||||
|
"VarCorrectionSingleDimReduceModule_basic",
|
||||||
|
"VarDimAllDimReduceModule_basic",
|
||||||
|
"VarDimModule_basic",
|
||||||
|
"VarDimMultiDimModule_basic",
|
||||||
|
"VarDimNegativeModule_basic",
|
||||||
|
"VarDimNoneDimModule_basic",
|
||||||
|
"VarDimSingleDimModule_basic",
|
||||||
|
"VarDimUnbiasedModule_basic",
|
||||||
|
"VarMeanCorrectionModule_basic",
|
||||||
|
"VarMeanCorrectionNoneModule_basic",
|
||||||
|
"VarMeanDimModule_basic",
|
||||||
|
"VarMeanUnbiasedModule_basic",
|
||||||
|
"VarUnbiasedModule_basic",
|
||||||
|
"_LogSoftmaxModuleStable_basic",
|
||||||
|
"_LogSoftmaxModule_basic",
|
||||||
|
|
||||||
|
# Failure - cast_error
|
||||||
|
"MeanDimNoneDimModule_basic",
|
||||||
|
"MeanDtypeModule_basic",
|
||||||
|
"MeanDynamicSizesModule_basic",
|
||||||
|
"MeanModule_basic",
|
||||||
|
"MseLossMeanReductionModule_basic",
|
||||||
|
"StdBiasedModule_basic",
|
||||||
|
"VarBiasedModule_basic",
|
||||||
|
"VarMeanBiasedModule_basic",
|
||||||
|
|
||||||
|
# Failure - constant_int
|
||||||
|
"ReduceMinAlongDimNegative_basic",
|
||||||
|
"ReduceMinAlongDimSignedInt_basic",
|
||||||
|
"ReduceMinAlongDim_basic",
|
||||||
|
"ReduceMinFloatModule_basic",
|
||||||
|
"ReduceMinKeepDimReturnBoth_basic",
|
||||||
|
"ReduceMinSignedIntModule_basic",
|
||||||
|
"ReduceMinUnsignedIntModule_basic",
|
||||||
|
"SplitTensorGetItem_Module_basic",
|
||||||
|
"SplitTensorLastSmallerModule_basic",
|
||||||
|
"SplitTensorListUnpackModule_basic",
|
||||||
|
"SplitTensorNegativeDimModule_basic",
|
||||||
|
"SplitWithSizesListUnpackModule_basic",
|
||||||
|
"UnbindIntGetItem_Module_basic",
|
||||||
|
"UnbindIntListUnpack_Module_basic",
|
||||||
|
|
||||||
|
# Failure - operand_type
|
||||||
|
"ElementwiseAcosIntModule_basic",
|
||||||
|
"ElementwiseAsinIntModule_basic",
|
||||||
|
"ElementwiseAtanTensorIntModule_basic",
|
||||||
|
"ElementwiseCosIntModule_basic",
|
||||||
|
"ElementwiseErfIntModule_basic",
|
||||||
|
"ElementwiseExpIntModule_basic",
|
||||||
|
"ElementwiseLog10IntModule_basic",
|
||||||
|
"ElementwiseLog2IntModule_basic",
|
||||||
|
"ElementwiseLogIntModule_basic",
|
||||||
|
"ElementwiseSinIntModule_basic",
|
||||||
|
"ElementwiseTanIntModule_basic",
|
||||||
|
"ElementwiseUnaryIntModule_basic",
|
||||||
|
|
||||||
|
# Failure - expand_multidim
|
||||||
|
"IndexTensorHackedTwinModule3dInput_basic",
|
||||||
|
"IndexTensorHackedTwinModule_basic",
|
||||||
|
"IndexTensorModule3dInput_basic",
|
||||||
|
"IndexTensorModule_basic",
|
||||||
|
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
|
||||||
|
|
||||||
|
# Failure - rankless_return
|
||||||
|
"ReduceAmaxMultiDim_basic",
|
||||||
|
"ReduceAmaxOutOfOrderDim_basic",
|
||||||
|
"ReduceAmaxSingleDim_basic",
|
||||||
|
"ReduceMaxAllDims_basic",
|
||||||
|
"ReduceMaxAlongDimNegative_basic",
|
||||||
|
"ReduceMaxAlongDimSignedInt_basic",
|
||||||
|
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||||
|
"ReduceMaxAlongDim_basic",
|
||||||
|
"ReduceMaxFloatModule_basic",
|
||||||
|
"ReduceMaxSignedIntModule_basic",
|
||||||
|
"ReduceMaxUnsignedIntModule_basic",
|
||||||
|
|
||||||
|
# Failure - slice_lowering
|
||||||
|
"ScaledDotProductAttentionDifferentModule_basic",
|
||||||
|
"ScaledDotProductAttentionSameModule_basic",
|
||||||
|
|
||||||
|
# Failure - view_lowering
|
||||||
|
"AddSizeIntModule_basic",
|
||||||
|
"ElementwiseFlattenBroadcastModule_basic",
|
||||||
|
"FlattenRank0Module_basic",
|
||||||
|
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
|
||||||
|
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
|
||||||
|
"IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
|
"IndexTensorMultiInputContiguousCenter_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguousDynamic_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
|
||||||
|
"IndexTensorMultiInputNonContiguous_basic",
|
||||||
|
"IndexTensorMultiInputOneDim_basic",
|
||||||
|
"IndexTensorMultiInputThreeIndexers_basic",
|
||||||
|
"IndexTensorMultiInput_basic",
|
||||||
|
"IndexTensorSelectDimModule_basic",
|
||||||
|
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||||
|
"RepeatModule_basic",
|
||||||
|
"SelectIntModule_basic",
|
||||||
|
"SelectIntNegativeDimAndIndexStaticModule_basic",
|
||||||
|
"SliceSingleIdxModule_basic",
|
||||||
|
"ViewFlattenAndExpandModule_basic",
|
||||||
|
"ViewSizeDimFollowedByCollapsedOnesModule_basic",
|
||||||
|
"ViewSizeDimFollowedByExpandedOnesModule_basic",
|
||||||
|
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
|
||||||
|
"ViewSizeDimLedAndFollowedByExpandedOnesModule_basic",
|
||||||
|
"ViewSizeDimLedByCollapsedOnesModule_basic",
|
||||||
|
"ViewSizeDimLedByExpandedOnesModule_basic",
|
||||||
|
|
||||||
|
# Failure - numerical
|
||||||
|
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||||
|
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
||||||
|
"ElementwiseSeluModule_basic",
|
||||||
|
"EmbeddingModule1DIndices_basic",
|
||||||
|
"EmbeddingModuleI32Static_basic",
|
||||||
|
"FlipNegativeIndexModule_basic",
|
||||||
|
"HardsigmoidModule_basic",
|
||||||
|
"HardsigmoidRandomModule_basic",
|
||||||
|
"IndexSelectDynamicIndexSizeModule_basic",
|
||||||
|
"IndexSelectDynamicInputSizeModule_basic",
|
||||||
|
"IndexSelectDynamicModulebasic",
|
||||||
|
"IndexSelectNegativeDimModule_basic",
|
||||||
|
"IndexSelectSingleIdxModule_basic",
|
||||||
|
"IndexSelectTwoIdxModule_basic",
|
||||||
|
"IndexSelectWholeDimensionModule_basic",
|
||||||
|
"IndexTensorStaticModule_basic",
|
||||||
|
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||||
|
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||||
|
"ResNet18Module_basic",
|
||||||
|
"SliceCopyEndGreaterThanDimSize_Module_basic",
|
||||||
|
"SliceCopyNegative_Module_basic",
|
||||||
|
"SliceCopyNonZeroDim_Module_basic",
|
||||||
|
"SliceCopy_Module_basic",
|
||||||
|
"TupleModule_basic",
|
||||||
|
|
||||||
|
# Failure - shape
|
||||||
|
"ArangeStartOutDtypeModule_basic",
|
||||||
|
"ArangeStartOutViewModule_basic",
|
||||||
|
"BroadcastDynamicDimModule_basic",
|
||||||
|
"BroadcastToModule_basic",
|
||||||
|
"EmbeddingModuleF16_basic",
|
||||||
|
"EmbeddingModuleI32_basic",
|
||||||
|
"EmbeddingModuleI64_basic",
|
||||||
|
"ExpandModule_basic",
|
||||||
|
"ReduceAmaxKeepDim_basic",
|
||||||
|
"ReduceMaxKeepDimReturnBoth_basic",
|
||||||
|
"ReduceMaxNegativeDim_basic",
|
||||||
|
"ViewSizeFromOtherTensor_basic",
|
||||||
|
|
||||||
|
# Failure - unknown
|
||||||
|
"ChunkListUnpackUneven_Module_basic",
|
||||||
|
"ChunkListUnpack_Module_basic",
|
||||||
|
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||||
|
"CopyWithDifferentDTypesAndSizesModule_basic",
|
||||||
|
"CopyWithDifferentDTypesModule_basic",
|
||||||
|
"CosineSimilarityStaticBroadcastModule_basic",
|
||||||
|
"CumsumInputDtypeInt32Module_basic",
|
||||||
|
"ElementwiseAtan2TensorIntModule_basic",
|
||||||
|
"ElementwiseDivRoundingModeTruncModule_basic",
|
||||||
|
"ElementwisePreluModule_basic",
|
||||||
|
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||||
|
"ElementwiseWhereScalarModule_basic",
|
||||||
|
"FlattenDynamicModule_basic",
|
||||||
|
"FlipModuleStaticShape_basic",
|
||||||
|
"GluStaticModule_basic",
|
||||||
|
"MaskedFillTensorFloatValueModule_basic",
|
||||||
|
"ReduceAllDimEmpty_basic",
|
||||||
|
"ReduceAllDimFloat_basic",
|
||||||
|
"ReduceAllDimInt_basic",
|
||||||
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
|
"TensorsStackNegativeDimModule_basic",
|
||||||
|
"TensorsStackPromoteDTypeModule_basic",
|
||||||
|
}
|
||||||
|
|
||||||
|
ONNX_CRASHING_SET = {
|
||||||
|
"ElementwiseSigmoidIntModule_basic",
|
||||||
|
"FlipModule_basic",
|
||||||
|
"IndexTensorNegativeIndexModule_basic",
|
||||||
|
"MoveDimIntNegativeIndexModule_basic",
|
||||||
|
"PermuteNegativeIndexModule_basic",
|
||||||
|
"RollModule_basic",
|
||||||
|
"SliceModule_basic",
|
||||||
|
"SliceNegIdxModule_basic",
|
||||||
|
"SliceOutOfLowerBoundEndIndexModule_basic",
|
||||||
|
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||||
|
"SliceSizeTwoStepModule_basic",
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
from .lazy_tensor_core import LazyTensorCoreTestConfig
|
from .lazy_tensor_core import LazyTensorCoreTestConfig
|
||||||
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
|
||||||
from .native_torch import NativeTorchTestConfig
|
from .native_torch import NativeTorchTestConfig
|
||||||
|
from .onnx_backend import OnnxBackendTestConfig
|
||||||
from .torchscript import TorchScriptTestConfig
|
from .torchscript import TorchScriptTestConfig
|
||||||
from .stablehlo_backend import StablehloBackendTestConfig
|
from .stablehlo_backend import StablehloBackendTestConfig
|
||||||
from .tosa_backend import TosaBackendTestConfig
|
from .tosa_backend import TosaBackendTestConfig
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import io
|
||||||
|
import onnx
|
||||||
|
import torch
|
||||||
|
import torch_mlir
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.onnx_backends.abc import OnnxBackend
|
||||||
|
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
|
||||||
|
from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders
|
||||||
|
from .utils import (
|
||||||
|
recursively_convert_to_numpy,
|
||||||
|
recursively_convert_from_numpy,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch_mlir.extras import onnx_importer
|
||||||
|
from torch_mlir.dialects import torch as torch_d
|
||||||
|
from torch_mlir.ir import Context, Module
|
||||||
|
|
||||||
|
|
||||||
|
def import_onnx(contents):
|
||||||
|
# Import the ONNX model proto from the file contents:
|
||||||
|
raw_model = onnx.load_from_string(contents)
|
||||||
|
model_proto = onnx.shape_inference.infer_shapes(raw_model)
|
||||||
|
|
||||||
|
# Import the ONNX module into an MLIR module:
|
||||||
|
context = Context()
|
||||||
|
torch_d.register_dialect(context)
|
||||||
|
model_info = onnx_importer.ModelInfo(model_proto)
|
||||||
|
m = model_info.create_module(context=context)
|
||||||
|
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m.operation)
|
||||||
|
imp.import_all()
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def convert_onnx(model, inputs):
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
|
||||||
|
# Process the type information so we export with the dynamic shape information
|
||||||
|
examples = []
|
||||||
|
input_names = []
|
||||||
|
dynamic_tensors = {}
|
||||||
|
for (index, arg) in enumerate(inputs):
|
||||||
|
shape = map(lambda d : d if d >= 0 else 1, arg.shape)
|
||||||
|
shape = tuple(shape)
|
||||||
|
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
|
||||||
|
|
||||||
|
input_name = "input_{}".format(index)
|
||||||
|
input_names.append(input_name)
|
||||||
|
|
||||||
|
dynamic_dims = {}
|
||||||
|
for (dimindex, dim) in enumerate(arg.shape):
|
||||||
|
if (dim < 0):
|
||||||
|
dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex)
|
||||||
|
|
||||||
|
if (dynamic_dims):
|
||||||
|
dynamic_tensors[input_name] = dynamic_dims
|
||||||
|
|
||||||
|
|
||||||
|
examples=tuple(examples)
|
||||||
|
torch.onnx.export(model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors)
|
||||||
|
buffer = buffer.getvalue()
|
||||||
|
return import_onnx(buffer)
|
||||||
|
|
||||||
|
class OnnxBackendTestConfig(TestConfig):
|
||||||
|
"""Base class for TestConfig's that are implemented with ONNX.
|
||||||
|
|
||||||
|
This class handles all the common lowering that torch-mlir does before
|
||||||
|
reaching the ONNX abstraction level.
|
||||||
|
"""
|
||||||
|
def __init__(self, backend: OnnxBackend, use_make_fx: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.backend = backend
|
||||||
|
self.use_make_fx = use_make_fx
|
||||||
|
|
||||||
|
def compile(self, program: torch.nn.Module) -> Any:
|
||||||
|
example_args = convert_annotations_to_placeholders(program.forward)
|
||||||
|
onnx_module = convert_onnx(program, example_args)
|
||||||
|
compiled_module = self.backend.compile(onnx_module)
|
||||||
|
return compiled_module
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run(self, artifact: Any, trace: Trace) -> Trace:
|
||||||
|
backend_module = self.backend.load(artifact)
|
||||||
|
result: Trace = []
|
||||||
|
for item in trace:
|
||||||
|
numpy_inputs = recursively_convert_to_numpy(item.inputs)
|
||||||
|
outputs = getattr(backend_module, "main_graph")(*numpy_inputs)
|
||||||
|
output = recursively_convert_from_numpy(outputs)
|
||||||
|
result.append(
|
||||||
|
TraceItem(symbol=item.symbol,
|
||||||
|
inputs=item.inputs,
|
||||||
|
output=output))
|
||||||
|
return result
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import abc
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir.ir import Module
|
||||||
|
|
||||||
|
# A type shared between the result of `OnnxBackend.compile` and the
|
||||||
|
# input to `OnnxBackend.load`. Each backend will likely have a
|
||||||
|
# different definition of this type.
|
||||||
|
CompiledArtifact = TypeVar('CompiledArtifact')
|
||||||
|
|
||||||
|
# A wrapper around a backend-specific loaded program representation
|
||||||
|
# that uniformly translates the `x.method(...)` interface expected of
|
||||||
|
# Torch modules into appropriate lower-level operations.
|
||||||
|
Invoker = TypeVar('Invoker')
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxBackend(abc.ABC):
|
||||||
|
"""The interface to an ONNX backend.
|
||||||
|
|
||||||
|
Backends are recommended to raise meaningful exceptions in case of error,
|
||||||
|
ideally with easy reproduction instructions.
|
||||||
|
"""
|
||||||
|
@abc.abstractmethod
|
||||||
|
def compile(self, module: Module) -> CompiledArtifact:
|
||||||
|
"""Compile the provided MLIR module into a compiled artifact.
|
||||||
|
|
||||||
|
The module adheres to the ONNX backend contract
|
||||||
|
(see the VerifyOnnxBackendContract pass).
|
||||||
|
|
||||||
|
The compiled artifact can be any type, but must be correctly
|
||||||
|
interpreted by the `load` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load(self, artifact: CompiledArtifact) -> Invoker:
|
||||||
|
"""Load the compiled artifact into a uniformly invokable form.
|
||||||
|
|
||||||
|
The compiled artifact is the result of a previous call to `compile`.
|
||||||
|
|
||||||
|
See the description of `Invoker` for the requirements on the returned
|
||||||
|
type.
|
||||||
|
"""
|
|
@ -0,0 +1,65 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
|
||||||
|
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||||
|
from torch_mlir.ir import *
|
||||||
|
from torch_mlir.passmanager import *
|
||||||
|
from torch_mlir.torchscript import OutputType
|
||||||
|
from torch_mlir.torchscript import _lower_mlir_module
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
|
||||||
|
|
||||||
|
from .abc import OnnxBackend
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LinalgOnTensorsOnnxBackend",
|
||||||
|
]
|
||||||
|
|
||||||
|
# The pipeline of func.func passes that lower the ONNX backend contract to the
|
||||||
|
# Linalg-on-Tensors backend contract accepted by RefBackend.
|
||||||
|
ONNX_TO_TORCH_FUNC_PIPELINE = ",".join([
|
||||||
|
"convert-torch-onnx-to-torch",
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class LinalgOnTensorsOnnxBackend(OnnxBackend):
|
||||||
|
"""Main entry-point for the linalg-on-tensors based ONNX backend.
|
||||||
|
|
||||||
|
This currently uses the linalg-on-tensors RefBackend for actual execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.refbackend = RefBackendLinalgOnTensorsBackend()
|
||||||
|
|
||||||
|
def compile(self, imported_module: Module):
|
||||||
|
"""Compiles an imported module that satisfied the ONNX backend contract.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imported_module: The MLIR module consisting of ONNX operations wrapped by
|
||||||
|
torch.operator.
|
||||||
|
Returns:
|
||||||
|
An opaque, backend specific compiled artifact object that can be
|
||||||
|
passed to `load`.
|
||||||
|
"""
|
||||||
|
run_pipeline_with_repro_report(
|
||||||
|
imported_module,
|
||||||
|
f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))",
|
||||||
|
"Lowering Onnx backend contract to Linalg-on-Tensors backend contract")
|
||||||
|
|
||||||
|
run_pipeline_with_repro_report(
|
||||||
|
imported_module,
|
||||||
|
f"builtin.module(torch-lower-to-backend-contract)",
|
||||||
|
"Lowering TorchFX IR -> Torch Backend IR",
|
||||||
|
)
|
||||||
|
|
||||||
|
imported_module = _lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
|
||||||
|
compiled_module = self.refbackend.compile(imported_module)
|
||||||
|
return compiled_module
|
||||||
|
|
||||||
|
def load(self, module):
|
||||||
|
"""Loads a compiled artifact into the runtime."""
|
||||||
|
return self.refbackend.load(module)
|
|
@ -102,7 +102,7 @@ class ModelInfo:
|
||||||
def create_module(self, context: Optional[Context] = None) -> Operation:
|
def create_module(self, context: Optional[Context] = None) -> Operation:
|
||||||
if not context:
|
if not context:
|
||||||
context = Context()
|
context = Context()
|
||||||
module_op = Module.create(Location.unknown(context)).operation
|
module_op = Module.create(Location.unknown(context))
|
||||||
# TODO: Populate module level metadata from the ModelProto
|
# TODO: Populate module level metadata from the ModelProto
|
||||||
return module_op
|
return module_op
|
||||||
|
|
||||||
|
@ -334,7 +334,8 @@ class NodeImporter:
|
||||||
f"This likely means that this is a special node which requires specific "
|
f"This likely means that this is a special node which requires specific "
|
||||||
f"handling in the importer: {onnx_attr}"
|
f"handling in the importer: {onnx_attr}"
|
||||||
)
|
)
|
||||||
attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc)
|
result = handler(onnx_attr, self._cc)
|
||||||
|
attrs[f"torch.onnx.{onnx_attr.name}"] = result
|
||||||
|
|
||||||
def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value:
|
def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value:
|
||||||
# If an explicitly specified name is given, use that; otherwise, pick
|
# If an explicitly specified name is given, use that; otherwise, pick
|
||||||
|
@ -502,9 +503,10 @@ class ContextCache:
|
||||||
if tp.HasField("raw_data"):
|
if tp.HasField("raw_data"):
|
||||||
# Conveniently, DenseResourceElementsAttr shares the raw data
|
# Conveniently, DenseResourceElementsAttr shares the raw data
|
||||||
# format. We just give it maximum numeric alignment.
|
# format. We just give it maximum numeric alignment.
|
||||||
return DenseResourceElementsAttr.get_from_buffer(
|
resource = DenseResourceElementsAttr.get_from_buffer(
|
||||||
tp.raw_data, self._sanitize_name(tp.name), tensor_type, alignment=8
|
tp.raw_data, self._sanitize_name(tp.name), tensor_type, alignment=8
|
||||||
)
|
)
|
||||||
|
return resource
|
||||||
else:
|
else:
|
||||||
# We have to do a data type specific instantiation from proto fields.
|
# We have to do a data type specific instantiation from proto fields.
|
||||||
# Since this is typically used for small tensor constants, we instantiate
|
# Since this is typically used for small tensor constants, we instantiate
|
||||||
|
|
|
@ -34,7 +34,7 @@ def main(args: argparse.Namespace):
|
||||||
context = Context()
|
context = Context()
|
||||||
torch_d.register_dialect(context)
|
torch_d.register_dialect(context)
|
||||||
model_info = onnx_importer.ModelInfo(model_proto)
|
model_info = onnx_importer.ModelInfo(model_proto)
|
||||||
m = model_info.create_module(context=context)
|
m = model_info.create_module(context=context).operation
|
||||||
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
|
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
|
||||||
imp.import_all()
|
imp.import_all()
|
||||||
if not args.no_verify:
|
if not args.no_verify:
|
||||||
|
|
|
@ -326,7 +326,7 @@ class ImportSmokeTest(unittest.TestCase):
|
||||||
model_info = onnx_importer.ModelInfo(
|
model_info = onnx_importer.ModelInfo(
|
||||||
self.load_onnx_model(ONNX_TEST_DATA_DIR / rel_path),
|
self.load_onnx_model(ONNX_TEST_DATA_DIR / rel_path),
|
||||||
)
|
)
|
||||||
m = model_info.create_module(context=context)
|
m = model_info.create_module(context=context).operation
|
||||||
try:
|
try:
|
||||||
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
|
imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m)
|
||||||
imp.import_all()
|
imp.import_all()
|
||||||
|
|
Loading…
Reference in New Issue