torch-mlir/python/npcomp/frontends/pytorch/__init__.py

46 lines
1.1 KiB
Python

# -*- Python -*-
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
import _torch_mlir
from _torch_mlir import _get_mlir
from _torch_mlir import _op_report
from _torch_mlir import _liveness_report
from _torch_mlir import set_debug
from _torch_mlir import lower_to_std
import json
_torch_mlir._initialize_aten_bindings()
_torch_mlir.set_debug(False, "")
def get_mlir(t):
if not isinstance(t, list):
t = [t]
return _get_mlir(t)
def op_report(mlir):
return json.loads(_op_report(mlir))
def liveness_report(mlir):
return json.loads(_liveness_report(mlir))
def get_mlir_supported_devices(devkind=None):
# TODO: define our own device and stop hijacking the xla device.
return ["xla:0"]
def mlir_device(devkind=None):
devices = get_mlir_supported_devices(devkind=devkind)
device = devices[0]
return torch.device(device)
__all__ = ['get_mlir', 'mlir_device', 'op_report', 'liveness_report']