# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import ctypes import numpy as np from torch_mlir.ir import * from torch_mlir.passmanager import * from torch_mlir.execution_engine import * from torch_mlir.runtime import * # Imported for side effects. import torch_mlir.all_passes_registration import torch_mlir.dialects.torch from .abc import NpcompBackend __all__ = [ "RefBackendNpcompBackend", ] class RefBackendInvoker: def __init__(self, module): self.ee = ExecutionEngine(module) self.result = None @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor)) def consume_return(a): self.result = unranked_memref_to_numpy(a, np.float32) self.ee.register_runtime("refbackend_consume_func_return", consume_return) def __getattr__(self, function_name: str): def invoke(*args): ffi_args = [ ctypes.pointer( ctypes.pointer( get_unranked_memref_descriptor(arg))) for arg in args] self.ee.invoke(function_name, *ffi_args) result = self.result assert result is not None, "Invocation didn't produce a result" self.result = None return result return invoke LOWERING_PIPELINE = ",".join([ # Bufferize. "tensor-constant-bufferize", "builtin.func(scf-bufferize)", "builtin.func(linalg-bufferize)", "builtin.func(std-bufferize)", "builtin.func(tensor-bufferize)", "func-bufferize", "builtin.func(finalizing-bufferize)", # Munge to make it ExecutionEngine compatible. # Specifically, we rewrite calling convention boundaries to be in terms # of unranked memref, and we rewrite the return to actually be a # callback that consumes the return (the final munged function always # returns void at the C level -- we get the return value by providing the # callback). "refback-munge-calling-conventions", # Lower to LLVM "builtin.func(convert-linalg-to-loops)", "builtin.func(lower-affine)", "builtin.func(convert-scf-to-std)", "builtin.func(refback-expand-ops-for-llvm)", "builtin.func(convert-math-to-llvm)", "convert-memref-to-llvm", "convert-std-to-llvm", "reconcile-unrealized-casts", ]) class RefBackendNpcompBackend(NpcompBackend): """Main entry-point for the backend.""" def __init__(self): super().__init__() def compile(self, imported_module: Module): """Compiles an imported module, with a flat list of functions. The module is expected to be in linalg-on-tensors + scalar code form. TODO: More clearly define the backend contract. Generally this will extend to support globals, lists, and other stuff. Args: imported_module: The MLIR module consisting of funcs in the torch dialect. Returns: An opaque, backend specific module object that can be passed to load. The object may actually be something more specific to the backend (i.e. for IREE, it is a serialized VM flatbuffer) but the contract is that it is operated on by methods on this class. """ # Go through a string because we are briding two separate CAPI's. # TODO: Remove after npcomp's mlir is deleted in favor of torch_mlir. with Context() as ctx: module = Module.parse(str(imported_module)) pm = PassManager.parse(LOWERING_PIPELINE) pm.run(module) return module def load(self, module) -> RefBackendInvoker: """Loads a compiled artifact into the runtime.""" return RefBackendInvoker(module)