mirror of https://github.com/llvm/torch-mlir
Update sample for refjit invocation.
parent
0356f65dcd
commit
29da57e631
|
@ -39,13 +39,6 @@ public:
|
|||
fromCompiledModule(mlir::ModuleOp module,
|
||||
llvm::ArrayRef<llvm::StringRef> sharedLibs);
|
||||
|
||||
/// All in one factory function for compiling from an input module.
|
||||
/// This method will construct a PassManager, perform backend compilation
|
||||
/// and construct a JITModule all in one.
|
||||
// static llvm::Expected<std::unique_ptr<JITModule>>
|
||||
// fromMLIR(mlir::ModuleOp module, llvm::ArrayRef<llvm::StringRef>
|
||||
// sharedLibs);
|
||||
|
||||
llvm::Expected<llvm::SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6>>
|
||||
invoke(llvm::StringRef functionName,
|
||||
llvm::ArrayRef<npcomprt::Ref<npcomprt::Tensor>> inputs);
|
||||
|
|
|
@ -19,8 +19,8 @@ def compile_function(f):
|
|||
target_factory=GenericTarget32))
|
||||
fe.import_global_function(f)
|
||||
compiler = refjit.CompilerBackend()
|
||||
vm_blob = compiler.compile(fe.ir_module)
|
||||
loaded_m = compiler.load(vm_blob)
|
||||
blob = compiler.compile(fe.ir_module)
|
||||
loaded_m = compiler.load(blob)
|
||||
return loaded_m[f.__name__]
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,20 @@
|
|||
# Run full pipeline with:
|
||||
# -npcomp-cpa-type-inference -numpy-public-functions-to-tensor -convert-numpy-to-tcf -canonicalize
|
||||
|
||||
import numpy as np
|
||||
from npcomp.compiler import test_config
|
||||
|
||||
import_global = test_config.create_import_dump_decorator()
|
||||
from npcomp.compiler import test_config
|
||||
from npcomp.compiler.backend import refjit
|
||||
from npcomp.compiler.frontend import *
|
||||
from npcomp.compiler.target import *
|
||||
|
||||
|
||||
def compile_function(f):
|
||||
fe = ImportFrontend(config=test_config.create_test_config(
|
||||
target_factory=GenericTarget32))
|
||||
fe.import_global_function(f)
|
||||
compiler = refjit.CompilerBackend()
|
||||
vm_blob = compiler.compile(fe.ir_module)
|
||||
loaded_m = compiler.load(vm_blob)
|
||||
return loaded_m[f.__name__]
|
||||
|
||||
|
||||
global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0],
|
||||
(2, 1)))
|
||||
|
@ -13,6 +23,13 @@ a = np.asarray([1.0, 2.0], dtype=np.float32)
|
|||
b = np.asarray([3.0, 4.0], dtype=np.float32)
|
||||
|
||||
|
||||
@import_global
|
||||
@compile_function
|
||||
def global_add():
|
||||
return np.add(a, np.add(b, a))
|
||||
|
||||
|
||||
assert global_add.__isnpcomp__
|
||||
|
||||
# CHECK: GLOBAL_ADD: [5. 8.]
|
||||
result = global_add()
|
||||
print("GLOBAL_ADD:", result)
|
||||
|
|
Loading…
Reference in New Issue