mirror of https://github.com/llvm/torch-mlir
Initial commit of python boiler-plate.
commit
9ee2f6ff7f
|
@ -0,0 +1,25 @@
|
|||
# MLIR npcomp project.
|
||||
set(MLIR_NPCOMP_MAIN_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) # --src-root
|
||||
set(MLIR_NPCOMP_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) # --includedir
|
||||
|
||||
set(MLIR_NPCOMP_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
set(MLIR_NPCOMP_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
# TODO(laurenzo): Rationalize with how this is done elsewhere
|
||||
find_package(PythonInterp REQUIRED)
|
||||
find_package(PythonLibs REQUIRED)
|
||||
message(STATUS "Found python include dirs: ${PYTHON_INCLUDE_DIRS}")
|
||||
message(STATUS "Found ppython libraries: ${PYTHON_LIBRARIES}")
|
||||
find_package(pybind11 CONFIG REQUIRED)
|
||||
message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIRS}")
|
||||
|
||||
# TODO(laurenzo): What is the right way to get include directories for
|
||||
# cross project dependencies?
|
||||
include_directories(${MLIR_NPCOMP_INCLUDE_DIR})
|
||||
include_directories(${CMAKE_SOURCE_DIR}/../mlir/include)
|
||||
include_directories(${CMAKE_BINARY_DIR}/tools/mlir/include)
|
||||
|
||||
add_subdirectory(include/npcomp)
|
||||
add_subdirectory(lib)
|
||||
add_subdirectory(python)
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# npcomp - An aspirational MLIR based numpy compiler
|
||||
|
||||
## Scratch-pad of build configurations that have worked
|
||||
|
||||
### VSCode settings for configuring CMake
|
||||
|
||||
```json
|
||||
"cmake.configureArgs": [
|
||||
"-DLLVM_TARGETS_TO_BUILD=X86",
|
||||
"-DLLVM_ENABLE_PROJECTS=mlir;npcomp",
|
||||
"-DPYTHON_EXECUTABLE=/bin/python3",
|
||||
"-DLLVM_EXTERNAL_PROJECTS=npcomp",
|
||||
"-DLLVM_ENABLE_ASSERTIONS:BOOL=ON"
|
||||
]
|
||||
```
|
||||
|
||||
### Installing pybind11
|
||||
|
||||
The native extension relies on pybind11. In a perfect world, this could just
|
||||
be installed with your system package manager. However, at least on
|
||||
Ubuntu Disco, the system package installed with broken cmake files.
|
||||
|
||||
I built/installed from pybind11 head without issue and put it in /usr/local.
|
||||
There are better ways to do this.
|
||||
|
||||
### Building the python native library
|
||||
|
||||
```shell
|
||||
# From the build directory
|
||||
ninja NPCOMPNativePyExt
|
||||
# Outputs to tools/npcomp/python/npcomp/native...so
|
||||
export PYTHONPATH=$(pwd)/tools/npcomp/python
|
||||
python3 -m npcomp.smoketest
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
* Python sources are symlinked to the output directory at configure time.
|
||||
Adding sources will require a reconfigure. Editing should not.
|
||||
* It is a very common issue to have both python 2.7 (aka. "python") and python
|
||||
3.x (aka. "python3") on a system at a time (and we can only hope that one
|
||||
day this ends). Since the native library at development time binds to a
|
||||
specific version, if you try to run with a different python, you will get
|
||||
an error about the "native" module not being found.
|
|
@ -0,0 +1 @@
|
|||
# Empty file
|
|
@ -0,0 +1 @@
|
|||
// Empty file
|
|
@ -0,0 +1,7 @@
|
|||
add_llvm_tool(npcomp-dummy-runner
|
||||
dummy-runner.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(npcomp-dummy-runner PRIVATE
|
||||
LLVMSupport
|
||||
)
|
|
@ -0,0 +1,12 @@
|
|||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "npcomp/Dummy.h"
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
InitLLVM y(argc, argv);
|
||||
cl::ParseCommandLineOptions(argc, argv, "Dummy program\n");
|
||||
llvm::outs() << "Hello world!\n";
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
add_subdirectory(npcomp)
|
||||
|
||||
################################################################################
|
||||
# Manage python source files
|
||||
################################################################################
|
||||
function (create_symlinks)
|
||||
# Do nothing if building in-source
|
||||
if (${CMAKE_CURRENT_BINARY_DIR} STREQUAL ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
return()
|
||||
endif()
|
||||
|
||||
foreach (path_file ${ARGN})
|
||||
get_filename_component(folder ${path_file} PATH)
|
||||
|
||||
# Create REAL folder
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${folder}")
|
||||
|
||||
# Delete symlink if it exists
|
||||
file(REMOVE "${CMAKE_CURRENT_BINARY_DIR}/${path_file}")
|
||||
|
||||
# Get OS dependent path to use in `execute_process`
|
||||
file(TO_NATIVE_PATH "${CMAKE_CURRENT_BINARY_DIR}/${path_file}" link)
|
||||
file(TO_NATIVE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/${path_file}" target)
|
||||
|
||||
if (UNIX)
|
||||
set(command ln -s ${target} ${link})
|
||||
else()
|
||||
set(command cmd.exe /c mklink ${link} ${target})
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND ${command}
|
||||
RESULT_VARIABLE result
|
||||
ERROR_VARIABLE output)
|
||||
|
||||
if (NOT ${result} EQUAL 0)
|
||||
message(FATAL_ERROR "Could not create symbolic link for: ${target} --> ${output}")
|
||||
endif()
|
||||
|
||||
endforeach(path_file)
|
||||
endfunction(create_symlinks)
|
||||
|
||||
file(GLOB_RECURSE python_files RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.py)
|
||||
create_symlinks(${python_files})
|
|
@ -0,0 +1,84 @@
|
|||
################################################################################
|
||||
# Native extensions
|
||||
################################################################################
|
||||
|
||||
# Normally on unix-like platforms, extensions are built as "MODULE" libraries
|
||||
# and do not explicitly link to the python shared object. This allows for
|
||||
# come greater deployment flexibility since the extension will bind to
|
||||
# symbols in the python interpreter on load. However, it also keeps the
|
||||
# linker from erroring on undefined symbols, leaving this to (usually obtuse)
|
||||
# runtime errors. Building in "SHARED" mode with an explicit link to the
|
||||
# python libraries allows us to build with the expectation of no undefined
|
||||
# symbols, which is better for development.
|
||||
# TODO(laurenzo): Windows requires linking against the PYTHON_LIBRARIES
|
||||
# TODO(laurenzo): OSX requires allowing undefined (-undefined dynamic_lookup)
|
||||
set(NPCOMP_PYEXT_LINK_MODE SHARED)
|
||||
set(NPCOMP_PYEXT_LIBADD ${PYTHON_LIBRARIES})
|
||||
|
||||
# TODO(laurenzo): Add a config setting to control this.
|
||||
# set(NPCOMP_PYEXT_LINK_MODE MODULE)
|
||||
# set(NPCOMP_PYEXT_LIBADD "")
|
||||
|
||||
# When building the extension, distinguish between those sources that use
|
||||
# pybind (and need rtti/exceptions) and those that only use LLVM/MLIR.
|
||||
# Some of the low-level components do not support mixing RTTI modes and are
|
||||
# compiled separately for now.
|
||||
set(extension_target NPCOMPNativePyExt)
|
||||
set(extension_pybind_sources
|
||||
native.cpp
|
||||
mlir_edsc.cpp
|
||||
)
|
||||
set(extension_llvm_sources
|
||||
mlir_init.cpp
|
||||
)
|
||||
set_source_files_properties(
|
||||
${extension_pybind_sources}
|
||||
PROPERTIES COMPILE_FLAGS
|
||||
"-frtti -fexceptions")
|
||||
add_library(${extension_target} ${NPCOMP_PYEXT_LINK_MODE}
|
||||
${extension_pybind_sources}
|
||||
${extension_llvm_sources}
|
||||
)
|
||||
|
||||
set_target_properties(${extension_target} PROPERTIES LIBRARY_OUTPUT_DIRECTORY
|
||||
"${CMAKE_CURRENT_BINARY_DIR}")
|
||||
set_target_properties(${extension_target} PROPERTIES OUTPUT_NAME native)
|
||||
set_target_properties(${extension_target} PROPERTIES PREFIX
|
||||
"${PYTHON_MODULE_PREFIX}")
|
||||
set_target_properties(${extension_target} PROPERTIES SUFFIX
|
||||
"${PYTHON_MODULE_EXTENSION}")
|
||||
|
||||
# pybind requires binding code to be compiled with -fvisibility=hidden
|
||||
# Better code can be generated if the entire project compiles that way, but
|
||||
# that is not enforced here. Instead, include a linker script that explicitly
|
||||
# hides anything but the PyInit_* symbols, allowing gc to take place.
|
||||
# TODO(laurenzo): Windows needs a .def file and different flags.
|
||||
set_target_properties(${extension_target} PROPERTIES CXX_VISIBILITY_PRESET "hidden")
|
||||
set_target_properties(${extension_target} PROPERTIES LINK_FLAGS
|
||||
"-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/unix_version.script")
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
llvm_update_compile_flags(${extension_target})
|
||||
target_link_libraries(${extension_target}
|
||||
PRIVATE
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
|
||||
pybind11::module
|
||||
LLVMSupport
|
||||
MLIRAffineToStandard
|
||||
MLIRAffineTransforms
|
||||
MLIRDialect
|
||||
MLIREDSC
|
||||
MLIREDSCInterface
|
||||
MLIRIR
|
||||
MLIRLoopToStandard
|
||||
MLIRLLVMIR
|
||||
MLIRPass
|
||||
MLIRTargetLLVMIR
|
||||
MLIRTransforms
|
||||
|
||||
${NPCOMP_PYEXT_LIBADD}
|
||||
)
|
|
@ -0,0 +1,4 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
|
@ -0,0 +1,536 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
"""Test for the MLIR EDSC Python bindings"""
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
|
||||
from npcomp.native.mlir import edsc as E
|
||||
from npcomp.utils import test_utils
|
||||
|
||||
|
||||
# Prints `str` prefixed by the current test function name so we can use it in
|
||||
# Filecheck label directives.
|
||||
# This is achieved by inspecting the stack and getting the parent name.
|
||||
def printWithCurrentFunctionName(str):
|
||||
print(inspect.stack()[1][3])
|
||||
print(str)
|
||||
|
||||
|
||||
class EdscTest:
|
||||
|
||||
def setUp(self):
|
||||
self.module = E.MLIRModule()
|
||||
self.boolType = self.module.make_type("i1")
|
||||
self.i32Type = self.module.make_type("i32")
|
||||
self.f32Type = self.module.make_type("f32")
|
||||
self.indexType = self.module.make_index_type()
|
||||
|
||||
def testBlockArguments(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []) as fun:
|
||||
E.constant_index(42)
|
||||
with E.BlockContext([self.f32Type, self.f32Type]) as b:
|
||||
b.arg(0) + b.arg(1)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBlockArguments
|
||||
# CHECK: %{{.*}} = constant 42 : index
|
||||
# CHECK: ^bb{{.*}}(%{{.*}}: f32, %{{.*}}: f32):
|
||||
# CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
|
||||
def testBlockContext(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []) as fun:
|
||||
cst = E.constant_index(42)
|
||||
with E.BlockContext():
|
||||
cst + cst
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBlockContext
|
||||
# CHECK: %{{.*}} = constant 42 : index
|
||||
# CHECK: ^bb
|
||||
# CHECK: %{{.*}} = affine.apply affine_map<() -> (84)>()
|
||||
|
||||
def testBlockContextAppend(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []) as fun:
|
||||
E.constant_index(41)
|
||||
with E.BlockContext() as b:
|
||||
blk = b # save block handle for later
|
||||
E.constant_index(0)
|
||||
E.constant_index(42)
|
||||
with E.BlockContext(E.appendTo(blk)):
|
||||
E.constant_index(1)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBlockContextAppend
|
||||
# CHECK: %{{.*}} = constant 41 : index
|
||||
# CHECK: %{{.*}} = constant 42 : index
|
||||
# CHECK: ^bb
|
||||
# CHECK: %{{.*}} = constant 0 : index
|
||||
# CHECK: %{{.*}} = constant 1 : index
|
||||
|
||||
def testBlockContextStandalone(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []) as fun:
|
||||
blk1 = E.BlockContext()
|
||||
blk2 = E.BlockContext()
|
||||
with blk1:
|
||||
E.constant_index(0)
|
||||
with blk2:
|
||||
E.constant_index(56)
|
||||
E.constant_index(57)
|
||||
E.constant_index(41)
|
||||
with blk1:
|
||||
E.constant_index(1)
|
||||
E.constant_index(42)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBlockContextStandalone
|
||||
# CHECK: %{{.*}} = constant 41 : index
|
||||
# CHECK: %{{.*}} = constant 42 : index
|
||||
# CHECK: ^bb
|
||||
# CHECK: %{{.*}} = constant 0 : index
|
||||
# CHECK: %{{.*}} = constant 1 : index
|
||||
# CHECK: ^bb
|
||||
# CHECK: %{{.*}} = constant 56 : index
|
||||
# CHECK: %{{.*}} = constant 57 : index
|
||||
|
||||
def testBooleanOps(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("booleans",
|
||||
[self.boolType for _ in range(4)],
|
||||
[]) as fun:
|
||||
i, j, k, l = (fun.arg(x) for x in range(4))
|
||||
stmt1 = (i < j) & (j >= k)
|
||||
stmt2 = ~(stmt1 | (k == l))
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBooleanOps
|
||||
# CHECK: %{{.*}} = cmpi "slt", %{{.*}}, %{{.*}} : i1
|
||||
# CHECK: %{{.*}} = cmpi "sge", %{{.*}}, %{{.*}} : i1
|
||||
# CHECK: %{{.*}} = and %{{.*}}, %{{.*}} : i1
|
||||
# CHECK: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : i1
|
||||
# CHECK: %{{.*}} = or %{{.*}}, %{{.*}} : i1
|
||||
# CHECK: %{{.*}} = constant 1 : i1
|
||||
# CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
|
||||
|
||||
def testBr(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []) as fun:
|
||||
with E.BlockContext() as b:
|
||||
blk = b
|
||||
E.ret()
|
||||
E.br(blk)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBr
|
||||
# CHECK: br ^bb
|
||||
# CHECK: ^bb
|
||||
# CHECK: return
|
||||
|
||||
def testBrArgs(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []) as fun:
|
||||
# Create an infinite loop.
|
||||
with E.BlockContext([self.indexType, self.indexType]) as b:
|
||||
E.br(b, [b.arg(1), b.arg(0)])
|
||||
E.br(b, [E.constant_index(0), E.constant_index(1)])
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBrArgs
|
||||
# CHECK: %{{.*}} = constant 0 : index
|
||||
# CHECK: %{{.*}} = constant 1 : index
|
||||
# CHECK: br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
|
||||
# CHECK: ^bb{{.*}}(%{{.*}}: index, %{{.*}}: index):
|
||||
# CHECK: br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
|
||||
|
||||
def testBrDeclaration(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []) as fun:
|
||||
blk = E.BlockContext()
|
||||
E.br(blk.handle())
|
||||
with blk:
|
||||
E.ret()
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testBrDeclaration
|
||||
# CHECK: br ^bb
|
||||
# CHECK: ^bb
|
||||
# CHECK: return
|
||||
|
||||
def testCallOp(self):
|
||||
self.setUp()
|
||||
callee = self.module.declare_function("sqrtf", [self.f32Type],
|
||||
[self.f32Type])
|
||||
with self.module.new_function_context("call", [self.f32Type], []) as fun:
|
||||
funCst = E.constant_function(callee)
|
||||
funCst([fun.arg(0)]) + E.constant_float(42., self.f32Type)
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testCallOp
|
||||
# CHECK: func @sqrtf(f32) -> f32
|
||||
# CHECK: %{{.*}} = constant @sqrtf : (f32) -> f32
|
||||
# CHECK: %{{.*}} = call_indirect %{{.*}}(%{{.*}}) : (f32) -> f32
|
||||
|
||||
def testCondBr(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [self.boolType], []) as fun:
|
||||
with E.BlockContext() as blk1:
|
||||
E.ret([])
|
||||
with E.BlockContext([self.indexType]) as blk2:
|
||||
E.ret([])
|
||||
cst = E.constant_index(0)
|
||||
E.cond_br(fun.arg(0), blk1, [], blk2, [cst])
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testCondBr
|
||||
# CHECK: cond_br %{{.*}}, ^bb{{.*}}, ^bb{{.*}}(%{{.*}} : index)
|
||||
|
||||
def testConstantAffineExpr(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("constant_affine", [], []) as fun:
|
||||
a1 = self.module.affine_dim_expr(0)
|
||||
a2 = self.module.affine_dim_expr(1)
|
||||
a3 = a1 + a2 + 3
|
||||
composedExpr = a3.compose(
|
||||
self.module.affine_map(2, 0, [
|
||||
self.module.affine_constant_expr(4),
|
||||
self.module.affine_constant_expr(7)
|
||||
]))
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
print("constant value : %d" % composedExpr.get_constant_value())
|
||||
# CHECK-LABEL: testConstantAffineExpr
|
||||
# CHECK: constant value : 14
|
||||
|
||||
def testConstants(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("constants", [], []) as fun:
|
||||
E.constant_float(1.23, self.module.make_type("bf16"))
|
||||
E.constant_float(1.23, self.module.make_type("f16"))
|
||||
E.constant_float(1.23, self.module.make_type("f32"))
|
||||
E.constant_float(1.23, self.module.make_type("f64"))
|
||||
E.constant_int(1, 1)
|
||||
E.constant_int(123, 8)
|
||||
E.constant_int(123, 16)
|
||||
E.constant_int(123, 32)
|
||||
E.constant_int(123, 64)
|
||||
E.constant_index(123)
|
||||
E.constant_function(fun)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testConstants
|
||||
# CHECK: constant 1.230000e+00 : bf16
|
||||
# CHECK: constant 1.230470e+00 : f16
|
||||
# CHECK: constant 1.230000e+00 : f32
|
||||
# CHECK: constant 1.230000e+00 : f64
|
||||
# CHECK: constant 1 : i1
|
||||
# CHECK: constant 123 : i8
|
||||
# CHECK: constant 123 : i16
|
||||
# CHECK: constant 123 : i32
|
||||
# CHECK: constant 123 : index
|
||||
# CHECK: constant @constants : () -> ()
|
||||
|
||||
def testCustom(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("custom", [self.indexType, self.f32Type],
|
||||
[]) as fun:
|
||||
E.op("foo", [fun.arg(0)], [self.f32Type]) + fun.arg(1)
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testCustom
|
||||
# CHECK: %{{.*}} = "foo"(%{{.*}}) : (index) -> f32
|
||||
# CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
|
||||
def testDictionaryAttributes(self):
|
||||
self.setUp()
|
||||
dictionaryAttr = self.module.dictionaryAttr({
|
||||
"int_0": self.module.integerAttr(self.i32Type, 43),
|
||||
"int_1": self.module.integerAttr(self.i32Type, 33),
|
||||
})
|
||||
f = self.module.declare_function("foo", [], [], dict_attr=dictionaryAttr)
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testDictionaryAttributes
|
||||
# CHECK: func @foo() attributes {dict_attr = {int_0 = 43 : i32, int_1 = 33 : i32}}
|
||||
|
||||
def testDivisions(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context(
|
||||
"division", [self.indexType, self.i32Type, self.i32Type], []) as fun:
|
||||
# indices only support floor division
|
||||
fun.arg(0) // E.constant_index(42)
|
||||
# regular values only support regular division
|
||||
fun.arg(1) / fun.arg(2)
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testDivisions
|
||||
# CHECK: floordiv 42
|
||||
# CHECK: divi_signed %{{.*}}, %{{.*}} : i32
|
||||
|
||||
def testFunctionArgs(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [self.f32Type, self.f32Type],
|
||||
[self.indexType]) as fun:
|
||||
pass
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testFunctionArgs
|
||||
# CHECK: func @foo(%{{.*}}: f32, %{{.*}}: f32) -> index
|
||||
|
||||
def testFunctionContext(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []):
|
||||
pass
|
||||
printWithCurrentFunctionName(self.module.get_function("foo"))
|
||||
# CHECK-LABEL: testFunctionContext
|
||||
# CHECK: func @foo() {
|
||||
|
||||
def testFunctionDeclaration(self):
|
||||
self.setUp()
|
||||
boolAttr = self.module.boolAttr(True)
|
||||
t = self.module.make_memref_type(self.f32Type, [10])
|
||||
t_llvm_noalias = t({"llvm.noalias": boolAttr})
|
||||
t_readonly = t({"readonly": boolAttr})
|
||||
f = self.module.declare_function("foo", [t, t_llvm_noalias, t_readonly], [])
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testFunctionDeclaration
|
||||
# CHECK: func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias = true}, memref<10xf32> {readonly = true})
|
||||
|
||||
def testFunctionDeclarationWithAffineAttr(self):
|
||||
self.setUp()
|
||||
a1 = self.module.affine_constant_expr(23)
|
||||
a2 = self.module.affine_constant_expr(44)
|
||||
a3 = self.module.affine_dim_expr(1)
|
||||
s0 = self.module.affine_symbol_expr(0)
|
||||
aMap1 = self.module.affine_map(2, 0, [a1, a2, s0])
|
||||
aMap2 = self.module.affine_constant_map(42)
|
||||
aMap3 = self.module.affine_map(
|
||||
2, 0,
|
||||
[a1 + a2 * a3, a1 // a3 % a2,
|
||||
a1.ceildiv(a2), a1 - 2, a2 * 2, -a3])
|
||||
|
||||
affineAttr1 = self.module.affineMapAttr(aMap1)
|
||||
affineAttr2 = self.module.affineMapAttr(aMap2)
|
||||
affineAttr3 = self.module.affineMapAttr(aMap3)
|
||||
|
||||
t = self.module.make_memref_type(self.f32Type, [10])
|
||||
t_with_attr = t({
|
||||
"affine_attr_1": affineAttr1,
|
||||
"affine_attr_2": affineAttr2,
|
||||
"affine_attr_3": affineAttr3,
|
||||
})
|
||||
|
||||
f = self.module.declare_function("foo", [t, t_with_attr], [])
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testFunctionDeclarationWithAffineAttr
|
||||
# CHECK: func @foo(memref<10xf32>, memref<10xf32> {affine_attr_1 = affine_map<(d0, d1) -> (23, 44, s0)>, affine_attr_2 = affine_map<() -> (42)>, affine_attr_3 = affine_map<(d0, d1) -> (d1 * 44 + 23, (23 floordiv d1) mod 44, 1, 21, 88, -d1)>})
|
||||
|
||||
def testFunctionDeclarationWithArrayAttr(self):
|
||||
self.setUp()
|
||||
arrayAttr = self.module.arrayAttr([
|
||||
self.module.integerAttr(self.i32Type, 43),
|
||||
self.module.integerAttr(self.i32Type, 33),
|
||||
])
|
||||
t = self.module.make_memref_type(self.f32Type, [10])
|
||||
t_with_attr = t({"array_attr": arrayAttr})
|
||||
|
||||
f = self.module.declare_function("foo", [t, t_with_attr], [])
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testFunctionDeclarationWithArrayAttr
|
||||
# CHECK: func @foo(memref<10xf32>, memref<10xf32> {array_attr = [43 : i32, 33 : i32]})
|
||||
|
||||
def testFunctionDeclarationWithFloatAndStringAttr(self):
|
||||
self.setUp()
|
||||
float_attr = self.module.floatAttr(23.3)
|
||||
string_attr = self.module.stringAttr("TEST_STRING")
|
||||
|
||||
f = self.module.declare_function(
|
||||
"foo", [], [], float_attr=float_attr, string_attr=string_attr)
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testFunctionDeclarationWithFloatAndStringAttr
|
||||
# CHECK: func @foo() attributes {float_attr = 2.330000e+01 : f32, string_attr = "TEST_STRING"}
|
||||
|
||||
def testFunctionMultiple(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []):
|
||||
pass
|
||||
with self.module.new_function_context("foo", [], []):
|
||||
E.constant_index(0)
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testFunctionMultiple
|
||||
# CHECK: func @foo()
|
||||
# CHECK: func @foo_0()
|
||||
# CHECK: %{{.*}} = constant 0 : index
|
||||
|
||||
def testIndexCast(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("testIndexCast", [], []):
|
||||
index = E.constant_index(0)
|
||||
E.index_cast(index, self.i32Type)
|
||||
printWithCurrentFunctionName(str(self.module))
|
||||
# CHECK-LABEL: testIndexCast
|
||||
# CHECK: index_cast %{{.*}} : index to i32
|
||||
|
||||
def testIndexedValue(self):
|
||||
self.setUp()
|
||||
memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
|
||||
with self.module.new_function_context("indexed", [memrefType],
|
||||
[memrefType]) as fun:
|
||||
A = E.IndexedValue(fun.arg(0))
|
||||
cst = E.constant_float(1., self.f32Type)
|
||||
with E.LoopNestContext(
|
||||
[E.constant_index(0), E.constant_index(0)],
|
||||
[E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j):
|
||||
A.store([i, j], A.load([i, j]) + cst)
|
||||
E.ret([fun.arg(0)])
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testIndexedValue
|
||||
# CHECK: affine.for
|
||||
# CHECK: affine.for
|
||||
# CHECK: %{{.*}} affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x42xf32>
|
||||
# CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
|
||||
# CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x42xf32>
|
||||
|
||||
def testLoopContext(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []) as fun:
|
||||
lhs = E.constant_index(0)
|
||||
rhs = E.constant_index(42)
|
||||
with E.LoopContext(lhs, rhs, 1) as i:
|
||||
lhs + rhs + i
|
||||
with E.LoopContext(rhs, rhs + rhs, 2) as j:
|
||||
x = i + j
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testLoopContext
|
||||
# CHECK: affine.for
|
||||
# CHECK: {{.*}} = affine.apply affine_map<() -> (42)>()
|
||||
# CHECK: {{.*}} = affine.apply affine_map<(d0) -> (d0 + 42)>({{.*}})
|
||||
# CHECK: {{.*}} = affine.apply affine_map<() -> (84)>()
|
||||
# CHECK: affine.for {{.*}} = affine_map<(d0) -> (d0)>(%c42) to affine_map<(d0) -> (d0)>({{.*}}) step 2 {
|
||||
# CHECK: {{.*}} = affine.apply affine_map<(d0, d1) -> (d0 + d1)>({{.*}}, {{.*}})
|
||||
|
||||
def testLoopNestContext(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [], []) as fun:
|
||||
lbs = [E.constant_index(i) for i in range(4)]
|
||||
ubs = [E.constant_index(10 * i + 5) for i in range(4)]
|
||||
with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
|
||||
i + j + k + l
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testLoopNestContext
|
||||
# CHECK: affine.for
|
||||
# CHECK: affine.for
|
||||
# CHECK: affine.for
|
||||
# CHECK: affine.for
|
||||
# CHECK: {{.*}} = affine.apply affine_map<(d0, d1) -> (d0 + d1)>({{.*}}, {{.*}})
|
||||
# CHECK: {{.*}} = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>({{.*}}, {{.*}}, {{.*}})
|
||||
# CHECK: {{.*}} = affine.apply affine_map<(d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)>({{.*}}, {{.*}}, {{.*}}, {{.*}})
|
||||
|
||||
def testMLIRFunctionCreation(self):
|
||||
self.setUp()
|
||||
module = E.MLIRModule()
|
||||
t = module.make_type("f32")
|
||||
m = module.make_memref_type(t, [3, 4, -1, 5])
|
||||
printWithCurrentFunctionName(str(t))
|
||||
print(str(m))
|
||||
print(str(module.make_function("copy", [m, m], [])))
|
||||
print(str(module.make_function("sqrtf", [t], [t])))
|
||||
# CHECK-LABEL: testMLIRFunctionCreation
|
||||
# CHECK: f32
|
||||
# CHECK: memref<3x4x?x5xf32>
|
||||
# CHECK: func @copy(%{{.*}}: memref<3x4x?x5xf32>, %{{.*}}: memref<3x4x?x5xf32>) {
|
||||
# CHECK: func @sqrtf(%{{.*}}: f32) -> f32
|
||||
|
||||
def testMLIRScalarTypes(self):
|
||||
self.setUp()
|
||||
module = E.MLIRModule()
|
||||
printWithCurrentFunctionName(str(module.make_type("bf16")))
|
||||
print(str(module.make_type("f16")))
|
||||
print(str(module.make_type("f32")))
|
||||
print(str(module.make_type("f64")))
|
||||
print(str(module.make_type("i1")))
|
||||
print(str(module.make_type("i8")))
|
||||
print(str(module.make_type("i32")))
|
||||
print(str(module.make_type("i123")))
|
||||
print(str(module.make_type("index")))
|
||||
# CHECK-LABEL: testMLIRScalarTypes
|
||||
# CHECK: bf16
|
||||
# CHECK: f16
|
||||
# CHECK: f32
|
||||
# CHECK: f64
|
||||
# CHECK: i1
|
||||
# CHECK: i8
|
||||
# CHECK: i32
|
||||
# CHECK: i123
|
||||
# CHECK: index
|
||||
|
||||
def testMatrixMultiply(self):
|
||||
self.setUp()
|
||||
memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
|
||||
with self.module.new_function_context("matmul",
|
||||
[memrefType, memrefType, memrefType],
|
||||
[]) as fun:
|
||||
A = E.IndexedValue(fun.arg(0))
|
||||
B = E.IndexedValue(fun.arg(1))
|
||||
C = E.IndexedValue(fun.arg(2))
|
||||
c0 = E.constant_index(0)
|
||||
c32 = E.constant_index(32)
|
||||
with E.LoopNestContext([c0, c0, c0], [c32, c32, c32],
|
||||
[1, 1, 1]) as (i, j, k):
|
||||
C.store([i, j], A.load([i, k]) * B.load([k, j]))
|
||||
E.ret([])
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testMatrixMultiply
|
||||
# CHECK: affine.for
|
||||
# CHECK: affine.for
|
||||
# CHECK: affine.for
|
||||
# CHECK-DAG: %{{.*}} = affine.load
|
||||
# CHECK-DAG: %{{.*}} = affine.load
|
||||
# CHECK: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
|
||||
# CHECK: affine.store
|
||||
# CHECK-SAME: memref<32x32xf32>
|
||||
|
||||
def testRet(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [],
|
||||
[self.indexType, self.indexType]) as fun:
|
||||
c42 = E.constant_index(42)
|
||||
c0 = E.constant_index(0)
|
||||
E.ret([c42, c0])
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testRet
|
||||
# CHECK: %{{.*}} = constant 42 : index
|
||||
# CHECK: %{{.*}} = constant 0 : index
|
||||
# CHECK: return %{{.*}}, %{{.*}} : index, index
|
||||
|
||||
def testSelectOp(self):
|
||||
self.setUp()
|
||||
with self.module.new_function_context("foo", [self.boolType],
|
||||
[self.i32Type]) as fun:
|
||||
a = E.constant_int(42, 32)
|
||||
b = E.constant_int(0, 32)
|
||||
E.ret([E.select(fun.arg(0), a, b)])
|
||||
printWithCurrentFunctionName(str(fun))
|
||||
# CHECK-LABEL: testSelectOp
|
||||
# CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : i32
|
||||
|
||||
def testType(self):
|
||||
self.setUp()
|
||||
printWithCurrentFunctionName("")
|
||||
with self.module.new_function_context(
|
||||
"foo", [self.module.make_memref_type(self.f32Type, [10])], []) as fun:
|
||||
c42 = E.constant_int(42, 32)
|
||||
print(str(c42.type()))
|
||||
print(str(fun.arg(0).type()))
|
||||
# CHECK-LABEL: testType
|
||||
# CHECK: i32
|
||||
# CHECK: memref<10xf32>
|
||||
|
||||
|
||||
# Until python 3.6 this cannot be used because the order in the dict is not the
|
||||
# order of method declaration.
|
||||
def runTests():
|
||||
|
||||
def isTest(attr):
|
||||
return inspect.ismethod(attr) and "EdscTest.setUp " not in str(attr)
|
||||
|
||||
edscTest = EdscTest()
|
||||
tests = sorted(
|
||||
filter(isTest, (getattr(edscTest, attr) for attr in dir(edscTest))),
|
||||
key=lambda x: str(x))
|
||||
for test in tests:
|
||||
print("--> Running test:", test.__name__, file=sys.stderr)
|
||||
test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_utils.run_under_filecheck(__file__, runTests)
|
|
@ -0,0 +1,199 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
|
||||
from ..native.mlir import edsc
|
||||
from ..exporter import *
|
||||
from ..types import *
|
||||
|
||||
|
||||
class TracingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class EmitterRegistry:
|
||||
def __init__(self):
|
||||
self._func_emitters = {}
|
||||
|
||||
def register(self, func, emitter):
|
||||
self._func_emitters[func] = emitter
|
||||
|
||||
def lookup(self, func):
|
||||
return self._func_emitters.get(func)
|
||||
|
||||
def register_ufunc(self, ufunc, function_name):
|
||||
def emitter(pft, method, *inputs, **kwargs):
|
||||
if method == "__call__":
|
||||
if kwargs:
|
||||
raise TracingError("Generic ufunc with kwargs not supported %r" % (
|
||||
ufunc,))
|
||||
|
||||
# Map inputs to TracedArrays.
|
||||
# TODO: Process captures, promotions, etc.
|
||||
op_inputs = []
|
||||
for py_input in inputs:
|
||||
if not isinstance(py_input, TracedArray):
|
||||
raise TracingError("Unsupported ufunc input: %r", (py_input,))
|
||||
op_input = pft.get_traced_array_value(py_input)
|
||||
if op_input is None:
|
||||
raise TracingError("Unregistered traced array: %r", (py_input,))
|
||||
op_inputs.append(op_input)
|
||||
|
||||
# Emit op.
|
||||
mlir_m = pft.mlir_module
|
||||
op_result_types = [mlir_m.make_type("tensor<*x!numpy.any_dtype>")]
|
||||
op_result = edsc.op("numpy.generic_ufunc", op_inputs, op_result_types,
|
||||
ufunc_name=mlir_m.stringAttr(function_name))
|
||||
|
||||
# Wrap returns.
|
||||
return_array = TracedArray(pft)
|
||||
pft.set_traced_array(return_array, op_result)
|
||||
return return_array
|
||||
|
||||
raise TracingError("Unsupported ufunc method %r:%r" % (ufunc, method,))
|
||||
|
||||
self.register(ufunc, emitter)
|
||||
|
||||
|
||||
EMITTER_REGISTRY = EmitterRegistry()
|
||||
EMITTER_REGISTRY.register_ufunc(np.multiply, "numpy.multiply")
|
||||
EMITTER_REGISTRY.register_ufunc(np.add, "numpy.add")
|
||||
|
||||
|
||||
class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
||||
"""An array that traces its operations."""
|
||||
def __init__(self, pft: "PyFuncTrace"):
|
||||
self._pft = pft
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
def __repr__(self):
|
||||
return "<TracedArray %d>" % id(self)
|
||||
|
||||
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
||||
emitter = EMITTER_REGISTRY.lookup(ufunc)
|
||||
if emitter is None:
|
||||
return NotImplemented
|
||||
result = emitter(self._pft, method, *inputs, **kwargs)
|
||||
return result
|
||||
|
||||
|
||||
class PyFuncTrace:
|
||||
r"""Creates an MLIR function from an unwrapped python function.
|
||||
|
||||
# TODO: These constraints are too verbose and should be coming in by
|
||||
# example.
|
||||
>>> def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
... return a * b + a
|
||||
>>> exp = Exporter()
|
||||
>>> exp.simple_mul = simple_mul
|
||||
>>> exp.simple_mul.sig.args["a"] += Shape(1, 4)
|
||||
>>> exp.simple_mul.sig.args["a"] += DynamicDim(0)
|
||||
>>> exp.simple_mul.sig.args["a"] += DType(np.float32)
|
||||
>>> exp.simple_mul.sig.args["b"] += Shape(1)
|
||||
>>> exp.simple_mul.sig.args["b"] += DType(np.float32)
|
||||
>>> exp.simple_mul.sig.result += Shape(1, 4)
|
||||
>>> exp.simple_mul.sig.result += DynamicDim(0)
|
||||
>>> exp.simple_mul.sig.result += DType(np.float32)
|
||||
>>> pft = PyFuncTrace(exp.simple_mul)
|
||||
>>> pft.trace()
|
||||
>>> print(pft.mlir_module.get_ir().strip())
|
||||
module {
|
||||
func @simple_mul(%arg0: tensor<?x4xf32>, %arg1: tensor<1xf32>) -> tensor<?x4xf32> {
|
||||
%0 = "numpy.generic_ufunc"(%arg0, %arg1) {ufunc_name = "numpy.multiply"} : (tensor<?x4xf32>, tensor<1xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
%1 = "numpy.generic_ufunc"(%0, %arg0) {ufunc_name = "numpy.add"} : (tensor<*x!numpy.any_dtype>, tensor<?x4xf32>) -> tensor<*x!numpy.any_dtype>
|
||||
%2 = "numpy.narrow"(%1) : (tensor<*x!numpy.any_dtype>) -> tensor<?x4xf32>
|
||||
return %2 : tensor<?x4xf32>
|
||||
}
|
||||
}
|
||||
"""
|
||||
__slots__ = [
|
||||
"epf",
|
||||
"mlir_ctx",
|
||||
"mlir_fun",
|
||||
"mlir_module",
|
||||
"mlir_result_types",
|
||||
"_args_array_params",
|
||||
"_traced_arrays",
|
||||
"_python_args",
|
||||
"_result_array_params",
|
||||
]
|
||||
def __init__(self, epf: ExportPyFunction):
|
||||
self.mlir_module = edsc.MLIRModule()
|
||||
self.epf = epf
|
||||
self._traced_arrays = {} # Mapping of TracedArray to current consumer value
|
||||
self._validate()
|
||||
|
||||
# Extract ArrayParams for all args and results.
|
||||
self._args_array_params = [
|
||||
ArrayParams.from_constraints(arg.constraints)
|
||||
for arg in self.epf.sig.args]
|
||||
self._python_args = [None] * len(self._args_array_params)
|
||||
self._result_array_params = ArrayParams.from_constraints(
|
||||
self.epf.sig.result.constraints)
|
||||
|
||||
# Create the MLIR function.
|
||||
self.mlir_fun, self.mlir_result_types = self._create_mlir_function()
|
||||
self.mlir_ctx = self.mlir_module.function_context(self.mlir_fun)
|
||||
self._create_trace_roots()
|
||||
|
||||
def set_traced_array(self, traced_array, value_handle):
|
||||
"""Sets the current SSA value for a traced_array."""
|
||||
assert isinstance(traced_array, TracedArray)
|
||||
self._traced_arrays[traced_array] = value_handle
|
||||
|
||||
def get_traced_array_value(self, traced_array):
|
||||
return self._traced_arrays.get(traced_array)
|
||||
|
||||
def trace(self):
|
||||
# TODO: General argument merging
|
||||
with self.mlir_ctx:
|
||||
py_results = (self.epf.pyfunc(*self._python_args),)
|
||||
if len(py_results) != len(self.mlir_result_types):
|
||||
raise TracingError(
|
||||
"Traced function returned != %d results: %r" % (
|
||||
len(self.mlir_result_types), py_results,))
|
||||
|
||||
# Narrow all results to the declared return types.
|
||||
return_operands = []
|
||||
for py_result, mlir_result_type in zip(py_results, self.mlir_result_types):
|
||||
mlir_result = self.get_traced_array_value(py_result)
|
||||
if mlir_result is None:
|
||||
raise TracingError("Unregistered traced array: %r", (py_input,))
|
||||
# narrow to declared result type.
|
||||
return_operands.append(edsc.op(
|
||||
"numpy.narrow", [mlir_result], [mlir_result_type]))
|
||||
edsc.ret(return_operands)
|
||||
|
||||
def _validate(self):
|
||||
if not all(arg.type_class == TypeClass.NdArray
|
||||
for arg in self.epf.sig.args):
|
||||
raise NotImplementedError("Non NdArray args: %r" % (self.epf.sig.args,))
|
||||
if not self.epf.sig.result.type_class == TypeClass.NdArray:
|
||||
raise NotImplementedError("Non NdArray result: %r" % (
|
||||
self.epf.sig.result,))
|
||||
|
||||
def _create_mlir_function(self):
|
||||
mlir_m = self.mlir_module
|
||||
epf = self.epf
|
||||
f_args = [mlir_m.make_type(ap.mlir_tensor_type_asm)
|
||||
for ap in self._args_array_params]
|
||||
f_results = [mlir_m.make_type(
|
||||
self._result_array_params.mlir_tensor_type_asm)]
|
||||
return mlir_m.make_function(epf.__name__, f_args, f_results), f_results
|
||||
|
||||
def _create_trace_roots(self):
|
||||
for index, ap in enumerate(self._args_array_params):
|
||||
if ap is not None:
|
||||
ta = TracedArray(self)
|
||||
self.set_traced_array(ta, self.mlir_fun.arg(index))
|
||||
self._python_args[index] = ta
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
|
@ -0,0 +1,274 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import inspect
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
from .types import *
|
||||
|
||||
__all__ = [
|
||||
"Exporter",
|
||||
"ExportFunction",
|
||||
"ExportPyFunction",
|
||||
]
|
||||
|
||||
|
||||
def _value_type_from_annotation(annotation):
|
||||
# TODO: This is just enough to recognize ndarrays.
|
||||
if annotation is np.ndarray:
|
||||
return ValueType(TypeClass.NdArray)
|
||||
else:
|
||||
return ValueType()
|
||||
|
||||
|
||||
def _signature_from_pyfunc(pyfunc):
|
||||
pysig = inspect.signature(pyfunc)
|
||||
sig = Signature(len(pysig.parameters))
|
||||
# Arguments
|
||||
for i, param in enumerate(pysig.parameters.values()):
|
||||
if param.kind not in (
|
||||
param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
||||
raise ValueError(
|
||||
"Currently only positional function signature are supported")
|
||||
|
||||
sig.arg_names[i] = param.name
|
||||
annot = param.annotation
|
||||
if annot is param.empty: continue
|
||||
sig.args[i] = _value_type_from_annotation(annot)
|
||||
|
||||
# Result
|
||||
if pysig.return_annotation is not pysig.empty:
|
||||
sig.result = _value_type_from_annotation(pysig.return_annotation)
|
||||
|
||||
return sig
|
||||
|
||||
|
||||
class ExportFunction:
|
||||
"""Base class for functions that can be exported."""
|
||||
__slots__ = ["_sig"]
|
||||
def __init__(self, sig=None):
|
||||
self._sig = sig if sig else Signature()
|
||||
|
||||
@property
|
||||
def sig(self):
|
||||
return self._sig
|
||||
|
||||
def __repr__(self):
|
||||
return "def %r" % self._sig
|
||||
|
||||
|
||||
class ExportPyFunction(ExportFunction):
|
||||
"""Wraps a fully specialized python function that is staged for export.
|
||||
|
||||
At different phases of compilation, the wrapped function will be
|
||||
treated differently. At the initial phase, it is just a pass-through
|
||||
and provides introspection capabilities.
|
||||
|
||||
Basic access:
|
||||
>>> def simple(a, b): return a + b
|
||||
>>> ExportPyFunction(simple)
|
||||
pydef simple(a: Any, b: Any) -> Any
|
||||
>>> def mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
... return a * b
|
||||
>>> ExportPyFunction(mul)
|
||||
pydef mul(a: NdArray, b: NdArray) -> NdArray
|
||||
>>> ExportPyFunction(mul).sig
|
||||
(a: NdArray, b: NdArray) -> NdArray
|
||||
|
||||
Manipulating the signature:
|
||||
>>> f = ExportPyFunction(mul)
|
||||
>>> f.sig.args["a"] += Rank(2)
|
||||
>>> f.sig.args["b"] = "Any"
|
||||
>>> f.sig.result += Shape(1, 2)
|
||||
>>> f
|
||||
pydef mul(a: NdArray[Rank(2)], b: Any) -> NdArray[Shape(1, 2)]
|
||||
"""
|
||||
__slots__ = ExportFunction.__slots__ + ["_pyfunc", "__name__"]
|
||||
|
||||
def __init__(self, pyfunc, name=None):
|
||||
super().__init__(sig=_signature_from_pyfunc(pyfunc))
|
||||
assert (hasattr(pyfunc, "__call__")
|
||||
and hasattr(pyfunc, "__name__")), "Not a python function"
|
||||
self._pyfunc = pyfunc
|
||||
self.__name__ = name if name else pyfunc.__name__
|
||||
|
||||
@property
|
||||
def pyfunc(self):
|
||||
return self._pyfunc
|
||||
|
||||
def __repr__(self):
|
||||
return "pydef %s%r" % (self.__name__, self._sig)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._pyfunc(*args, **kwargs)
|
||||
|
||||
|
||||
class _ExpandoNode:
|
||||
"""Expando object that can be indexed into to construct a namespace."""
|
||||
__slots__ = [
|
||||
"_parent", "_services", "_local_name", "_parent_name",
|
||||
"_children", "_attached"]
|
||||
def __init__(self, parent: Optional["_ExpandoNode"],
|
||||
services: "_Services",
|
||||
local_name: str):
|
||||
super().__init__()
|
||||
object.__setattr__(self, "_parent", parent)
|
||||
object.__setattr__(self, "_services", services)
|
||||
object.__setattr__(self, "_local_name", local_name)
|
||||
object.__setattr__(self, "_parent_name",
|
||||
parent._get_full_name() if parent else "")
|
||||
object.__setattr__(self, "_children", {})
|
||||
object.__setattr__(self, "_attached", parent is None)
|
||||
|
||||
def _attach(self):
|
||||
if self._attached: return
|
||||
if self._local_name in self._parent._children:
|
||||
raise KeyError("Cannot re-assign '%s'" % (self._get_full_name(),))
|
||||
self._parent._attach()
|
||||
self._parent._children[self._local_name] = self
|
||||
object.__setattr__(self, "_attached", True)
|
||||
|
||||
def _get_full_name(self):
|
||||
if not self._parent: return "" # Root is always empty name.
|
||||
full_name = (self._parent_name + "." + self._local_name
|
||||
if self._parent_name else self._local_name)
|
||||
return full_name
|
||||
|
||||
def _get_child_name(self, child_local_name):
|
||||
full_name = self._get_full_name()
|
||||
if not full_name: return child_local_name
|
||||
else: return full_name + "." + child_local_name
|
||||
|
||||
def __repr__(self):
|
||||
return "Namespace(\"%s\")" % (self._get_full_name())
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._children
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = str(key)
|
||||
existing = self._children.get(key)
|
||||
if existing is not None: return existing
|
||||
# Speculatively create a child expando.
|
||||
child = _ExpandoNode(self, self._services, key)
|
||||
return child
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if not inspect.isfunction(value):
|
||||
raise TypeError("Cannot assign value to an exporter: %r" % (value,))
|
||||
child_name = self._get_child_name(key)
|
||||
if key in self._children:
|
||||
# TODO: Relax this once __delitem__ is implemented.
|
||||
raise KeyError("Cannot re-assign '%s'" % (child_name))
|
||||
self._attach()
|
||||
self._children[key] = self._services.wrap_function(value, child_name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self[name]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
try:
|
||||
self[name] = value
|
||||
except KeyError as e:
|
||||
raise AttributeError(str(e)) from None
|
||||
|
||||
def __dir__(self):
|
||||
return self._children.keys()
|
||||
|
||||
|
||||
class _Services:
|
||||
"""Services and support for the Exporter.
|
||||
|
||||
Exporters are user objects, so most of the functional components are
|
||||
contained in the associated _Services object.
|
||||
"""
|
||||
def wrap_function(self, f, full_name):
|
||||
if isinstance(f, ExportFunction): return f
|
||||
# TODO: Need to scan through providers and choose.
|
||||
return ExportPyFunction(f, name=full_name)
|
||||
|
||||
|
||||
class Exporter:
|
||||
"""Top-level UI object for assembling a program for export.
|
||||
|
||||
The exporter defines an open namespace of functions to be exported.
|
||||
Logically, it can be thought of as a dict-of-dicts that is populated
|
||||
by assignment of functions to leaves. The act of assigning a function
|
||||
captures it as an ExportFunction and binds it to the exporter. This
|
||||
ExportFunction exposes the object model that can be manipulated to
|
||||
refine the compiled form. By default, any calls to such functions will
|
||||
delegate to the original function, capturing examples that constrain
|
||||
and allow further optimizations on the compiled form.
|
||||
|
||||
There are several reserved names that can not have functions bound
|
||||
to them with the dot notation, but can still be referenced by subscripting
|
||||
if necessary:
|
||||
TODO: Reserved names. 'captures', etc.
|
||||
|
||||
>>> exp = Exporter()
|
||||
>>> exp
|
||||
Exporter()
|
||||
|
||||
Creating namespaces and functions with attribute access:
|
||||
>>> exp = Exporter()
|
||||
>>> exp.ns1
|
||||
Namespace("ns1")
|
||||
>>> "ns1" in exp # Not yet attached
|
||||
False
|
||||
>>> exp.ns1.ns2.f = lambda x: x
|
||||
>>> exp.ns1.ns2 # Should be attached
|
||||
Namespace("ns1.ns2")
|
||||
>>> exp.ns1.ns2.f
|
||||
pydef ns1.ns2.f(x: Any) -> Any
|
||||
|
||||
Via index access:
|
||||
>>> exp = Exporter()
|
||||
>>> exp["ns1"]["f"] = lambda x: x
|
||||
>>> dir(exp["ns1"])
|
||||
['f']
|
||||
>>> exp["ns1"]["f"]
|
||||
pydef ns1.f(x: Any) -> Any
|
||||
|
||||
Illegal access:
|
||||
>>> exp = Exporter()
|
||||
>>> exp.ns1.ns2.f = lambda x: x
|
||||
>>> exp.ns1.ns2.f = lambda x: x
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AttributeError: "Cannot re-assign 'ns1.ns2.f'"
|
||||
>>> exp.ns1 = lambda x: x
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AttributeError: "Cannot re-assign 'ns1'"
|
||||
"""
|
||||
__slots__ = ["_root", "_services"]
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
services = _Services()
|
||||
object.__setattr__(self, "_root", _ExpandoNode(None, services, ""))
|
||||
object.__setattr__(self, "_services", services)
|
||||
|
||||
def __repr__(self):
|
||||
return "Exporter()"
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._root
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._root[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._root[key] = value
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._root, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
setattr(self._root, name, value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,19 @@
|
|||
//===- mlir_init.cpp ------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/InitAllDialects.h"
|
||||
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
|
||||
void npcompMlirInitialize() {
|
||||
::mlir::registerAllDialects();
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namesapce npcomp
|
|
@ -0,0 +1,35 @@
|
|||
//===- native.cpp - MLIR Python bindings ----------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <cstddef>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace npcomp {
|
||||
namespace python {
|
||||
|
||||
// Externs
|
||||
void npcompMlirInitialize();
|
||||
void defineMlirEdscModule(py::module m);
|
||||
|
||||
PYBIND11_MODULE(native, m) {
|
||||
npcompMlirInitialize();
|
||||
m.doc() = "Npcomp native python bindings";
|
||||
|
||||
auto mlir_m = m.def_submodule("mlir", "MLIR interop");
|
||||
auto mlir_edsc_m = mlir_m.def_submodule("edsc");
|
||||
defineMlirEdscModule(mlir_edsc_m);
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
} // namespace npcomp
|
|
@ -0,0 +1,2 @@
|
|||
from . import native
|
||||
print(native.__doc__)
|
|
@ -0,0 +1,119 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TraceContext:
|
||||
"""Context for intercepting array traces.
|
||||
|
||||
Context manager:
|
||||
----------------
|
||||
Instances act as context managers, the inner-most of which can be
|
||||
queried with current() or optional_current().
|
||||
|
||||
>>> with TraceContext(desc=1) as tc:
|
||||
... print(tc)
|
||||
... print(TraceContext.current())
|
||||
<TraceContext 1>
|
||||
<TraceContext 1>
|
||||
>>> print(TraceContext.optional_current())
|
||||
None
|
||||
>>> TraceContext.current()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
RuntimeError: No active TraceContext
|
||||
|
||||
Unique ids:
|
||||
-----------
|
||||
Many things in tracing require a context-local id.
|
||||
|
||||
>>> with TraceContext() as tc:
|
||||
... print(tc.get_next_id())
|
||||
... print(tc.get_next_id())
|
||||
1
|
||||
2
|
||||
|
||||
"""
|
||||
_local = threading.local()
|
||||
|
||||
def __init__(self, desc=None):
|
||||
self._desc = desc
|
||||
self._next_id = 1
|
||||
|
||||
def get_next_id(self):
|
||||
"""Gets the next unique id for the context."""
|
||||
rv = self._next_id
|
||||
self._next_id += 1
|
||||
return rv
|
||||
|
||||
@classmethod
|
||||
def _get_context_stack(cls):
|
||||
try:
|
||||
return cls._local.s
|
||||
except AttributeError:
|
||||
cls._local.s = []
|
||||
return cls._local.s
|
||||
|
||||
@classmethod
|
||||
def optional_current(cls) -> Optional["TraceContext"]:
|
||||
s = cls._get_context_stack()
|
||||
if s:
|
||||
return s[-1]
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def current(cls) -> "TraceContext":
|
||||
c = cls.optional_current()
|
||||
if c is None:
|
||||
raise RuntimeError("No active TraceContext")
|
||||
return c
|
||||
|
||||
def __enter__(self):
|
||||
s = self._get_context_stack()
|
||||
s.append(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
s = self._get_context_stack()
|
||||
s.pop()
|
||||
|
||||
def __repr__(self):
|
||||
return "<TraceContext %r>" % self._desc
|
||||
|
||||
class TracedArray(np.lib.mixins.NDArrayOperatorsMixin):
|
||||
"""An array that traces its operations.
|
||||
|
||||
Unique ids:
|
||||
-----------
|
||||
>>> tc = TraceContext()
|
||||
>>> TracedArray(tc=tc)
|
||||
<TracedArray 1>
|
||||
>>> TracedArray(tc=tc)
|
||||
<TracedArray 2>
|
||||
"""
|
||||
def __init__(self, tc: Optional[TraceContext] = None):
|
||||
self._tc = tc if tc is not None else TraceContext.current()
|
||||
self._uid = self._tc.get_next_id()
|
||||
|
||||
@property
|
||||
def uid(self):
|
||||
return self._uid
|
||||
|
||||
def __repr__(self):
|
||||
return "<TracedArray %d>" % self._uid
|
||||
|
||||
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
|
||||
return NotImplemented
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
|
@ -0,0 +1,117 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import context
|
||||
from ..native.mlir import edsc
|
||||
|
||||
|
||||
def _map_typing_to_mlir_type(mlir_m, typing_annot):
|
||||
"""Maps a typing annotation to an MLIR type.
|
||||
|
||||
Args:
|
||||
mlir_m: MLIRModule.
|
||||
typing_annot: Value for an __annotations__ entry.
|
||||
Returns:
|
||||
MLIR type or None if not mappable.
|
||||
"""
|
||||
if typing_annot is np.ndarray:
|
||||
return mlir_m.make_type("tensor<*x!numpy.any_dtype>")
|
||||
return None
|
||||
|
||||
|
||||
class GenericFunctionTrace:
|
||||
"""Represents a trace of a 'generic' python function in progress."""
|
||||
|
||||
def __init__(self, mlir_m, mlir_f):
|
||||
self._mlir_m = mlir_m
|
||||
self._mlir_f = mlir_f
|
||||
|
||||
@property
|
||||
def mlir_module(self):
|
||||
return self._mlir_m
|
||||
|
||||
@property
|
||||
def mlir_function(self):
|
||||
return self._mlir_f
|
||||
|
||||
@classmethod
|
||||
def from_typed_pyfunc(cls, mlir_m, pyfunc, name_in_module=None):
|
||||
"""Creates a generic function trace from a pyfunc with type annotations.
|
||||
|
||||
This is a relatively limited mechanism which relies on typing annotations
|
||||
for arguments and results and supports a relatively limited amount of
|
||||
variation.
|
||||
|
||||
Examples:
|
||||
|
||||
* Generic ndarrays:
|
||||
>>> m = edsc.MLIRModule()
|
||||
>>> def simple_mul(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
... return a * b
|
||||
>>> gft = GenericFunctionTrace.from_typed_pyfunc(m, simple_mul)
|
||||
>>> ir = gft.mlir_module.get_ir()
|
||||
>>> print(re.findall("func @simple_mul.+", ir)[0])
|
||||
func @simple_mul$$generic(%arg0: tensor<*x!numpy.any_dtype> {py_name = "a"}, %arg1: tensor<*x!numpy.any_dtype> {py_name = "b"}) -> tensor<*x!numpy.any_dtype> attributes {py_ftype = "generic_trace", py_name = "simple_mul"} {
|
||||
|
||||
* None types must be annotated:
|
||||
>>> m = edsc.MLIRModule()
|
||||
>>> def simple_mul(a: np.ndarray, b: np.ndarray) -> None:
|
||||
... return a * b
|
||||
>>> gft = GenericFunctionTrace.from_typed_pyfunc(m, simple_mul)
|
||||
>>> ir = gft.mlir_module.get_ir()
|
||||
>>> print(re.findall("func @simple_mul.+", ir)[0])
|
||||
func @simple_mul$$generic(%arg0: tensor<*x!numpy.any_dtype> {py_name = "a"}, %arg1: tensor<*x!numpy.any_dtype> {py_name = "b"}) attributes {py_ftype = "generic_trace", py_name = "simple_mul"} {
|
||||
|
||||
Args:
|
||||
mlir_m: An MLIRModule.
|
||||
pyfunc: A python function to transform.
|
||||
Returns:
|
||||
A new GenericFunctionTrace.
|
||||
"""
|
||||
if name_in_module is None:
|
||||
name_in_module = pyfunc.__name__ + "$$generic"
|
||||
code = pyfunc.__code__
|
||||
# Process arguments.
|
||||
f_args = []
|
||||
for i in range(code.co_argcount):
|
||||
arg_name = code.co_varnames[i]
|
||||
arg_annot = pyfunc.__annotations__.get(arg_name)
|
||||
if arg_annot is None:
|
||||
raise ValueError("Function %s arg %d is missing a typing annotation" % (
|
||||
pyfunc.__name__, i))
|
||||
arg_type = _map_typing_to_mlir_type(mlir_m, arg_annot)
|
||||
if arg_type is None:
|
||||
raise ValueError("Function %s arg %d is not a supported type" % (
|
||||
pyfunc.__name__, i))
|
||||
arg_type = arg_type({
|
||||
"py_name": mlir_m.stringAttr(arg_name),
|
||||
})
|
||||
f_args.append(arg_type)
|
||||
|
||||
# Process results.
|
||||
f_results = []
|
||||
if "return" not in pyfunc.__annotations__:
|
||||
raise ValueError("Un-annotated function returns not yet supported")
|
||||
return_annot = pyfunc.__annotations__["return"]
|
||||
if return_annot is not None:
|
||||
return_type = _map_typing_to_mlir_type(mlir_m, return_annot)
|
||||
if return_type is None:
|
||||
raise ValueError("Function %s return type %r is not supported" % (
|
||||
pyfunc.__name__, return_annot))
|
||||
f_results.append(return_type)
|
||||
|
||||
mlir_f = mlir_m.make_function(
|
||||
name_in_module, f_args, f_results,
|
||||
py_ftype=mlir_m.stringAttr("generic_trace"),
|
||||
py_name=mlir_m.stringAttr(pyfunc.__name__))
|
||||
return GenericFunctionTrace(mlir_m, mlir_f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
|
@ -0,0 +1,695 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
"Unspec",
|
||||
"ArrayConstraint",
|
||||
"ArrayParams",
|
||||
"DType",
|
||||
"DimFlag",
|
||||
"DimFlagEnum",
|
||||
"DynamicDim",
|
||||
"Rank",
|
||||
"Shape",
|
||||
"Signature",
|
||||
"TypeClass",
|
||||
"TypeConstraints",
|
||||
"ValueType",
|
||||
]
|
||||
|
||||
# TODO: All supported types
|
||||
_DTYPE_TO_ASM_DICT = {
|
||||
np.bool: "i1", # TODO: May need a custom type to signify 8bit storage
|
||||
np.int8: "s8",
|
||||
np.int16: "s16",
|
||||
np.int32: "s32",
|
||||
np.int64: "s64",
|
||||
np.float32: "f32",
|
||||
np.float64: "f64",
|
||||
}
|
||||
|
||||
|
||||
def _dtype_to_mlir_asm(dtype):
|
||||
return _DTYPE_TO_ASM_DICT.get(dtype)
|
||||
|
||||
|
||||
class _LiterateEnum(Enum):
|
||||
"""An enum that can be parsed/printed based on its name.
|
||||
|
||||
>>> class SampleEnum(_LiterateEnum):
|
||||
... Red = 1
|
||||
... Blue = 2
|
||||
>>> SampleEnum.Red
|
||||
Red
|
||||
>>> SampleEnum.parse("Red")
|
||||
Red
|
||||
>>> SampleEnum.parse("Mauve")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Cannot parse SampleEnum 'Mauve'
|
||||
>>> SampleEnum.parse("parse")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Cannot parse SampleEnum 'parse'
|
||||
>>> SampleEnum.parse(None)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Cannot parse SampleEnum None
|
||||
>>> SampleEnum.parse(1.0)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Cannot parse SampleEnum 1.0
|
||||
|
||||
"""
|
||||
@classmethod
|
||||
def parse(cls, v):
|
||||
if isinstance(v, cls): return v
|
||||
if not v or not isinstance(v, str) or v[0] == '_' or not hasattr(cls, v):
|
||||
raise ValueError("Cannot parse %s %r" % (
|
||||
cls.__name__.split(".")[-1], v,))
|
||||
value = getattr(cls, v)
|
||||
if not isinstance(value, cls):
|
||||
raise ValueError("Cannot parse %s %r" % (
|
||||
cls.__name__.split(".")[-1], v,))
|
||||
return value
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
# Special "unspecified" value that we use throughout.
|
||||
class _Unspec:
|
||||
__slots__ = []
|
||||
def __str__(self):
|
||||
return "Unspec"
|
||||
def __repr__(self):
|
||||
return "Unspec"
|
||||
Unspec = _Unspec()
|
||||
|
||||
|
||||
class TypeClass(_LiterateEnum):
|
||||
"""Top level types in the npcomp language."""
|
||||
Any = 0
|
||||
NdArray = 1
|
||||
|
||||
|
||||
class ValueType:
|
||||
"""The type a value can take in the npcomp language.
|
||||
|
||||
Types of values in npcomp are always being refined and are therefore
|
||||
mutable. Instances represent the type derived for a single value, not a
|
||||
concept of "typeness" generally.
|
||||
|
||||
>>> ValueType()
|
||||
Any
|
||||
>>> ValueType('NdArray')
|
||||
NdArray
|
||||
>>> ValueType('NdArray', DType(np.float32), Rank(2))
|
||||
NdArray[DType(float32), Rank(2)]
|
||||
>>> vt = ValueType('NdArray')
|
||||
>>> vt += Rank(3)
|
||||
>>> vt += DynamicDim(1)
|
||||
>>> vt
|
||||
NdArray[Rank(3), DimFlag(Dynamic, (1,))]
|
||||
>>> vt = ValueType()
|
||||
>>> vt.type_class = 'NdArray'
|
||||
>>> vt
|
||||
NdArray
|
||||
"""
|
||||
__slots__ = ["_constraints", "_type_class"]
|
||||
|
||||
def __init__(self, type_class=TypeClass.Any, *constraints):
|
||||
super().__init__()
|
||||
self._type_class = TypeClass.parse(type_class)
|
||||
self._constraints = TypeConstraints(constraints)
|
||||
|
||||
def __iadd__(self, constraint):
|
||||
assert isinstance(constraint, TypeConstraint), (
|
||||
"Can only add constraints to a ValueType")
|
||||
self._constraints.append(constraint)
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
if not self._constraints:
|
||||
return repr(self._type_class)
|
||||
return "%r[%s]" % (self._type_class,
|
||||
", ".join([repr(c) for c in self._constraints]))
|
||||
|
||||
@property
|
||||
def type_class(self):
|
||||
return self._type_class
|
||||
|
||||
@type_class.setter
|
||||
def type_class(self, type_class):
|
||||
self._type_class = TypeClass.parse(type_class)
|
||||
|
||||
@property
|
||||
def constraints(self):
|
||||
return self._constraints
|
||||
|
||||
|
||||
class ValueTypeList:
|
||||
"""Models a list of ValueTypes.
|
||||
|
||||
>>> v3 = ValueTypeList(3)
|
||||
>>> v3
|
||||
(Any, Any, Any)
|
||||
>>> v3[1]
|
||||
Any
|
||||
>>> v3[2] = 'NdArray'
|
||||
>>> v3
|
||||
(Any, Any, NdArray)
|
||||
>>> v3[2] += Rank(2)
|
||||
>>> v3
|
||||
(Any, Any, NdArray[Rank(2)])
|
||||
|
||||
With names:
|
||||
>>> v3 = ValueTypeList(3, [None, "b", None])
|
||||
>>> v3[1] = 'NdArray'
|
||||
>>> v3["b"]
|
||||
NdArray
|
||||
>>> v3["b"] = 'Any'
|
||||
>>> v3
|
||||
(Any, Any, Any)
|
||||
"""
|
||||
__slots__ = ["_list", "_names"]
|
||||
def __init__(self, arity=0, names=None):
|
||||
self._list = [ValueType() for _ in range(arity)]
|
||||
self._names = names
|
||||
|
||||
def _key_to_index(self, key):
|
||||
if isinstance(key, str):
|
||||
# Scan for the index.
|
||||
if self._names:
|
||||
for i, n in enumerate(self._names):
|
||||
if n == key: return i
|
||||
raise KeyError("Unknown key '%s'" % key)
|
||||
return key
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._list[self._key_to_index(key)]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if not isinstance(value, ValueType):
|
||||
value = ValueType(value)
|
||||
self._list[self._key_to_index(key)] = value
|
||||
|
||||
def __iter__(self):
|
||||
return self._list.__iter__()
|
||||
|
||||
def __repr__(self):
|
||||
return "(%s)" % (", ".join(repr(t) for t in self._list),)
|
||||
|
||||
|
||||
class Signature:
|
||||
"""A function signature.
|
||||
|
||||
This currently only models a linear list of positional arguments and
|
||||
assumes that multiple results will be represented by some form of tuple
|
||||
type.
|
||||
|
||||
>>> Signature()
|
||||
() -> Any
|
||||
>>> Signature(2)
|
||||
(Any, Any) -> Any
|
||||
>>> s = Signature(2)
|
||||
>>> s.args[1] = 'NdArray'
|
||||
>>> s.args[1] += Rank(2)
|
||||
>>> s
|
||||
(Any, NdArray[Rank(2)]) -> Any
|
||||
>>> s.result = 'NdArray'
|
||||
>>> s.result += Rank(3)
|
||||
>>> s
|
||||
(Any, NdArray[Rank(2)]) -> NdArray[Rank(3)]
|
||||
>>> s.arg_names[0] = 'a'
|
||||
>>> s.arg_names[1] = 'b'
|
||||
>>> s
|
||||
(a: Any, b: NdArray[Rank(2)]) -> NdArray[Rank(3)]
|
||||
"""
|
||||
__slots__ = ["_args", "_arg_names", "_result"]
|
||||
def __init__(self, arity=0):
|
||||
super().__init__()
|
||||
self._result = ValueType()
|
||||
self._arg_names = [None] * arity
|
||||
self._args = ValueTypeList(arity, names=self._arg_names)
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return self._args
|
||||
|
||||
@property
|
||||
def arg_names(self):
|
||||
return self._arg_names
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return self._result
|
||||
|
||||
@result.setter
|
||||
def result(self, value):
|
||||
if not isinstance(value, ValueType):
|
||||
value = ValueType(value)
|
||||
self._result = value
|
||||
|
||||
def __repr__(self):
|
||||
args_repr = "(%s)" % (
|
||||
", ".join(
|
||||
((n + ": " + repr(t)) if n else repr(t))
|
||||
for t, n in zip(self._args, self._arg_names)),)
|
||||
return "%s -> %r" % (args_repr, self._result)
|
||||
|
||||
class ArrayParams:
|
||||
"""Represents parameters defining how to construct an array.
|
||||
|
||||
>>> ArrayParams()
|
||||
ArrayParams(dtype=Unspec)
|
||||
>>> ArrayParams(np.float32)
|
||||
ArrayParams(dtype=float32)
|
||||
>>> ArrayParams(np.float32, rank=4)
|
||||
ArrayParams(dtype=float32, shape=(-1, -1, -1, -1))
|
||||
>>> ArrayParams(np.float32, shape=(1, 2, 3))
|
||||
ArrayParams(dtype=float32, shape=(1, 2, 3))
|
||||
"""
|
||||
__slots__ = ["dtype", "shape"]
|
||||
|
||||
def __init__(self, dtype=Unspec, shape=Unspec, rank=Unspec):
|
||||
self.dtype = dtype
|
||||
if shape is not Unspec:
|
||||
self.shape = shape
|
||||
elif rank is not Unspec:
|
||||
self.shape = [-1 for _ in range(rank)]
|
||||
else:
|
||||
self.shape = Unspec
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
if self.shape is Unspec: return Unspec
|
||||
return len(self.shape)
|
||||
|
||||
@classmethod
|
||||
def from_constraints(cls, constraints):
|
||||
"""Constructs params for a TypeConstraints list.
|
||||
|
||||
Unconstrained:
|
||||
>>> ArrayParams.from_constraints(TypeConstraints())
|
||||
ArrayParams(dtype=Unspec)
|
||||
|
||||
DType constrained:
|
||||
>>> ArrayParams.from_constraints(TypeConstraints(DType(np.float32)))
|
||||
ArrayParams(dtype=float32)
|
||||
|
||||
Rank constrained:
|
||||
>>> ArrayParams.from_constraints(TypeConstraints(Rank(2)))
|
||||
ArrayParams(dtype=Unspec, shape=(-1, -1))
|
||||
|
||||
Shape constrained:
|
||||
>>> ArrayParams.from_constraints(TypeConstraints(Shape(1, 2, 3)))
|
||||
ArrayParams(dtype=Unspec, shape=(1, 2, 3))
|
||||
>>> ArrayParams.from_constraints(TypeConstraints(
|
||||
... Rank(3), Shape(1, 2, 3)))
|
||||
ArrayParams(dtype=Unspec, shape=(1, 2, 3))
|
||||
|
||||
Shape constrained with dynamic dim constraint:
|
||||
>>> ArrayParams.from_constraints(TypeConstraints(
|
||||
... Shape(1, 2, 3), DynamicDim(1)))
|
||||
ArrayParams(dtype=Unspec, shape=(1, -1, 3))
|
||||
>>> ArrayParams.from_constraints(TypeConstraints(
|
||||
... Shape(1, 2, 3), DynamicDim((0, 2))))
|
||||
ArrayParams(dtype=Unspec, shape=(-1, 2, -1))
|
||||
|
||||
Errors:
|
||||
>>> ArrayParams.from_constraints(TypeConstraints(
|
||||
... Rank(4), Shape(1, 2, 3)))
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Conflicting shape and rank: Rank(4) vs Shape(1, 2, 3)
|
||||
>>> ArrayParams.from_constraints(TypeConstraints(
|
||||
... Shape(1, 2, 3), DynamicDim((0, 5))))
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Out of range DimFlag(Dynamic, (0, 5)) for shape [-1, 2, 3]
|
||||
"""
|
||||
# TODO: Should have a 'canonicalize' method on TypeConstraints which
|
||||
# reduces and verifies.
|
||||
dtype_c = constraints.one_of(DType)
|
||||
shape_c = constraints.one_of(Shape)
|
||||
rank_c = constraints.one_of(Rank)
|
||||
dim_flags = constraints.all_of(DimFlag)
|
||||
|
||||
dtype = dtype_c.dtype if dtype_c else Unspec
|
||||
shape = Unspec
|
||||
|
||||
# Compute shape
|
||||
if shape_c:
|
||||
# TODO: Should be in canonicalizer
|
||||
if rank_c and rank_c.rank != len(shape_c.dims):
|
||||
raise ValueError("Conflicting shape and rank: %r vs %r" % (
|
||||
rank_c, shape_c))
|
||||
shape = list(shape_c.dims)
|
||||
elif rank_c:
|
||||
shape = [-1 for _ in range(rank_c.rank)]
|
||||
|
||||
# Apply dim flags
|
||||
if shape is not Unspec and dim_flags:
|
||||
for df in dim_flags:
|
||||
flag, for_dims = df.dim_flag
|
||||
for d in for_dims:
|
||||
if d < 0 or d >= len(shape):
|
||||
raise ValueError("Out of range %r for shape %r" % (
|
||||
df, shape))
|
||||
if flag == DimFlagEnum.Dynamic:
|
||||
shape[d] = -1
|
||||
|
||||
return cls(dtype=dtype, shape=shape)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
try:
|
||||
s = "ArrayParams(dtype=%s" % (
|
||||
self.dtype.__name__ if isinstance(self.dtype, type) else self.dtype,)
|
||||
if self.shape is not Unspec:
|
||||
s += ", shape=%r" % (tuple(self.shape),)
|
||||
s += ")"
|
||||
return s
|
||||
except:
|
||||
return "ArrayParams(ERROR)"
|
||||
|
||||
@property
|
||||
def is_concrete(self):
|
||||
"""Returns true if the parameters are sufficient to construct an ndarray.
|
||||
|
||||
>>> ArrayParams().is_concrete
|
||||
False
|
||||
>>> ArrayParams(dtype=np.float32).is_concrete
|
||||
False
|
||||
>>> ArrayParams(dtype=np.float32, rank=1).is_concrete
|
||||
False
|
||||
>>> ArrayParams(dtype=np.float32, shape=(1, 2)).is_concrete
|
||||
True
|
||||
"""
|
||||
if self.dtype is Unspec:
|
||||
return False
|
||||
if self.shape is Unspec:
|
||||
return False
|
||||
if any(d < 0 for d in self.shape):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def mlir_tensor_type_asm(self):
|
||||
"""Get a corresponding MLIR tensor type.
|
||||
|
||||
Fully Unspecified:
|
||||
>>> ArrayParams().mlir_tensor_type_asm
|
||||
'tensor<*x!numpy.any_dtype>'
|
||||
|
||||
Unranked:
|
||||
>>> ArrayParams(dtype=np.float32).mlir_tensor_type_asm
|
||||
'tensor<*xf32>'
|
||||
|
||||
Ranked:
|
||||
>>> ArrayParams(dtype=np.float32, rank=3).mlir_tensor_type_asm
|
||||
'tensor<?x?x?xf32>'
|
||||
>>> ArrayParams(dtype=np.float32, shape=(-1, -1)).mlir_tensor_type_asm
|
||||
'tensor<?x?xf32>'
|
||||
|
||||
Scalar:
|
||||
>>> ArrayParams(dtype=np.float32, rank=0).mlir_tensor_type_asm
|
||||
'tensor<f32>'
|
||||
>>> ArrayParams(dtype=np.float32, shape=()).mlir_tensor_type_asm
|
||||
'tensor<f32>'
|
||||
|
||||
Shaped:
|
||||
>>> ArrayParams(dtype=np.float32, shape=(2, 3)).mlir_tensor_type_asm
|
||||
'tensor<2x3xf32>'
|
||||
>>> ArrayParams(dtype=np.float32, shape=(-1, 3)).mlir_tensor_type_asm
|
||||
'tensor<?x3xf32>'
|
||||
"""
|
||||
if self.dtype is Unspec:
|
||||
dtype_asm = "!numpy.any_dtype"
|
||||
else:
|
||||
dtype_asm = _dtype_to_mlir_asm(self.dtype)
|
||||
if not dtype_asm:
|
||||
raise ValueError(
|
||||
"Unsupported MLIR tensor element type %r" % (self.dtype,))
|
||||
if self.shape is Unspec:
|
||||
shape_asm = "*"
|
||||
else:
|
||||
shape_asm = "x".join((str(d) if d >= 0 else "?") for d in self.shape)
|
||||
if shape_asm: shape_asm += "x"
|
||||
return "tensor<%s%s>" % (shape_asm, dtype_asm)
|
||||
|
||||
def new_ndarray(self):
|
||||
"""Creates a new ndarray from these params.
|
||||
|
||||
>>> ArrayParams().new_ndarray()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: ArrayParams(dtype=Unspec) is not concrete
|
||||
>>> ArrayParams(np.float32, (1, 2)).new_ndarray() * 0.0
|
||||
array([[0., 0.]], dtype=float32)
|
||||
"""
|
||||
if not self.is_concrete:
|
||||
raise ValueError("%r is not concrete" % (self,))
|
||||
return np.ndarray(dtype=self.dtype, shape=self.shape)
|
||||
|
||||
|
||||
class TypeConstraint:
|
||||
"""Base class for type constraints."""
|
||||
pass
|
||||
|
||||
|
||||
class TypeConstraints(list):
|
||||
"""Collection of type constraints.
|
||||
|
||||
>>> TypeConstraints([DynamicDim()])
|
||||
TypeConstraints(DimFlag(Dynamic, Unspec))
|
||||
>>> TypeConstraints([DynamicDim(), Rank(4)])
|
||||
TypeConstraints(DimFlag(Dynamic, Unspec), Rank(4))
|
||||
>>> TypeConstraints(DynamicDim(), Rank(4))
|
||||
TypeConstraints(DimFlag(Dynamic, Unspec), Rank(4))
|
||||
>>> TypeConstraints(Rank(4))
|
||||
TypeConstraints(Rank(4))
|
||||
>>> TypeConstraints("foobar")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError
|
||||
"""
|
||||
def __init__(self, *constraints):
|
||||
if len(constraints) == 1 and not isinstance(
|
||||
constraints[0], ArrayConstraint):
|
||||
constraints = constraints[0]
|
||||
super().__init__(constraints)
|
||||
assert(all(isinstance(c, ArrayConstraint) for c in self))
|
||||
|
||||
def __repr__(self):
|
||||
return "TypeConstraints(%s)" % (
|
||||
", ".join([repr(c) for c in self]))
|
||||
|
||||
def all_of(self, clazz):
|
||||
"""Finds all of the given class."""
|
||||
return [c for c in self if isinstance(c, clazz)]
|
||||
|
||||
def one_of(self, clazz):
|
||||
"""Finds at most one constraint of the given class."""
|
||||
found = [c for c in self if isinstance(c, clazz)]
|
||||
if not found: return None
|
||||
if len(found) > 1:
|
||||
raise ValueError("Conflicting constraints. Expected one of %r. Got %r" % (
|
||||
clazz, found))
|
||||
return found[0]
|
||||
|
||||
|
||||
class ArrayConstraint(TypeConstraint):
|
||||
"""Base class for a constraint on an array's characteristics."""
|
||||
def implies_dtype(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def implies_rank(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def implies_dims(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def dims(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def implies_dim_flag(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def dim_flag(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DType(ArrayConstraint):
|
||||
"""A constraint on a dtype.
|
||||
|
||||
DType constraints are exclusive with only one permitted in a set.
|
||||
|
||||
>>> DType(np.float32)
|
||||
DType(float32)
|
||||
>>> DType("foobar")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError
|
||||
"""
|
||||
__slots__ = ["_dtype"]
|
||||
|
||||
def __init__(self, dtype):
|
||||
super().__init__()
|
||||
assert isinstance(dtype, type)
|
||||
self._dtype = dtype
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
def implies_dtype(self):
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return "DType(%s)" % (self._dtype.__name__,)
|
||||
|
||||
|
||||
class Rank(ArrayConstraint):
|
||||
"""Establishes a fixed rank for the array.
|
||||
|
||||
>>> Rank(1)
|
||||
Rank(1)
|
||||
>>> Rank(0)
|
||||
Rank(0)
|
||||
>>> Rank(-1)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError
|
||||
>>> Rank("foobar")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError
|
||||
|
||||
"""
|
||||
__slots__ = ["_rank"]
|
||||
|
||||
def __init__(self, rank):
|
||||
super().__init__()
|
||||
assert(isinstance(rank, int) and rank >= 0)
|
||||
self._rank = rank
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
def implies_rank(self):
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return "Rank(%d)" % (self._rank)
|
||||
|
||||
|
||||
class Shape(ArrayConstraint):
|
||||
"""Establishes a static shape for an array.
|
||||
|
||||
All dimensions must be a non-negative integer or Unspec.
|
||||
|
||||
>>> Shape(1, 2, 3)
|
||||
Shape(1, 2, 3)
|
||||
>>> Shape(Unspec, 1)
|
||||
Shape(Unspec, 1)
|
||||
>>> Shape()
|
||||
Shape()
|
||||
>>> Shape(-1, 1)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AssertionError
|
||||
"""
|
||||
__slots__ = ["_dims"]
|
||||
|
||||
def __init__(self, *dims):
|
||||
super().__init__()
|
||||
assert(all(d is Unspec or (isinstance(d, int) and d >= 0) for d in dims))
|
||||
self._dims = tuple(dims)
|
||||
|
||||
@property
|
||||
def dims(self):
|
||||
return self._dims
|
||||
|
||||
def implies_dims(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return len(self._dims)
|
||||
|
||||
def implies_rank(self):
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
return "Shape(%s)" % (", ".join(str(d) for d in self._dims))
|
||||
|
||||
|
||||
class DimFlagEnum(_LiterateEnum):
|
||||
"""Flag for the kind of DimFlag constraint."""
|
||||
Dynamic = 1
|
||||
|
||||
|
||||
class DimFlag(ArrayConstraint):
|
||||
"""Generic flag applying to one or more dimensions.
|
||||
|
||||
If dims is Unspec, the flag applies to all dims.
|
||||
|
||||
>>> DimFlag("Dynamic")
|
||||
DimFlag(Dynamic, Unspec)
|
||||
>>> DimFlag("Dynamic", 1)
|
||||
DimFlag(Dynamic, (1,))
|
||||
>>> DimFlag("Dynamic", (0, 1))
|
||||
DimFlag(Dynamic, (0, 1))
|
||||
"""
|
||||
__slots__ = ["_flag", "_dims"]
|
||||
|
||||
def __init__(self, flag, dims=Unspec):
|
||||
super().__init__()
|
||||
self._flag = DimFlagEnum.parse(flag)
|
||||
if isinstance(dims, int):
|
||||
assert(dims >= 0)
|
||||
self._dims = (dims,)
|
||||
elif dims is Unspec:
|
||||
self._dims = Unspec
|
||||
else:
|
||||
self._dims = tuple(dims)
|
||||
assert(all(isinstance(d, int) and d >= 0 for d in self._dims))
|
||||
|
||||
def implies_dim_flag(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def dim_flag(self):
|
||||
return self._flag, self._dims
|
||||
|
||||
def __repr__(self):
|
||||
return "DimFlag(%r, %r)" % (self._flag, self._dims)
|
||||
|
||||
|
||||
def DynamicDim(dims=Unspec):
|
||||
"""Dim flag that signals a dimension should be considered dynamic."""
|
||||
return DimFlag(DimFlagEnum.Dynamic, dims)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
global: PyInit_native;
|
||||
local: *;
|
||||
};
|
|
@ -0,0 +1,48 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
def run_under_filecheck(main_file, callback, disable_filecheck=False):
|
||||
"""Runs a callback under a FileCheck sub-process.
|
||||
|
||||
This is typically called from a main context and will sys.exit on
|
||||
completion.
|
||||
|
||||
Args:
|
||||
main_file: The file to process filecheck directives on. Typically
|
||||
__file__ from the caller's perspective.
|
||||
callback: The no-argument callback to invoke.
|
||||
disable_filecheck: Whether to disable filecheck.
|
||||
"""
|
||||
disable_var = "NPCOMP_DISABLE_FILECHECK"
|
||||
filecheck_binary_var = "FILECHECK_BINARY"
|
||||
if "NPCOMP_DISABLE_FILECHECK" in os.environ:
|
||||
print("WARNING:FileCheck disabled due to", disable_var,
|
||||
"in the environment", file=sys.stderr)
|
||||
disable_filecheck = True
|
||||
if disable_filecheck:
|
||||
callback()
|
||||
sys.exit(0)
|
||||
|
||||
# Redirect through FileCheck
|
||||
filecheck_capture_io = io.StringIO()
|
||||
with contextlib.redirect_stdout(filecheck_capture_io):
|
||||
callback()
|
||||
filecheck_capture_io.flush()
|
||||
filecheck_input = filecheck_capture_io.getvalue()
|
||||
filecheck_binary = "FileCheck"
|
||||
if filecheck_binary_var in os.environ:
|
||||
filecheck_binary = os.environ[filecheck_binary_var]
|
||||
print("Using FileCheck binary", filecheck_binary,
|
||||
"(customize by setting", filecheck_binary_var, ")", file=sys.stderr)
|
||||
filecheck_args = [filecheck_binary, main_file, "--dump-input=fail"]
|
||||
print("LAUNCHING FILECHECK:", filecheck_args, file=sys.stderr)
|
||||
p = subprocess.Popen(filecheck_args, stdin=subprocess.PIPE)
|
||||
p.communicate(filecheck_input.encode("UTF-8"))
|
||||
sys.exit(p.returncode)
|
|
@ -0,0 +1,53 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
TEST_MODULES = (
|
||||
"npcomp.edsc_test",
|
||||
"npcomp.tracing.context",
|
||||
"npcomp.tracing.mlir_trace",
|
||||
"npcomp.types",
|
||||
"npcomp.exporter",
|
||||
"npcomp.exp.extractor",
|
||||
)
|
||||
|
||||
# Compute PYTHONPATH for sub processes.
|
||||
DIRSEP = ":" if os.path.pathsep == "/" else ";"
|
||||
PYTHONPATH = os.path.dirname(__file__)
|
||||
if "PYTHONPATH" in os.environ:
|
||||
PYTHONPATH = PYTHONPATH + DIRSEP + os.environ["PYTHONPATH"]
|
||||
CHILD_ENVIRON = dict(os.environ)
|
||||
CHILD_ENVIRON["PYTHONPATH"] = PYTHONPATH
|
||||
|
||||
# Configure filecheck.
|
||||
FILECHECK_BINARY = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..", "..", "..", "bin", "FileCheck"))
|
||||
if os.path.exists(FILECHECK_BINARY):
|
||||
CHILD_ENVIRON["FILECHECK_BINARY"] = FILECHECK_BINARY
|
||||
else:
|
||||
print("WARNING! Built FileCheck not found. Leaving to path resolution")
|
||||
|
||||
passed = []
|
||||
failed = []
|
||||
|
||||
for test_module in TEST_MODULES:
|
||||
print("--------====== RUNNING %s ======--------" % test_module)
|
||||
try:
|
||||
subprocess.check_call([sys.executable, "-m", test_module],
|
||||
env=CHILD_ENVIRON)
|
||||
print("--------====== DONE %s ======--------\n" % test_module)
|
||||
passed.append(test_module)
|
||||
except subprocess.CalledProcessError:
|
||||
print("!!!!!!!!====== ERROR %s ======!!!!!!!!\n" % test_module)
|
||||
failed.append(test_module)
|
||||
|
||||
print("Done: %d passed, %d failed" % (len(passed), len(failed)))
|
||||
if failed:
|
||||
for test_module in failed:
|
||||
print(" %s: FAILED" % test_module)
|
||||
sys.exit(1)
|
Loading…
Reference in New Issue