mirror of https://github.com/llvm/torch-mlir
46 lines
1.1 KiB
Python
46 lines
1.1 KiB
Python
# -*- Python -*-
|
|
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
import torch
|
|
import _torch_mlir
|
|
from _torch_mlir import _get_mlir
|
|
from _torch_mlir import _op_report
|
|
from _torch_mlir import _liveness_report
|
|
from _torch_mlir import set_debug
|
|
from _torch_mlir import lower_to_std
|
|
|
|
import json
|
|
|
|
_torch_mlir._initialize_aten_bindings()
|
|
_torch_mlir.set_debug(False, "")
|
|
|
|
|
|
def get_mlir(t):
|
|
if not isinstance(t, list):
|
|
t = [t]
|
|
return _get_mlir(t)
|
|
|
|
|
|
def op_report(mlir):
|
|
return json.loads(_op_report(mlir))
|
|
|
|
|
|
def liveness_report(mlir):
|
|
return json.loads(_liveness_report(mlir))
|
|
|
|
|
|
def get_mlir_supported_devices(devkind=None):
|
|
# TODO: define our own device and stop hijacking the xla device.
|
|
return ["xla:0"]
|
|
|
|
|
|
def mlir_device(devkind=None):
|
|
devices = get_mlir_supported_devices(devkind=devkind)
|
|
device = devices[0]
|
|
return torch.device(device)
|
|
|
|
|
|
__all__ = ['get_mlir', 'mlir_device', 'op_report', 'liveness_report']
|