mirror of https://github.com/llvm/torch-mlir
[FxImporter] Fix primitive type in return (#3379)
parent
2e194e13d6
commit
d924d0047f
|
@ -450,7 +450,6 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"TensorToBool_basic",
|
"TensorToBool_basic",
|
||||||
"TensorToFloatZeroRank_basic",
|
"TensorToFloatZeroRank_basic",
|
||||||
"TensorToFloat_basic",
|
"TensorToFloat_basic",
|
||||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
|
||||||
"ThresholdBackward2dMixedModule_basic",
|
"ThresholdBackward2dMixedModule_basic",
|
||||||
"TorchPrimLoopForLikeModule_basic",
|
"TorchPrimLoopForLikeModule_basic",
|
||||||
"TorchPrimLoopWhileLikeModule_basic",
|
"TorchPrimLoopWhileLikeModule_basic",
|
||||||
|
|
|
@ -848,6 +848,13 @@ class FxImporter:
|
||||||
result_types.append(
|
result_types.append(
|
||||||
self._cc.tensor_to_vtensor_type(result_node)
|
self._cc.tensor_to_vtensor_type(result_node)
|
||||||
)
|
)
|
||||||
|
elif type(result_node) in SCALAR_TYPE_TO_TORCH_MLIR_TYPE:
|
||||||
|
result_types.append(
|
||||||
|
IrType.parse(
|
||||||
|
SCALAR_TYPE_TO_TORCH_MLIR_TYPE[type(result_node)],
|
||||||
|
self._c,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result_types.append(self._cc.node_val_to_type(result_node))
|
result_types.append(self._cc.node_val_to_type(result_node))
|
||||||
return (
|
return (
|
||||||
|
|
Loading…
Reference in New Issue