[FxImporter] Fix primitive type in return (#3379)

pull/3382/head
penguin_wwy 2024-05-23 09:55:33 +08:00 committed by GitHub
parent 2e194e13d6
commit d924d0047f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 1 deletions

View File

@ -450,7 +450,6 @@ FX_IMPORTER_XFAIL_SET = {
"TensorToBool_basic", "TensorToBool_basic",
"TensorToFloatZeroRank_basic", "TensorToFloatZeroRank_basic",
"TensorToFloat_basic", "TensorToFloat_basic",
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
"ThresholdBackward2dMixedModule_basic", "ThresholdBackward2dMixedModule_basic",
"TorchPrimLoopForLikeModule_basic", "TorchPrimLoopForLikeModule_basic",
"TorchPrimLoopWhileLikeModule_basic", "TorchPrimLoopWhileLikeModule_basic",

View File

@ -848,6 +848,13 @@ class FxImporter:
result_types.append( result_types.append(
self._cc.tensor_to_vtensor_type(result_node) self._cc.tensor_to_vtensor_type(result_node)
) )
elif type(result_node) in SCALAR_TYPE_TO_TORCH_MLIR_TYPE:
result_types.append(
IrType.parse(
SCALAR_TYPE_TO_TORCH_MLIR_TYPE[type(result_node)],
self._c,
)
)
else: else:
result_types.append(self._cc.node_val_to_type(result_node)) result_types.append(self._cc.node_val_to_type(result_node))
return ( return (