torch-mlir/python/torch_mlir_e2e_test/configs/utils.py

47 lines
1.6 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
from typing import Any
import numpy as np
import torch
def recursively_convert_to_numpy(o: Any):
if isinstance(o, torch.Tensor):
return o.numpy()
if isinstance(o, tuple):
return tuple(recursively_convert_to_numpy(x) for x in o)
if isinstance(o, list):
return [recursively_convert_to_numpy(x) for x in o]
if isinstance(o, dict):
return {k: recursively_convert_to_numpy(v) for k, v in o.items()}
# No-op cases. Explicitly enumerated to avoid things sneaking through.
if isinstance(o, str):
return o
if isinstance(o, float):
return o
if isinstance(o, int):
return o
raise Exception(f"Unexpected Python function input: {o}")
def recursively_convert_from_numpy(o: Any):
if isinstance(o, np.ndarray):
return torch.from_numpy(o)
if isinstance(o, tuple):
return tuple(recursively_convert_from_numpy(x) for x in o)
if isinstance(o, list):
return [recursively_convert_from_numpy(x) for x in o]
if isinstance(o, dict):
return {k: recursively_convert_from_numpy(v) for k, v in o.items()}
# No-op cases. Explicitly enumerated to avoid things sneaking through.
if isinstance(o, str):
return o
if isinstance(o, float):
return o
if isinstance(o, int):
return o
raise Exception(f"Unexpected Python function output: {o}")