mirror of https://github.com/llvm/torch-mlir
47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
# Also available under a BSD-style license. See LICENSE.
|
|
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def recursively_convert_to_numpy(o: Any):
|
|
if isinstance(o, torch.Tensor):
|
|
return o.numpy()
|
|
if isinstance(o, tuple):
|
|
return tuple(recursively_convert_to_numpy(x) for x in o)
|
|
if isinstance(o, list):
|
|
return [recursively_convert_to_numpy(x) for x in o]
|
|
if isinstance(o, dict):
|
|
return {k: recursively_convert_to_numpy(v) for k, v in o.items()}
|
|
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
|
if isinstance(o, str):
|
|
return o
|
|
if isinstance(o, float):
|
|
return o
|
|
if isinstance(o, int):
|
|
return o
|
|
raise Exception(f"Unexpected Python function input: {o}")
|
|
|
|
def recursively_convert_from_numpy(o: Any):
|
|
if isinstance(o, np.ndarray):
|
|
return torch.from_numpy(o)
|
|
if isinstance(o, tuple):
|
|
return tuple(recursively_convert_from_numpy(x) for x in o)
|
|
if isinstance(o, list):
|
|
return [recursively_convert_from_numpy(x) for x in o]
|
|
if isinstance(o, dict):
|
|
return {k: recursively_convert_from_numpy(v) for k, v in o.items()}
|
|
# No-op cases. Explicitly enumerated to avoid things sneaking through.
|
|
if isinstance(o, str):
|
|
return o
|
|
if isinstance(o, float):
|
|
return o
|
|
if isinstance(o, int):
|
|
return o
|
|
raise Exception(f"Unexpected Python function output: {o}")
|