Update sample for refjit invocation.

pull/1/head
Stella Laurenzo 2020-07-10 22:50:24 -07:00 committed by Stella Laurenzo
parent 0356f65dcd
commit 29da57e631
3 changed files with 25 additions and 15 deletions

View File

@ -39,13 +39,6 @@ public:
fromCompiledModule(mlir::ModuleOp module,
llvm::ArrayRef<llvm::StringRef> sharedLibs);
/// All in one factory function for compiling from an input module.
/// This method will construct a PassManager, perform backend compilation
/// and construct a JITModule all in one.
// static llvm::Expected<std::unique_ptr<JITModule>>
// fromMLIR(mlir::ModuleOp module, llvm::ArrayRef<llvm::StringRef>
// sharedLibs);
llvm::Expected<llvm::SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6>>
invoke(llvm::StringRef functionName,
llvm::ArrayRef<npcomprt::Ref<npcomprt::Tensor>> inputs);

View File

@ -19,8 +19,8 @@ def compile_function(f):
target_factory=GenericTarget32))
fe.import_global_function(f)
compiler = refjit.CompilerBackend()
vm_blob = compiler.compile(fe.ir_module)
loaded_m = compiler.load(vm_blob)
blob = compiler.compile(fe.ir_module)
loaded_m = compiler.load(blob)
return loaded_m[f.__name__]

View File

@ -1,10 +1,20 @@
# Run full pipeline with:
# -npcomp-cpa-type-inference -numpy-public-functions-to-tensor -convert-numpy-to-tcf -canonicalize
import numpy as np
from npcomp.compiler import test_config
import_global = test_config.create_import_dump_decorator()
from npcomp.compiler import test_config
from npcomp.compiler.backend import refjit
from npcomp.compiler.frontend import *
from npcomp.compiler.target import *
def compile_function(f):
fe = ImportFrontend(config=test_config.create_test_config(
target_factory=GenericTarget32))
fe.import_global_function(f)
compiler = refjit.CompilerBackend()
vm_blob = compiler.compile(fe.ir_module)
loaded_m = compiler.load(vm_blob)
return loaded_m[f.__name__]
global_data = (np.zeros((2, 3)) + [1.0, 2.0, 3.0] * np.reshape([1.0, 2.0],
(2, 1)))
@ -13,6 +23,13 @@ a = np.asarray([1.0, 2.0], dtype=np.float32)
b = np.asarray([3.0, 4.0], dtype=np.float32)
@import_global
@compile_function
def global_add():
return np.add(a, np.add(b, a))
assert global_add.__isnpcomp__
# CHECK: GLOBAL_ADD: [5. 8.]
result = global_add()
print("GLOBAL_ADD:", result)