[onnx] Update the importer to create a `none` for missing operands (#2931)

Some operands are optional so we require a placeholder for missing
operands. We invent an `onnx.None` operation as our placeholder.
pull/2936/head
Rob Suderman 2024-02-20 09:30:30 -08:00 committed by GitHub
parent 4446fa00d8
commit 13553d49c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 61 deletions

View File

@ -2184,6 +2184,7 @@ ONNX_XFAIL_SET = {
"ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseWhereScalarModule_basic",
"FlattenDynamicModule_basic",
"FlipModule_basic",
"FlipModuleStaticShape_basic",
"GluStaticModule_basic",
"MaskedFillTensorFloatValueModule_basic",
@ -2193,17 +2194,9 @@ ONNX_XFAIL_SET = {
"ReduceMinAlongDimUnsignedInt_basic",
"TensorsStackNegativeDimModule_basic",
"TensorsStackPromoteDTypeModule_basic",
}
ONNX_CRASHING_SET = {
"FlipModule_basic",
"IndexTensorNegativeIndexModule_basic",
"MoveDimIntNegativeIndexModule_basic",
"PermuteNegativeIndexModule_basic",
"RollModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundEndIndexModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceSizeTwoStepModule_basic",
}
ONNX_CRASHING_SET = { }

View File

@ -258,6 +258,8 @@ class NodeImporter:
# much unused crap.
for init in self._gi.initializer_map.values():
self.import_initializer(init)
self.get_none()
for node in self._gi.graph_proto.node:
self.import_node(node)
@ -272,6 +274,20 @@ class NodeImporter:
with InsertionPoint(self._b), Location.unknown():
func_dialect.ReturnOp(outputs)
def get_none(self):
if '' in self._nv_map:
return self._nv_map['']
with InsertionPoint(self._b), Location.name("onnx_importer.none"):
nne = Operation.create(
name="torch.constant.none",
results=[self._cc.get_none_type()],
operands=[],
attributes={},
).results[0]
self._nv_map[''] = nne
return nne
def import_node(self, node: onnx.NodeProto):
with InsertionPoint(self._b), Location.name(node.name):
op_type = node.op_type
@ -283,7 +299,6 @@ class NodeImporter:
was_handled = getattr(self, special_key)(node)
if was_handled:
return
# General node import.
input_values = []
for input_name in node.input:
@ -449,6 +464,9 @@ class ContextCache:
self._elem_type_map[elem_type] = t
return t
def get_none_type(self):
return IrType.parse("!torch.none", context=self._c)
def get_vtensor_type(
self, dims: tuple[Optional[int]], element_type: IrType
) -> IrType:

View File

@ -102,22 +102,12 @@ TEST_CAST_XFAILS = [
"node_test_castlike_FLOAT_to_STRING_model",
"node_test_castlike_STRING_to_FLOAT_expanded_model",
"node_test_castlike_STRING_to_FLOAT_model",
"node_test_center_crop_pad_crop_axes_chw_expanded_model",
"node_test_center_crop_pad_crop_axes_hwc_expanded_model",
"node_test_center_crop_pad_crop_negative_axes_hwc_expanded_model",
"node_test_clip_default_inbounds_model",
"node_test_clip_default_int8_inbounds_model",
"node_test_clip_default_int8_max_model",
"node_test_clip_default_max_model",
"node_test_constantofshape_float_ones_model",
"node_test_constantofshape_int_shape_zero_model",
"node_test_constantofshape_int_zeros_model",
"node_test_dequantizelinear_e4m3fn_model",
"node_test_dequantizelinear_e4m3fn_zero_point_model",
"node_test_dequantizelinear_e5m2_model",
"node_test_dft_axis_model",
"node_test_dft_inverse_model",
"node_test_dft_model",
"node_test_equal_string_broadcast_model",
"node_test_equal_string_model",
"node_test_gru_defaults_model",
@ -175,8 +165,6 @@ TEST_CAST_XFAILS = [
"node_test_optional_get_element_optional_sequence_model",
"node_test_optional_get_element_optional_tensor_model",
"node_test_optional_get_element_sequence_model",
"node_test_optional_has_element_empty_no_input_name_optional_input_model",
"node_test_optional_has_element_empty_no_input_name_tensor_input_model",
"node_test_optional_has_element_empty_optional_input_model",
"node_test_optional_has_element_optional_input_model",
"node_test_optional_has_element_tensor_input_model",
@ -187,43 +175,6 @@ TEST_CAST_XFAILS = [
"node_test_regex_full_match_basic_model",
"node_test_regex_full_match_email_domain_model",
"node_test_regex_full_match_empty_model",
"node_test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_model",
"node_test_resize_downsample_scales_cubic_align_corners_model",
"node_test_resize_downsample_scales_cubic_antialias_model",
"node_test_resize_downsample_scales_cubic_model",
"node_test_resize_downsample_scales_linear_align_corners_model",
"node_test_resize_downsample_scales_linear_antialias_model",
"node_test_resize_downsample_scales_linear_half_pixel_symmetric_model",
"node_test_resize_downsample_scales_linear_model",
"node_test_resize_downsample_scales_nearest_model",
"node_test_resize_downsample_sizes_cubic_antialias_model",
"node_test_resize_downsample_sizes_cubic_model",
"node_test_resize_downsample_sizes_linear_antialias_model",
"node_test_resize_downsample_sizes_linear_pytorch_half_pixel_model",
"node_test_resize_downsample_sizes_nearest_model",
"node_test_resize_downsample_sizes_nearest_not_larger_model",
"node_test_resize_downsample_sizes_nearest_not_smaller_model",
"node_test_resize_tf_crop_and_resize_axes_2_3_model",
"node_test_resize_tf_crop_and_resize_axes_3_2_model",
"node_test_resize_tf_crop_and_resize_model",
"node_test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_model",
"node_test_resize_upsample_scales_cubic_align_corners_model",
"node_test_resize_upsample_scales_cubic_asymmetric_model",
"node_test_resize_upsample_scales_cubic_model",
"node_test_resize_upsample_scales_linear_align_corners_model",
"node_test_resize_upsample_scales_linear_half_pixel_symmetric_model",
"node_test_resize_upsample_scales_linear_model",
"node_test_resize_upsample_scales_nearest_axes_2_3_model",
"node_test_resize_upsample_scales_nearest_axes_3_2_model",
"node_test_resize_upsample_scales_nearest_model",
"node_test_resize_upsample_sizes_cubic_model",
"node_test_resize_upsample_sizes_nearest_axes_2_3_model",
"node_test_resize_upsample_sizes_nearest_axes_3_2_model",
"node_test_resize_upsample_sizes_nearest_ceil_half_pixel_model",
"node_test_resize_upsample_sizes_nearest_floor_align_corners_model",
"node_test_resize_upsample_sizes_nearest_model",
"node_test_resize_upsample_sizes_nearest_not_larger_model",
"node_test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_model",
"node_test_rnn_seq_length_model",
"node_test_scan9_sum_model",
"node_test_scan_sum_model",
@ -246,7 +197,6 @@ TEST_CAST_XFAILS = [
"node_test_split_to_sequence_1_model",
"node_test_split_to_sequence_2_model",
"node_test_split_to_sequence_nokeepdims_model",
"node_test_stft_model",
"node_test_string_concat_broadcasting_model",
"node_test_string_concat_empty_string_model",
"node_test_string_concat_model",
@ -281,6 +231,9 @@ TEST_CAST_XFAILS = [
]
class ImportSmokeTest(unittest.TestCase):
@classmethod
def setUpClass(cls):