mirror of https://github.com/llvm/torch-mlir
[onnx] Update the importer to create a `none` for missing operands (#2931)
Some operands are optional so we require a placeholder for missing operands. We invent an `onnx.None` operation as our placeholder.pull/2936/head
parent
4446fa00d8
commit
13553d49c9
|
@ -2184,6 +2184,7 @@ ONNX_XFAIL_SET = {
|
|||
"ElementwiseUnsqueezeNegDimsModule_basic",
|
||||
"ElementwiseWhereScalarModule_basic",
|
||||
"FlattenDynamicModule_basic",
|
||||
"FlipModule_basic",
|
||||
"FlipModuleStaticShape_basic",
|
||||
"GluStaticModule_basic",
|
||||
"MaskedFillTensorFloatValueModule_basic",
|
||||
|
@ -2193,17 +2194,9 @@ ONNX_XFAIL_SET = {
|
|||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"TensorsStackNegativeDimModule_basic",
|
||||
"TensorsStackPromoteDTypeModule_basic",
|
||||
}
|
||||
|
||||
ONNX_CRASHING_SET = {
|
||||
"FlipModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"MoveDimIntNegativeIndexModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"RollModule_basic",
|
||||
"SliceModule_basic",
|
||||
"SliceNegIdxModule_basic",
|
||||
"SliceOutOfLowerBoundEndIndexModule_basic",
|
||||
"SliceOutOfLowerBoundStartIndexModule_basic",
|
||||
"SliceSizeTwoStepModule_basic",
|
||||
}
|
||||
|
||||
ONNX_CRASHING_SET = { }
|
||||
|
|
|
@ -258,6 +258,8 @@ class NodeImporter:
|
|||
# much unused crap.
|
||||
for init in self._gi.initializer_map.values():
|
||||
self.import_initializer(init)
|
||||
|
||||
self.get_none()
|
||||
for node in self._gi.graph_proto.node:
|
||||
self.import_node(node)
|
||||
|
||||
|
@ -272,6 +274,20 @@ class NodeImporter:
|
|||
with InsertionPoint(self._b), Location.unknown():
|
||||
func_dialect.ReturnOp(outputs)
|
||||
|
||||
def get_none(self):
|
||||
if '' in self._nv_map:
|
||||
return self._nv_map['']
|
||||
|
||||
with InsertionPoint(self._b), Location.name("onnx_importer.none"):
|
||||
nne = Operation.create(
|
||||
name="torch.constant.none",
|
||||
results=[self._cc.get_none_type()],
|
||||
operands=[],
|
||||
attributes={},
|
||||
).results[0]
|
||||
self._nv_map[''] = nne
|
||||
return nne
|
||||
|
||||
def import_node(self, node: onnx.NodeProto):
|
||||
with InsertionPoint(self._b), Location.name(node.name):
|
||||
op_type = node.op_type
|
||||
|
@ -283,7 +299,6 @@ class NodeImporter:
|
|||
was_handled = getattr(self, special_key)(node)
|
||||
if was_handled:
|
||||
return
|
||||
|
||||
# General node import.
|
||||
input_values = []
|
||||
for input_name in node.input:
|
||||
|
@ -449,6 +464,9 @@ class ContextCache:
|
|||
self._elem_type_map[elem_type] = t
|
||||
return t
|
||||
|
||||
def get_none_type(self):
|
||||
return IrType.parse("!torch.none", context=self._c)
|
||||
|
||||
def get_vtensor_type(
|
||||
self, dims: tuple[Optional[int]], element_type: IrType
|
||||
) -> IrType:
|
||||
|
|
|
@ -102,22 +102,12 @@ TEST_CAST_XFAILS = [
|
|||
"node_test_castlike_FLOAT_to_STRING_model",
|
||||
"node_test_castlike_STRING_to_FLOAT_expanded_model",
|
||||
"node_test_castlike_STRING_to_FLOAT_model",
|
||||
"node_test_center_crop_pad_crop_axes_chw_expanded_model",
|
||||
"node_test_center_crop_pad_crop_axes_hwc_expanded_model",
|
||||
"node_test_center_crop_pad_crop_negative_axes_hwc_expanded_model",
|
||||
"node_test_clip_default_inbounds_model",
|
||||
"node_test_clip_default_int8_inbounds_model",
|
||||
"node_test_clip_default_int8_max_model",
|
||||
"node_test_clip_default_max_model",
|
||||
"node_test_constantofshape_float_ones_model",
|
||||
"node_test_constantofshape_int_shape_zero_model",
|
||||
"node_test_constantofshape_int_zeros_model",
|
||||
"node_test_dequantizelinear_e4m3fn_model",
|
||||
"node_test_dequantizelinear_e4m3fn_zero_point_model",
|
||||
"node_test_dequantizelinear_e5m2_model",
|
||||
"node_test_dft_axis_model",
|
||||
"node_test_dft_inverse_model",
|
||||
"node_test_dft_model",
|
||||
"node_test_equal_string_broadcast_model",
|
||||
"node_test_equal_string_model",
|
||||
"node_test_gru_defaults_model",
|
||||
|
@ -175,8 +165,6 @@ TEST_CAST_XFAILS = [
|
|||
"node_test_optional_get_element_optional_sequence_model",
|
||||
"node_test_optional_get_element_optional_tensor_model",
|
||||
"node_test_optional_get_element_sequence_model",
|
||||
"node_test_optional_has_element_empty_no_input_name_optional_input_model",
|
||||
"node_test_optional_has_element_empty_no_input_name_tensor_input_model",
|
||||
"node_test_optional_has_element_empty_optional_input_model",
|
||||
"node_test_optional_has_element_optional_input_model",
|
||||
"node_test_optional_has_element_tensor_input_model",
|
||||
|
@ -187,43 +175,6 @@ TEST_CAST_XFAILS = [
|
|||
"node_test_regex_full_match_basic_model",
|
||||
"node_test_regex_full_match_email_domain_model",
|
||||
"node_test_regex_full_match_empty_model",
|
||||
"node_test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_model",
|
||||
"node_test_resize_downsample_scales_cubic_align_corners_model",
|
||||
"node_test_resize_downsample_scales_cubic_antialias_model",
|
||||
"node_test_resize_downsample_scales_cubic_model",
|
||||
"node_test_resize_downsample_scales_linear_align_corners_model",
|
||||
"node_test_resize_downsample_scales_linear_antialias_model",
|
||||
"node_test_resize_downsample_scales_linear_half_pixel_symmetric_model",
|
||||
"node_test_resize_downsample_scales_linear_model",
|
||||
"node_test_resize_downsample_scales_nearest_model",
|
||||
"node_test_resize_downsample_sizes_cubic_antialias_model",
|
||||
"node_test_resize_downsample_sizes_cubic_model",
|
||||
"node_test_resize_downsample_sizes_linear_antialias_model",
|
||||
"node_test_resize_downsample_sizes_linear_pytorch_half_pixel_model",
|
||||
"node_test_resize_downsample_sizes_nearest_model",
|
||||
"node_test_resize_downsample_sizes_nearest_not_larger_model",
|
||||
"node_test_resize_downsample_sizes_nearest_not_smaller_model",
|
||||
"node_test_resize_tf_crop_and_resize_axes_2_3_model",
|
||||
"node_test_resize_tf_crop_and_resize_axes_3_2_model",
|
||||
"node_test_resize_tf_crop_and_resize_model",
|
||||
"node_test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_model",
|
||||
"node_test_resize_upsample_scales_cubic_align_corners_model",
|
||||
"node_test_resize_upsample_scales_cubic_asymmetric_model",
|
||||
"node_test_resize_upsample_scales_cubic_model",
|
||||
"node_test_resize_upsample_scales_linear_align_corners_model",
|
||||
"node_test_resize_upsample_scales_linear_half_pixel_symmetric_model",
|
||||
"node_test_resize_upsample_scales_linear_model",
|
||||
"node_test_resize_upsample_scales_nearest_axes_2_3_model",
|
||||
"node_test_resize_upsample_scales_nearest_axes_3_2_model",
|
||||
"node_test_resize_upsample_scales_nearest_model",
|
||||
"node_test_resize_upsample_sizes_cubic_model",
|
||||
"node_test_resize_upsample_sizes_nearest_axes_2_3_model",
|
||||
"node_test_resize_upsample_sizes_nearest_axes_3_2_model",
|
||||
"node_test_resize_upsample_sizes_nearest_ceil_half_pixel_model",
|
||||
"node_test_resize_upsample_sizes_nearest_floor_align_corners_model",
|
||||
"node_test_resize_upsample_sizes_nearest_model",
|
||||
"node_test_resize_upsample_sizes_nearest_not_larger_model",
|
||||
"node_test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_model",
|
||||
"node_test_rnn_seq_length_model",
|
||||
"node_test_scan9_sum_model",
|
||||
"node_test_scan_sum_model",
|
||||
|
@ -246,7 +197,6 @@ TEST_CAST_XFAILS = [
|
|||
"node_test_split_to_sequence_1_model",
|
||||
"node_test_split_to_sequence_2_model",
|
||||
"node_test_split_to_sequence_nokeepdims_model",
|
||||
"node_test_stft_model",
|
||||
"node_test_string_concat_broadcasting_model",
|
||||
"node_test_string_concat_empty_string_model",
|
||||
"node_test_string_concat_model",
|
||||
|
@ -281,6 +231,9 @@ TEST_CAST_XFAILS = [
|
|||
]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ImportSmokeTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
|
Loading…
Reference in New Issue