torch-mlir/python/npcomp/frontends/pytorch/__init__.py

46 lines
1.1 KiB
Python
Raw Normal View History

Add pytorch interface to ATen Dialect (#30) This patch adds a pytorch interface to npcomp. This interface is modeled after pytorch_xla and exposes the MLIR-based flow as a virtual device (similar to a gpu device or the xla backend). Usage is intended to be something like: dev = torch_mlir.mlir_device() t0 = torch.randn((4,4), device=dev) t1 = torch.randn((4,4), device=dev) t2 = t0 + t1 t2_mlir = torch_mlir.get_mlir( t2 ) t2_cpu = t2.to('cpu') In this case t2_cpu would contain the result of the computation, and t2_mlir contains the mlir description of the computation. Note that this also properly returns backward paths synthesized by pytorch. There are several parts of this: 1) A tensor type (implemented by tensor.* and tensor_impl.*) 2) The device modeling (aten_mlir_bridge.*, aten_mlir_device.*, aten_mlir_type*) 3) a temporary IR (implemented by ir.cpp) There is also a reference lowering directly from the ATen dialect to C function calls consisting of two parts: 1) The driver that uses the IR to generate MLIR, run Passes and compile the result using mlir::ExecutionEngine (implemented by jit.cpp and mlir_gen.cpp) 2) A runtime library implemented by lib/aten_ops.cpp. Most of the operations are implemented by callbacks into the torch C++ libraries. Some aspects of this are known to be less than optimal, in particular: 1) There's some function definitions that don't live in the file corresponding to their declaration. 2) More aspects of this (e.g. the IR) seem like they should be automatically generated. 3) It's unclear to me how much of the 'IR' is actually necessary, or whether MLIR could be created on the fly. Note that this code is licensed in a way similar to pytorch, with the intention that eventually (when npcomp reaches some maturity) it should be pushed there. (see frontends/pytorch/LICENSE) The code is also structured much closer to the pytorch coding style than the LLVM coding style.
2020-08-22 02:22:47 +08:00
# -*- Python -*-
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
import _torch_mlir
from _torch_mlir import _get_mlir
from _torch_mlir import _op_report
from _torch_mlir import _liveness_report
from _torch_mlir import set_debug
from _torch_mlir import lower_to_std
import json
_torch_mlir._initialize_aten_bindings()
_torch_mlir.set_debug(False, "")
def get_mlir(t):
if not isinstance(t, list):
t = [t]
return _get_mlir(t)
def op_report(mlir):
return json.loads(_op_report(mlir))
def liveness_report(mlir):
return json.loads(_liveness_report(mlir))
def get_mlir_supported_devices(devkind=None):
# TODO: define our own device and stop hijacking the xla device.
return ["xla:0"]
def mlir_device(devkind=None):
devices = get_mlir_supported_devices(devkind=devkind)
device = devices[0]
return torch.device(device)
__all__ = ['get_mlir', 'mlir_device', 'op_report', 'liveness_report']