mirror of https://github.com/llvm/torch-mlir
Bump Onnx Version to 1.16.1 (#3515)
This commit adds the support for new data types: uint4, and int4 and uint8 tensor protos. Also, it moves some tests from failing to crashing. Fixes https://github.com/llvm/torch-mlir/issues/3507 Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>pull/3517/head
parent
0e71a192d8
commit
2f231f394e
|
@ -2572,8 +2572,6 @@ ONNX_XFAIL_SET = {
|
||||||
"SplitDimStaticModule_basic",
|
"SplitDimStaticModule_basic",
|
||||||
"SqrtIntConstantModule_basic",
|
"SqrtIntConstantModule_basic",
|
||||||
"SqrtIntModule_basic",
|
"SqrtIntModule_basic",
|
||||||
"StdCorrectionEmptyDimModule_basic",
|
|
||||||
"StdDimEmptyDimModule_basic",
|
|
||||||
"SubFloatModule_basic",
|
"SubFloatModule_basic",
|
||||||
"SubIntModule_basic",
|
"SubIntModule_basic",
|
||||||
"TanhBackward_basic",
|
"TanhBackward_basic",
|
||||||
|
@ -2627,8 +2625,6 @@ ONNX_XFAIL_SET = {
|
||||||
"UpSampleNearest2dDynamicFactor_basic",
|
"UpSampleNearest2dDynamicFactor_basic",
|
||||||
"UpSampleNearest2dStaticFactor_basic",
|
"UpSampleNearest2dStaticFactor_basic",
|
||||||
"UpSampleNearest2d_basic",
|
"UpSampleNearest2d_basic",
|
||||||
"VarCorrectionEmptyDimModule_basic",
|
|
||||||
"VarDimEmptyDimModule_basic",
|
|
||||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||||
"ViewCollapseModule_basic",
|
"ViewCollapseModule_basic",
|
||||||
"ViewDynamicExpandCollapseModule_basic",
|
"ViewDynamicExpandCollapseModule_basic",
|
||||||
|
@ -2797,6 +2793,10 @@ ONNX_CRASHING_SET = {
|
||||||
# Runtime crash: mismatched size for broadcast
|
# Runtime crash: mismatched size for broadcast
|
||||||
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
|
||||||
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
|
||||||
|
"StdDimEmptyDimModule_basic",
|
||||||
|
"StdCorrectionEmptyDimModule_basic",
|
||||||
|
"VarCorrectionEmptyDimModule_basic",
|
||||||
|
"VarDimEmptyDimModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
FX_IMPORTER_TOSA_XFAIL_SET = {
|
FX_IMPORTER_TOSA_XFAIL_SET = {
|
||||||
|
|
|
@ -1098,6 +1098,8 @@ ELEM_TYPE_TO_IR_TYPE_CB = {
|
||||||
onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(),
|
onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(),
|
||||||
onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(),
|
onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(),
|
||||||
onnx.TensorProto.DataType.STRING: lambda: "!torch.str",
|
onnx.TensorProto.DataType.STRING: lambda: "!torch.str",
|
||||||
|
onnx.TensorProto.DataType.UINT4: lambda: IntegerType.get_unsigned(4),
|
||||||
|
onnx.TensorProto.DataType.INT4: lambda: IntegerType.get_signed(4),
|
||||||
# Ommitted: STRING,
|
# Ommitted: STRING,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1134,6 +1136,9 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = {
|
||||||
),
|
),
|
||||||
signless=False,
|
signless=False,
|
||||||
),
|
),
|
||||||
|
onnx.TensorProto.DataType.UINT8: lambda tp: DenseElementsAttr.get(
|
||||||
|
np.asarray(tp.int32_data, dtype=np.uint8).reshape(tp.dims), signless=False
|
||||||
|
),
|
||||||
onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get(
|
onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get(
|
||||||
np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False
|
np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False
|
||||||
),
|
),
|
||||||
|
|
|
@ -84,7 +84,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
|
||||||
raw_model = onnx.load(args.input_file)
|
raw_model = onnx.load(args.input_file)
|
||||||
else:
|
else:
|
||||||
raw_model = onnx.load(args.input_file, load_external_data=False)
|
raw_model = onnx.load(args.input_file, load_external_data=False)
|
||||||
onnx.load_external_data_for_model(raw_model, args.data_dir)
|
onnx.load_external_data_for_model(raw_model, str(args.data_dir))
|
||||||
|
|
||||||
if args.opset_version:
|
if args.opset_version:
|
||||||
raw_model = onnx.version_converter.convert_version(
|
raw_model = onnx.version_converter.convert_version(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
pillow
|
pillow
|
||||||
dill
|
dill
|
||||||
multiprocess
|
multiprocess
|
||||||
onnx==1.15.0
|
onnx==1.16.1
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
|
|
Loading…
Reference in New Issue