mirror of https://github.com/llvm/torch-mlir
[FxImporter] Fix primitive type in return (#3379)
parent
2e194e13d6
commit
d924d0047f
|
@ -450,7 +450,6 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"TensorToBool_basic",
|
||||
"TensorToFloatZeroRank_basic",
|
||||
"TensorToFloat_basic",
|
||||
"TestMultipleTensorAndPrimitiveTypesReturn_basic",
|
||||
"ThresholdBackward2dMixedModule_basic",
|
||||
"TorchPrimLoopForLikeModule_basic",
|
||||
"TorchPrimLoopWhileLikeModule_basic",
|
||||
|
|
|
@ -848,6 +848,13 @@ class FxImporter:
|
|||
result_types.append(
|
||||
self._cc.tensor_to_vtensor_type(result_node)
|
||||
)
|
||||
elif type(result_node) in SCALAR_TYPE_TO_TORCH_MLIR_TYPE:
|
||||
result_types.append(
|
||||
IrType.parse(
|
||||
SCALAR_TYPE_TO_TORCH_MLIR_TYPE[type(result_node)],
|
||||
self._c,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result_types.append(self._cc.node_val_to_type(result_node))
|
||||
return (
|
||||
|
|
Loading…
Reference in New Issue