diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8272bc4b0..adfb68b94 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2572,8 +2572,6 @@ ONNX_XFAIL_SET = { "SplitDimStaticModule_basic", "SqrtIntConstantModule_basic", "SqrtIntModule_basic", - "StdCorrectionEmptyDimModule_basic", - "StdDimEmptyDimModule_basic", "SubFloatModule_basic", "SubIntModule_basic", "TanhBackward_basic", @@ -2627,8 +2625,6 @@ ONNX_XFAIL_SET = { "UpSampleNearest2dDynamicFactor_basic", "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2d_basic", - "VarCorrectionEmptyDimModule_basic", - "VarDimEmptyDimModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseModule_basic", "ViewDynamicExpandCollapseModule_basic", @@ -2797,6 +2793,10 @@ ONNX_CRASHING_SET = { # Runtime crash: mismatched size for broadcast "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "StdDimEmptyDimModule_basic", + "StdCorrectionEmptyDimModule_basic", + "VarCorrectionEmptyDimModule_basic", + "VarDimEmptyDimModule_basic", } FX_IMPORTER_TOSA_XFAIL_SET = { diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index f8b10a2a4..9fe292123 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -1098,6 +1098,8 @@ ELEM_TYPE_TO_IR_TYPE_CB = { onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(), onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(), onnx.TensorProto.DataType.STRING: lambda: "!torch.str", + onnx.TensorProto.DataType.UINT4: lambda: IntegerType.get_unsigned(4), + onnx.TensorProto.DataType.INT4: lambda: IntegerType.get_signed(4), # Ommitted: STRING, } @@ -1134,6 +1136,9 @@ ELEM_TYPE_INLINE_TENSOR_PROTO_CB = { ), signless=False, ), + onnx.TensorProto.DataType.UINT8: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.uint8).reshape(tp.dims), signless=False + ), onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get( np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False ), diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index d20c212d0..fa0e2a89d 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -84,7 +84,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: raw_model = onnx.load(args.input_file) else: raw_model = onnx.load(args.input_file, load_external_data=False) - onnx.load_external_data_for_model(raw_model, args.data_dir) + onnx.load_external_data_for_model(raw_model, str(args.data_dir)) if args.opset_version: raw_model = onnx.version_converter.convert_version( diff --git a/test-requirements.txt b/test-requirements.txt index b21e8dfcd..42278b3cb 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,5 +1,5 @@ pillow dill multiprocess -onnx==1.15.0 +onnx==1.16.1 mpmath==1.3.0