mirror of https://github.com/llvm/torch-mlir
109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
|
||
|
import ctypes
|
||
|
import numpy as np
|
||
|
|
||
|
from torch_mlir.ir import *
|
||
|
from torch_mlir.passmanager import *
|
||
|
from torch_mlir.execution_engine import *
|
||
|
from torch_mlir.runtime import *
|
||
|
# Imported for side effects.
|
||
|
import torch_mlir.all_passes_registration
|
||
|
import torch_mlir.dialects.torch
|
||
|
|
||
|
from .abc import NpcompBackend
|
||
|
|
||
|
__all__ = [
|
||
|
"RefBackendNpcompBackend",
|
||
|
]
|
||
|
|
||
|
|
||
|
class RefBackendInvoker:
|
||
|
def __init__(self, module):
|
||
|
self.ee = ExecutionEngine(module)
|
||
|
self.result = None
|
||
|
|
||
|
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
|
||
|
def consume_return(a):
|
||
|
self.result = unranked_memref_to_numpy(a, np.float32)
|
||
|
self.ee.register_runtime("refbackend_consume_func_return", consume_return)
|
||
|
|
||
|
def __getattr__(self, function_name: str):
|
||
|
def invoke(*args):
|
||
|
ffi_args = [
|
||
|
ctypes.pointer(
|
||
|
ctypes.pointer(
|
||
|
get_unranked_memref_descriptor(arg)))
|
||
|
for arg in args]
|
||
|
self.ee.invoke(function_name, *ffi_args)
|
||
|
result = self.result
|
||
|
assert result is not None, "Invocation didn't produce a result"
|
||
|
self.result = None
|
||
|
return result
|
||
|
|
||
|
return invoke
|
||
|
|
||
|
|
||
|
LOWERING_PIPELINE = ",".join([
|
||
|
# Bufferize.
|
||
|
"tensor-constant-bufferize",
|
||
|
"builtin.func(scf-bufferize)",
|
||
|
"builtin.func(linalg-bufferize)",
|
||
|
"builtin.func(std-bufferize)",
|
||
|
"builtin.func(tensor-bufferize)",
|
||
|
"func-bufferize",
|
||
|
"builtin.func(finalizing-bufferize)",
|
||
|
# Munge to make it ExecutionEngine compatible.
|
||
|
# Specifically, we rewrite calling convention boundaries to be in terms
|
||
|
# of unranked memref, and we rewrite the return to actually be a
|
||
|
# callback that consumes the return (the final munged function always
|
||
|
# returns void at the C level -- we get the return value by providing the
|
||
|
# callback).
|
||
|
"refback-munge-calling-conventions",
|
||
|
# Lower to LLVM
|
||
|
"builtin.func(convert-linalg-to-loops)",
|
||
|
"builtin.func(lower-affine)",
|
||
|
"builtin.func(convert-scf-to-std)",
|
||
|
"builtin.func(refback-expand-ops-for-llvm)",
|
||
|
"builtin.func(convert-math-to-llvm)",
|
||
|
"convert-memref-to-llvm",
|
||
|
"convert-std-to-llvm",
|
||
|
"reconcile-unrealized-casts",
|
||
|
])
|
||
|
|
||
|
|
||
|
class RefBackendNpcompBackend(NpcompBackend):
|
||
|
"""Main entry-point for the backend."""
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
|
||
|
def compile(self, imported_module: Module):
|
||
|
"""Compiles an imported module, with a flat list of functions.
|
||
|
The module is expected to be in linalg-on-tensors + scalar code form.
|
||
|
TODO: More clearly define the backend contract. Generally this will
|
||
|
extend to support globals, lists, and other stuff.
|
||
|
|
||
|
Args:
|
||
|
imported_module: The MLIR module consisting of funcs in the torch
|
||
|
dialect.
|
||
|
Returns:
|
||
|
An opaque, backend specific module object that can be passed to load.
|
||
|
The object may actually be something more specific to the backend (i.e.
|
||
|
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||
|
it is operated on by methods on this class.
|
||
|
"""
|
||
|
# Go through a string because we are briding two separate CAPI's.
|
||
|
# TODO: Remove after npcomp's mlir is deleted in favor of torch_mlir.
|
||
|
with Context() as ctx:
|
||
|
module = Module.parse(str(imported_module))
|
||
|
pm = PassManager.parse(LOWERING_PIPELINE)
|
||
|
pm.run(module)
|
||
|
return module
|
||
|
|
||
|
def load(self, module) -> RefBackendInvoker:
|
||
|
"""Loads a compiled artifact into the runtime."""
|
||
|
return RefBackendInvoker(module)
|