mirror of https://github.com/llvm/torch-mlir
Make torch_mlir compatible with binary PyTorch installations.
* This has been anticipated for a long time in that it is quite hard to keep C++ binary compatibility across a system landscape as diverse as PyTorch, LLVM, and this project. This is why we based the PyTorch extension on the MLIR and NPCOMP C APIs only: that is the only sane linkage story for the entire matrix. * Removes the few LLVM'isms in torch_mlir that had snuck in, using either STL or PyTorch support utilities. The new rule here is that LLVM C++ includes are forbidden at this level and (as stated in the design), torch_mlir should use the PyTorch runtime and support libraries (not introduce an incidental C++ dependency on LLVM). * Also deletes mnist-playground as it was proving impossible to keep the grid of PyTorch vs system ABI divisions functioning. I am open to a less drastic course here (optional/disabled by default?) * This gets us pretty close to just using PyTorch's extension builder API, which will be nice for distribution (i.e. it integrates well with the PyTorch ecosystem for deployment). I ended up just simplifying the in-tree CMake support for now. * Fixes #138pull/140/head
parent
b2077738ca
commit
f6d7ee06ef
|
@ -150,14 +150,7 @@ message(STATUS "Found python libraries: ${Python3_LIBRARIES}")
|
|||
# Pytorch Configuration
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
if(NPCOMP_ENABLE_PYTORCH)
|
||||
ProbeForPyTorchInstall()
|
||||
if(NPCOMP_ENABLE_PYTORCH STREQUAL "OPTIONAL")
|
||||
find_package(Torch)
|
||||
else()
|
||||
find_package(Torch REQUIRED)
|
||||
endif()
|
||||
endif()
|
||||
NpcompFindPyTorch(${NPCOMP_ENABLE_PYTORCH})
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Pybind11 Configuration
|
||||
|
@ -186,8 +179,13 @@ add_subdirectory(include/npcomp)
|
|||
add_subdirectory(lib)
|
||||
add_subdirectory(python)
|
||||
add_subdirectory(test)
|
||||
add_subdirectory(frontends)
|
||||
|
||||
# Tools needs to come late to ensure that NPCOMP_ALL_LIBS is populated.
|
||||
# Generally things after this point may depend on NPCOMP_ALL_LIBS or libNPCOMP.so.
|
||||
add_subdirectory(tools)
|
||||
|
||||
if(${TORCH_FOUND})
|
||||
add_subdirectory(frontends/pytorch)
|
||||
else()
|
||||
message("Skipping pytorch frontend, because PyTorch not found!")
|
||||
endif()
|
||||
|
|
|
@ -94,10 +94,6 @@ git submodule update
|
|||
LLVM_VERSION=10
|
||||
export CC=clang-$LLVM_VERSION
|
||||
export CXX=clang++-$LLVM_VERSION
|
||||
# If compiling on a new OS that defaults to the CXX11 ABI (i.e. Ubuntu >= 20.04)
|
||||
# and looking to use binary installs from the PyTorch website, you must build
|
||||
# without the CXX11 ABI.
|
||||
export CXXFLAGS="-D_GLIBCXX_USE_CXX11_ABI=0"
|
||||
export LDFLAGS=-fuse-ld=$(which ld.lld-$LLVM_VERSION)
|
||||
|
||||
# Build and install LLVM/MLIR into the ./install-mlir directory
|
||||
|
@ -124,7 +120,6 @@ source .env
|
|||
### PyTorch Frontend (with PyTorch installed via conda)
|
||||
|
||||
```shell
|
||||
# See note above about -D_GLIBCXX_USE_CXX11_ABI=0
|
||||
./build_tools/cmake_configure.sh
|
||||
cmake --build build --target check-npcomp check-frontends-pytorch
|
||||
```
|
||||
|
|
|
@ -1,11 +1,35 @@
|
|||
function(ProbeForPyTorchInstall)
|
||||
# NpcompFindPyTorch
|
||||
# Calls find_package on Torch and does any needed post-processing.
|
||||
# The enable_pytorch flag can be OFF, ON or OPTIONAL.
|
||||
macro(NpcompFindPyTorch enable_pytorch)
|
||||
if(${enable_pytorch} OR ${enable_pytorch} STREQUAL "OPTIONAL")
|
||||
NpcompProbeForPyTorchInstall()
|
||||
if(${enable_pytorch} STREQUAL "OPTIONAL")
|
||||
find_package(Torch 1.8)
|
||||
else()
|
||||
find_package(Torch 1.8 REQUIRED)
|
||||
endif()
|
||||
|
||||
if(${TORCH_FOUND})
|
||||
NpcompConfigurePyTorch()
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Not configuring PyTorch (disabled)")
|
||||
endif()
|
||||
endmacro()
|
||||
|
||||
# NpcompProbeForPyTorchInstall
|
||||
# Attempts to find a Torch installation and set the Torch_ROOT variable
|
||||
# based on introspecting the python environment. This allows a subsequent
|
||||
# call to find_package(Torch) to work.
|
||||
function(NpcompProbeForPyTorchInstall)
|
||||
if(Torch_ROOT)
|
||||
message(STATUS "Using cached Torch root = ${Torch_ROOT}")
|
||||
else()
|
||||
message(STATUS "Checking for PyTorch using ${PYTHON_EXECUTABLE} ...")
|
||||
execute_process(
|
||||
COMMAND ${PYTHON_EXECUTABLE}
|
||||
-c "import os;import torch;print(os.path.dirname(torch.__file__), end='')"
|
||||
-c "import os;import torch;print(torch.utils.cmake_prefix_path, end='')"
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE PYTORCH_STATUS
|
||||
OUTPUT_VARIABLE PYTORCH_PACKAGE_DIR)
|
||||
|
@ -15,8 +39,42 @@ function(ProbeForPyTorchInstall)
|
|||
endif()
|
||||
message(STATUS "Found PyTorch installation at ${PYTORCH_PACKAGE_DIR}")
|
||||
|
||||
# PyTorch stashes its installed .cmake files under share/cmake/Torch.
|
||||
set(Torch_ROOT "${PYTORCH_PACKAGE_DIR}/share/cmake/Torch"
|
||||
CACHE STRING "Torch package root")
|
||||
set(Torch_ROOT "${PYTORCH_PACKAGE_DIR}" CACHE STRING
|
||||
"Torch configure directory" FORCE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# NpcompConfigurePyTorch
|
||||
# Performs configuration of PyTorch flags after CMake has found it to be
|
||||
# present. Most of this comes down to detecting whether building against a
|
||||
# source or official binary and adjusting compiler options in the latter case
|
||||
# (in the former, we assume that it was built with system defaults). We do this
|
||||
# conservatively and assume non-binary builds by default.
|
||||
#
|
||||
# In the future, we may want to switch away from custom building these
|
||||
# extensions and instead rely on the Torch machinery directly (definitely want
|
||||
# to do that for official builds).
|
||||
function(NpcompConfigurePyTorch)
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
||||
# Linux specific libstdcpp ABI checking.
|
||||
message(STATUS "Checking if Torch is an official binary ...")
|
||||
execute_process(
|
||||
COMMAND ${PYTHON_EXECUTABLE}
|
||||
-c "from torch.utils import cpp_extension as c; import sys; sys.exit(0 if c._is_binary_build() else 1)"
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE _is_binary_build)
|
||||
if(${_is_binary_build} EQUAL 0)
|
||||
set(TORCH_CXXFLAGS "")
|
||||
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
|
||||
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11")
|
||||
elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
|
||||
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=1011 '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'")
|
||||
else()
|
||||
message(WARNING "Unrecognized compiler. Cannot determine ABI flags.")
|
||||
return()
|
||||
endif()
|
||||
message(STATUS "Detected Torch official binary build. Setting ABI flags: ${TORCH_CXXFLAGS}")
|
||||
set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE)
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
|
|
@ -12,7 +12,7 @@ function(npcomp_detect_pybind11_install)
|
|||
if(pybind11_DIR)
|
||||
message(STATUS "Using explicit pybind11 cmake directory: ${pybind11_DIR} (-Dpybind11_DIR to change)")
|
||||
else()
|
||||
message(CHECK_START "Checking for pybind11 in python path...")
|
||||
message(STATUS "Checking for pybind11 in python path...")
|
||||
execute_process(
|
||||
COMMAND "${Python3_EXECUTABLE}"
|
||||
-c "import pybind11;print(pybind11.get_cmake_dir(), end='')"
|
||||
|
@ -24,7 +24,6 @@ function(npcomp_detect_pybind11_install)
|
|||
message(CHECK_FAIL "not found (install via 'pip install pybind11' or set pybind11_DIR)")
|
||||
return()
|
||||
endif()
|
||||
message(CHECK_PASS "found (${PACKAGE_DIR})")
|
||||
set(pybind11_DIR "${PACKAGE_DIR}" PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
if(${TORCH_FOUND})
|
||||
add_subdirectory(pytorch)
|
||||
else()
|
||||
message("Skipping pytorch frontend, because PyTorch not found!")
|
||||
endif()
|
|
@ -1,51 +1,49 @@
|
|||
add_subdirectory(builder)
|
||||
|
||||
include(NpcompPython)
|
||||
|
||||
# Sharp edge: Torch extensions need to use the same pybind11 that torch
|
||||
# was compiled with, or else there will be issues in cross module exception
|
||||
# handling (which will abort instead of raise). We circumvent the possibility
|
||||
# by forcing the torch directories first.
|
||||
include_directories(BEFORE
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${TORCH_INSTALL_PREFIX}/include/TH
|
||||
${TORCH_INSTALL_PREFIX}/include/THC/opt/pytorch/pytorch
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Python3_INCLUDE_DIRS}
|
||||
# TODO: Fix implicit ordering. If PyTorch was build against an external
|
||||
# pybind11, then it will not be in the above search path and must be
|
||||
# resolved here, in the hope that it is the same one we were configured
|
||||
# with (which it should be if installed via pip). This is really fragile,
|
||||
# though, causing cast failures at runtime if we get it wrong. Come up with
|
||||
# a way to strengthen this.
|
||||
${pybind11_INCLUDE_DIR}
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
add_library(NPCOMPTorchMLIRExt SHARED
|
||||
builder/acap_dispatch.cpp
|
||||
builder/debug.cpp
|
||||
builder/func_builder.cpp
|
||||
builder/graph_importer.cpp
|
||||
builder/module_builder.cpp
|
||||
builder/python_bindings.cpp
|
||||
init_python_bindings.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(NPCOMPTorchMLIRExt
|
||||
# NPCOMP shared library.
|
||||
# TODO: Debug why order matters here (if NPCOMP is included last a large
|
||||
# amount of LLVM/MLIR/NPCOMP ends up compiled into this library).
|
||||
NPCOMP
|
||||
|
||||
${TORCH_LIBRARIES}
|
||||
${Python3_LIBRARIES}
|
||||
torch_python
|
||||
npcomp_torch_builder_bindings
|
||||
|
||||
# NPCOMP shared library.
|
||||
NPCOMP
|
||||
)
|
||||
add_dependencies(NPCOMPTorchMLIRExt
|
||||
# Uses of the torch_mlir extension also require the npcomp extension to
|
||||
# be built.
|
||||
NPCOMPNativePyExt
|
||||
)
|
||||
|
||||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}")
|
||||
set_target_properties(NPCOMPTorchMLIRExt PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/python
|
||||
OUTPUT_NAME _torch_mlir
|
||||
PREFIX "${PYTHON_MODULE_PREFIX}"
|
||||
SUFFIX "${PYTHON_MODULE_EXTENSION}"
|
||||
CXX_VISIBILITY_PRESET "hidden"
|
||||
COMPILE_FLAGS "${TORCH_CXXFLAGS}"
|
||||
)
|
||||
|
||||
npcomp_python_target_compile_options(NPCOMPTorchMLIRExt)
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
include_directories(
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${TORCH_INSTALL_PREFIX}/include/TH
|
||||
${TORCH_INSTALL_PREFIX}/include/THC/opt/pytorch/pytorch
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Python3_INCLUDE_DIRS}
|
||||
# TODO: Fix implicit ordering. If PyTorch was build against an external
|
||||
# pybind11, then it will not be in the above search path and must be
|
||||
# resolved here, in the hope that it is the same one we were configured
|
||||
# with (which it should be if installed via pip). This is really fragile,
|
||||
# though, causing cast failures at runtime if we get it wrong. Come up with
|
||||
# a way to strengthen this.
|
||||
${pybind11_INCLUDE_DIR}
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
add_library(npcomp_torch_builder_bindings
|
||||
acap_dispatch.cpp
|
||||
debug.cpp
|
||||
func_builder.cpp
|
||||
graph_importer.cpp
|
||||
module_builder.cpp
|
||||
python_bindings.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(npcomp_torch_builder_bindings
|
||||
${TORCH_LIBRARIES}
|
||||
${Python3_LIBRARIES}
|
||||
torch_python
|
||||
)
|
|
@ -46,7 +46,7 @@ static c10::DispatchKey kAcapGradDispatchKey =
|
|||
AcapController::TracedKernelCallBuilder::TracedKernelCallBuilder(
|
||||
AcapController &parent, MlirContext context, MlirLocation loc,
|
||||
const c10::OperatorHandle &opHandle,
|
||||
llvm::Optional<std::string> overrideKernelName)
|
||||
c10::optional<std::string> overrideKernelName)
|
||||
: KernelCallBuilder(context, loc,
|
||||
overrideKernelName ? *overrideKernelName
|
||||
: opHandle.operator_name().name,
|
||||
|
@ -125,8 +125,8 @@ void AcapController::contextExit(py::object exc_type, py::object exc_val,
|
|||
void AcapController::returns(std::vector<at::Tensor> tensors) {
|
||||
verifyHasNotReturned();
|
||||
|
||||
llvm::SmallVector<MlirType, 4> returnsTypes;
|
||||
llvm::SmallVector<MlirValue, 4> returnsValues;
|
||||
std::vector<MlirType> returnsTypes;
|
||||
std::vector<MlirValue> returnsValues;
|
||||
for (auto &tensor : tensors) {
|
||||
MlirValue v = funcBuilder->lookupTensor(tensor);
|
||||
if (mlirValueIsNull(v)) {
|
||||
|
@ -470,7 +470,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc,
|
|||
}
|
||||
if (ival.isList()) {
|
||||
auto list = ival.toList();
|
||||
llvm::SmallVector<MlirValue, 4> elements;
|
||||
std::vector<MlirValue> elements;
|
||||
for (IValue element : list) {
|
||||
elements.push_back(mapIValueToMlirValue(loc, element));
|
||||
}
|
||||
|
@ -557,8 +557,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
|||
} else {
|
||||
elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type());
|
||||
}
|
||||
llvm::SmallVector<int64_t, 4> shape(tensor.sizes().begin(),
|
||||
tensor.sizes().end());
|
||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||
MlirType shapedType = mlirRankedTensorTypeGetChecked(
|
||||
shape.size(), shape.data(), elementType, loc);
|
||||
if (mlirTypeIsNull(shapedType)) {
|
||||
|
|
|
@ -85,7 +85,7 @@ private:
|
|||
TracedKernelCallBuilder(
|
||||
AcapController &parent, MlirContext context, MlirLocation loc,
|
||||
const c10::OperatorHandle &opHandle,
|
||||
llvm::Optional<std::string> overrideKernelName = llvm::None);
|
||||
c10::optional<std::string> overrideKernelName = c10::nullopt);
|
||||
void addOperand(const c10::IValue &value);
|
||||
void addResult(const c10::IValue &result);
|
||||
MlirOperation create();
|
||||
|
@ -94,7 +94,7 @@ private:
|
|||
AcapController &parent;
|
||||
const c10::OperatorHandle &opHandle;
|
||||
int resultCount = 0;
|
||||
llvm::SmallVector<std::pair<size_t, at::Tensor>, 4> resultIndexToTensorMap;
|
||||
std::vector<std::pair<size_t, at::Tensor>> resultIndexToTensorMap;
|
||||
};
|
||||
|
||||
MlirLocation getCurrentLocation();
|
||||
|
|
|
@ -25,10 +25,10 @@ static MlirOperation createStandardConstant(MlirLocation loc, MlirType type,
|
|||
}
|
||||
|
||||
KernelCallBuilder::KernelCallBuilder(MlirContext context, MlirLocation loc,
|
||||
llvm::StringRef kernelName,
|
||||
const std::string &kernelName,
|
||||
const c10::FunctionSchema &schema)
|
||||
: context(context), loc(loc), state("torch.kernel_call", loc),
|
||||
kernelName(kernelName), schema(schema) {
|
||||
schema(schema) {
|
||||
(void)this->context; // Preserve for future.
|
||||
MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet(
|
||||
toMlirStringRef("kernelName"),
|
||||
|
@ -45,7 +45,7 @@ void KernelCallBuilder::addSchemaAttrs() {
|
|||
// sigIsVararg
|
||||
// sigIsVarret
|
||||
// sigIsMutable
|
||||
llvm::SmallVector<MlirNamedAttribute, 8> attrs;
|
||||
std::vector<MlirNamedAttribute> attrs;
|
||||
attrs.push_back(
|
||||
mlirNamedAttributeGet(toMlirStringRef("sigIsMutable"),
|
||||
mlirBoolAttrGet(context, schema.is_mutable())));
|
||||
|
@ -57,7 +57,7 @@ void KernelCallBuilder::addSchemaAttrs() {
|
|||
mlirBoolAttrGet(context, schema.is_varret())));
|
||||
|
||||
// Arg types.
|
||||
llvm::SmallVector<MlirAttribute, 4> args;
|
||||
std::vector<MlirAttribute> args;
|
||||
for (auto &arg : schema.arguments()) {
|
||||
const std::string &typeStr = arg.type()->str();
|
||||
args.push_back(mlirStringAttrGet(
|
||||
|
@ -68,7 +68,7 @@ void KernelCallBuilder::addSchemaAttrs() {
|
|||
mlirArrayAttrGet(context, args.size(), args.data())));
|
||||
|
||||
// Return types.
|
||||
llvm::SmallVector<MlirAttribute, 4> returns;
|
||||
std::vector<MlirAttribute> returns;
|
||||
for (auto &ret : schema.returns()) {
|
||||
const std::string &typeStr = ret.type()->str();
|
||||
returns.push_back(mlirStringAttrGet(
|
||||
|
@ -170,7 +170,7 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
}
|
||||
// Ranked with possibly dynamic dims.
|
||||
auto &symbolicShape = tensorType->symbolic_sizes();
|
||||
llvm::SmallVector<int64_t, 4> dims;
|
||||
std::vector<int64_t> dims;
|
||||
dims.resize(*sizes.rank());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
auto shapeSymbol = symbolicShape[i];
|
||||
|
@ -204,12 +204,12 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
|||
|
||||
std::unique_ptr<FuncBuilder>
|
||||
FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
||||
MlirLocation location, llvm::StringRef name,
|
||||
llvm::SmallVectorImpl<MlirType> &inputTypes) {
|
||||
MlirLocation location, const std::string &name,
|
||||
std::vector<MlirType> &inputTypes) {
|
||||
auto context = mlirLocationGetContext(location);
|
||||
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
|
||||
// (this is fragile and reveals details that are not guaranteed).
|
||||
llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs;
|
||||
std::vector<MlirNamedAttribute> funcAttrs;
|
||||
funcAttrs.push_back(
|
||||
mlirNamedAttributeGet(toMlirStringRef("type"),
|
||||
mlirTypeAttrGet(mlirFunctionTypeGet(
|
||||
|
@ -242,8 +242,7 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter,
|
|||
context, funcOp, BlockBuilder(entryBlock, /*returnOp=*/{nullptr}, true)));
|
||||
}
|
||||
|
||||
void FuncBuilder::rewriteFuncReturnTypes(
|
||||
llvm::SmallVectorImpl<MlirType> &resultTypes) {
|
||||
void FuncBuilder::rewriteFuncReturnTypes(std::vector<MlirType> &resultTypes) {
|
||||
// Get inputs from current function type.
|
||||
MlirAttribute funcTypeAttr =
|
||||
mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type"));
|
||||
|
@ -252,7 +251,7 @@ void FuncBuilder::rewriteFuncReturnTypes(
|
|||
assert(mlirAttributeIsAType(funcTypeAttr) &&
|
||||
"function type is not a TypeAttr");
|
||||
MlirType funcType = mlirTypeAttrGetValue(funcTypeAttr);
|
||||
llvm::SmallVector<MlirType, 4> inputTypes;
|
||||
std::vector<MlirType> inputTypes;
|
||||
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(funcType); i < e; ++i) {
|
||||
inputTypes.push_back(mlirFunctionTypeGetInput(funcType, i));
|
||||
}
|
||||
|
@ -328,7 +327,7 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
|||
}
|
||||
|
||||
MlirValue FuncBuilder::buildList(MlirLocation loc,
|
||||
llvm::SmallVectorImpl<MlirValue> &elements) {
|
||||
std::vector<MlirValue> &elements) {
|
||||
MlirType resultType = npcompListTypeGet(context);
|
||||
OperationStateHolder state{"basicpy.build_list", loc};
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
|
|
|
@ -11,8 +11,6 @@
|
|||
#include "mlir_utils.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
@ -102,7 +100,7 @@ private:
|
|||
class KernelCallBuilder {
|
||||
public:
|
||||
KernelCallBuilder(MlirContext context, MlirLocation loc,
|
||||
llvm::StringRef kernelName,
|
||||
const std::string &kernelName,
|
||||
const c10::FunctionSchema &schema);
|
||||
void addOperand(MlirValue operand);
|
||||
void addResultType(MlirType resultType);
|
||||
|
@ -115,7 +113,6 @@ protected:
|
|||
private:
|
||||
void addSchemaAttrs();
|
||||
OperationStateHolder state;
|
||||
llvm::StringRef kernelName;
|
||||
const c10::FunctionSchema &schema;
|
||||
};
|
||||
|
||||
|
@ -131,8 +128,7 @@ public:
|
|||
/// to a parent.
|
||||
static std::unique_ptr<FuncBuilder>
|
||||
createFunction(Inserter &inserter, MlirLocation location,
|
||||
llvm::StringRef name,
|
||||
llvm::SmallVectorImpl<MlirType> &inputTypes);
|
||||
const std::string &name, std::vector<MlirType> &inputTypes);
|
||||
|
||||
MlirContext getContext() { return context; }
|
||||
MlirOperation getFuncOp() { return funcOp; }
|
||||
|
@ -143,7 +139,7 @@ public:
|
|||
|
||||
/// Rewrites the function's signature to return the given types. It is
|
||||
/// assumed that a compatible terminator has been added.
|
||||
void rewriteFuncReturnTypes(llvm::SmallVectorImpl<MlirType> &resultTypes);
|
||||
void rewriteFuncReturnTypes(std::vector<MlirType> &resultTypes);
|
||||
|
||||
/// Maps a live Tensor to an MlirValue.
|
||||
void mapTensor(at::Tensor tensor, MlirValue value) {
|
||||
|
@ -168,8 +164,7 @@ public:
|
|||
MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value);
|
||||
|
||||
/// Builds a list with the given elements
|
||||
MlirValue buildList(MlirLocation loc,
|
||||
llvm::SmallVectorImpl<MlirValue> &elements);
|
||||
MlirValue buildList(MlirLocation loc, std::vector<MlirValue> &elements);
|
||||
|
||||
private:
|
||||
FuncBuilder(MlirContext context, MlirOperation funcOp,
|
||||
|
@ -200,7 +195,7 @@ private:
|
|||
/// that tensors may be mapped and accessed in proximity.
|
||||
/// TODO: Tensors referenced via an IValue support hash code lookup and
|
||||
/// identity checks. Switch to this instead of a linear scan.
|
||||
llvm::SmallVector<std::pair<at::Tensor, MlirValue>, 16> tensorValueMap;
|
||||
std::vector<std::pair<at::Tensor, MlirValue>> tensorValueMap;
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
#include "graph_importer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlir_utils.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
|
@ -33,7 +35,7 @@ public:
|
|||
MlirValue findRequiredValue(MlirLocation loc, torch::jit::Value *torchValue);
|
||||
|
||||
private:
|
||||
llvm::DenseMap<torch::jit::Value *, MlirValue> valueMap;
|
||||
std::unordered_map<torch::jit::Value *, MlirValue> valueMap;
|
||||
NodeScope *prev = nullptr;
|
||||
};
|
||||
|
||||
|
@ -134,8 +136,8 @@ void GraphImporter::NodeImporter::importNode() {
|
|||
mlirBlockInsertOwnedOperationBefore(block, ip, op);
|
||||
|
||||
// Map results.
|
||||
for (auto it : llvm::enumerate(node->outputs())) {
|
||||
scope->bindValue(it.value(), mlirOperationGetResult(op, it.index()));
|
||||
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
||||
scope->bindValue(node->outputs()[i], mlirOperationGetResult(op, i));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -151,7 +153,7 @@ void GraphImporter::NodeImporter::importNode() {
|
|||
|
||||
void GraphImporter::NodeImporter::importReturnOp() {
|
||||
OperationStateHolder s("std.return", loc);
|
||||
llvm::SmallVector<MlirValue, 4> returnsValues;
|
||||
std::vector<MlirValue> returnsValues;
|
||||
for (auto *input : node->inputs()) {
|
||||
returnsValues.push_back(scope->findRequiredValue(loc, input));
|
||||
}
|
||||
|
@ -271,9 +273,9 @@ void GraphImporter::importGenericFunc() {
|
|||
|
||||
// Bind inputs.
|
||||
NodeScope scope;
|
||||
for (const auto &it : llvm::enumerate(graph->inputs())) {
|
||||
MlirValue value = mlirBlockGetArgument(entryBlock, it.index());
|
||||
scope.bindValue(it.value(), value);
|
||||
for (size_t i = 0; i < graph->inputs().size(); ++i) {
|
||||
MlirValue value = mlirBlockGetArgument(entryBlock, i);
|
||||
scope.bindValue(graph->inputs()[i], value);
|
||||
}
|
||||
|
||||
// Walk body nodes.
|
||||
|
|
|
@ -41,8 +41,8 @@ public:
|
|||
/// when to import globals as constants vs shared arrays, etc.
|
||||
struct MlirMappingOptions {
|
||||
MlirContext context;
|
||||
llvm::Optional<std::string> genericFuncName;
|
||||
llvm::Optional<std::string> funcName;
|
||||
c10::optional<std::string> genericFuncName;
|
||||
c10::optional<std::string> funcName;
|
||||
TypeMapper &typeMapper;
|
||||
FuncBuilder::Inserter &inserter;
|
||||
};
|
||||
|
@ -82,8 +82,8 @@ private:
|
|||
MlirLocation defaultLoc;
|
||||
|
||||
/// Argument and return types for the generic func.
|
||||
llvm::SmallVector<MlirType, 4> genericFuncArgTypes;
|
||||
llvm::SmallVector<MlirType, 4> genericFuncReturnTypes;
|
||||
std::vector<MlirType> genericFuncArgTypes;
|
||||
std::vector<MlirType> genericFuncReturnTypes;
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
|
|
@ -74,7 +74,7 @@ std::shared_ptr<AcapController>
|
|||
ModuleBuilder::startCaptureFunction(std::string &name,
|
||||
std::vector<at::Tensor> args) {
|
||||
// TODO: Verify that arguments do not alias each other.
|
||||
llvm::SmallVector<MlirType, 4> inputTypes;
|
||||
std::vector<MlirType> inputTypes;
|
||||
for (auto &arg : args) {
|
||||
inputTypes.push_back(typeMapper.forwardTensorToType(arg));
|
||||
}
|
||||
|
@ -88,9 +88,8 @@ ModuleBuilder::startCaptureFunction(std::string &name,
|
|||
assert(mlirBlockGetNumArguments(entryBlock) ==
|
||||
static_cast<intptr_t>(args.size()) &&
|
||||
"entry block incorrect arg arity");
|
||||
for (auto it : llvm::enumerate(args)) {
|
||||
funcBuilder->mapTensor(it.value(),
|
||||
mlirBlockGetArgument(entryBlock, it.index()));
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
funcBuilder->mapTensor(args[i], mlirBlockGetArgument(entryBlock, i));
|
||||
}
|
||||
return std::make_shared<AcapController>(typeMapper, std::move(funcBuilder));
|
||||
}
|
||||
|
@ -100,10 +99,9 @@ ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) {
|
|||
auto inserter = createInserter();
|
||||
GraphImporter::MlirMappingOptions mappingOptions{
|
||||
context,
|
||||
llvm::None, // genericFuncName (default to auto)
|
||||
llvm::None, // funcName (default to auto)
|
||||
typeMapper, inserter,
|
||||
};
|
||||
c10::nullopt, // genericFuncName (default to auto)
|
||||
c10::nullopt, // funcName (default to auto)
|
||||
typeMapper, inserter};
|
||||
auto graphImporter = GraphImporter::forPythonJitFunc(
|
||||
function.function_, std::move(mappingOptions));
|
||||
graphImporter->initialize();
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
#include "acap_dispatch.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||
|
|
|
@ -12,137 +12,15 @@
|
|||
// b) Direct IR translation from PyTorch Graphs (not implemented).
|
||||
// c) Using the PyTorch JIT facility (not implemented).
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
|
||||
#include "npcomp/Dialect/ATen/Transforms/LivenessReport.h"
|
||||
#include "npcomp/Dialect/ATen/Transforms/Passes.h"
|
||||
|
||||
#include "init_python_bindings.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir;
|
||||
|
||||
namespace llvm {
|
||||
extern bool DebugFlag;
|
||||
}
|
||||
|
||||
namespace torch_mlir {
|
||||
namespace {
|
||||
|
||||
mlir::OwningModuleRef LoadModule(mlir::MLIRContext &context, std::string mlir) {
|
||||
|
||||
mlir::OwningModuleRef module;
|
||||
|
||||
std::unique_ptr<llvm::MemoryBuffer> membuf =
|
||||
llvm::MemoryBuffer::getMemBuffer(mlir);
|
||||
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(membuf), llvm::SMLoc());
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't parse mlir module\n";
|
||||
return nullptr;
|
||||
}
|
||||
if (failed(mlir::verify(*module))) {
|
||||
llvm::errs() << "Error verifying MLIR module\n";
|
||||
return nullptr;
|
||||
}
|
||||
if (!module)
|
||||
return nullptr;
|
||||
return module;
|
||||
}
|
||||
|
||||
void InitModuleBindings(py::module &m) {
|
||||
m.def(
|
||||
"_op_report",
|
||||
[](std::string mlir) -> std::string {
|
||||
mlir::MLIRContext context;
|
||||
auto module = LoadModule(context, mlir);
|
||||
mlir::PassManager pm(module->getContext());
|
||||
|
||||
// our pass
|
||||
std::string report;
|
||||
pm.addPass(mlir::NPCOMP::aten::createATenLayerNamePass());
|
||||
pm.addPass(mlir::NPCOMP::aten::createATenOpReportPass(report));
|
||||
|
||||
if (failed(pm.run(*module))) {
|
||||
llvm::errs() << "ATenOpReportPass failed";
|
||||
return "<error>";
|
||||
}
|
||||
return report;
|
||||
},
|
||||
"run ATenOpReportPass");
|
||||
|
||||
m.def(
|
||||
"_liveness_report",
|
||||
[](std::string mlir) -> std::string {
|
||||
mlir::MLIRContext context;
|
||||
auto module = LoadModule(context, mlir);
|
||||
|
||||
mlir::PassManager pm(module->getContext());
|
||||
|
||||
pm.addPass(mlir::NPCOMP::aten::createATenLayerNamePass());
|
||||
if (failed(pm.run(*module))) {
|
||||
llvm::errs() << "ATen generate liveness report failed";
|
||||
return "<error>";
|
||||
}
|
||||
|
||||
auto mOp = module.get();
|
||||
auto liveness = mlir::NPCOMP::aten::LivenessReport(mOp);
|
||||
std::string report = liveness.emitJSONReport();
|
||||
return report;
|
||||
},
|
||||
"generate liveness report");
|
||||
|
||||
// TODO: Could this be implemented with MLIR python bindings?
|
||||
m.def(
|
||||
"lower_to_std",
|
||||
[](std::string mlir) -> std::string {
|
||||
mlir::MLIRContext context;
|
||||
auto module = LoadModule(context, mlir);
|
||||
|
||||
PassManager pm0(module->getContext());
|
||||
pm0.addPass(mlir::NPCOMP::aten::createATenLoweringPass());
|
||||
pm0.addPass(mlir::NPCOMP::aten::createReturnEliminationPass());
|
||||
pm0.addPass(mlir::createCSEPass());
|
||||
|
||||
if (failed(pm0.run(*module))) {
|
||||
llvm::errs() << "aten to loops conversion failed ";
|
||||
return "";
|
||||
}
|
||||
|
||||
// dump MLIR to string and return
|
||||
std::string s;
|
||||
llvm::raw_string_ostream ss(s);
|
||||
ss << "# Lowered to Std\n";
|
||||
module->print(ss);
|
||||
return ss.str();
|
||||
},
|
||||
"lower aten to std dialect");
|
||||
|
||||
m.def(
|
||||
"set_debug",
|
||||
[](bool b, std::string type) -> void {
|
||||
llvm::setCurrentDebugType(type.c_str());
|
||||
llvm::DebugFlag = b;
|
||||
},
|
||||
"enable/disable debug messages");
|
||||
}
|
||||
void InitModuleBindings(py::module &m) {}
|
||||
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -17,14 +17,13 @@ mb = torch_mlir.ModuleBuilder()
|
|||
@mb.import_function
|
||||
@torch.jit.script
|
||||
def add3(t0, t1, t2):
|
||||
# CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 2]]
|
||||
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
|
||||
# TODO: Checks for debug info are quite hard with the new trailing debug
|
||||
# attribute print. See if this can be improved.
|
||||
# CHECK: loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
|
||||
intermediate = t0 + t1
|
||||
# CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
|
||||
# CHECK: loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]]
|
||||
final = intermediate + t2
|
||||
# CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 3]]
|
||||
return final
|
||||
# CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 5]]
|
||||
|
||||
# Verify again with debug info present. Just checking that it makes it in there.
|
||||
mb.module.operation.print(enable_debug_info=True)
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
add_subdirectory(npcomp-opt)
|
||||
if(${TORCH_FOUND})
|
||||
add_subdirectory(mnist-playground)
|
||||
endif()
|
||||
add_subdirectory(npcomp-run-mlir)
|
||||
add_subdirectory(npcomp-shlib)
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
# TODO: This is copied from frontends/pytorch/csrc/c10_dispatch/CMakeLists.txt
|
||||
# What is the idiomatic way of sharing this in CMake?
|
||||
include_directories(
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${TORCH_INSTALL_PREFIX}/include/TH
|
||||
${TORCH_INSTALL_PREFIX}/include/THC/opt/pytorch/pytorch
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
${Python3_INCLUDE_DIRS}
|
||||
)
|
||||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
|
||||
set(LLVM_LINK_COMPONENTS
|
||||
Core
|
||||
Support
|
||||
nativecodegen
|
||||
)
|
||||
|
||||
add_npcomp_executable(mnist-playground
|
||||
mnist-playground.cpp
|
||||
)
|
||||
llvm_update_compile_flags(mnist-playground)
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
target_link_libraries(mnist-playground PRIVATE
|
||||
# Shared library deps first ensure we get most of what we need from libraries.
|
||||
NPCOMP
|
||||
MLIR
|
||||
|
||||
MLIRAnalysis
|
||||
MLIREDSC
|
||||
MLIRExecutionEngine
|
||||
MLIRIR
|
||||
MLIRJitRunner
|
||||
MLIRLLVMIR
|
||||
MLIRParser
|
||||
MLIRTargetLLVMIR
|
||||
MLIRSupport
|
||||
NPCOMPInitAll
|
||||
NPCOMPRefBackendJITHelpers
|
||||
${conversion_libs}
|
||||
${dialect_libs}
|
||||
${TORCH_LIBRARIES}
|
||||
)
|
||||
add_dependencies(mnist-playground
|
||||
NPCOMPCompilerRuntimeShlib
|
||||
)
|
|
@ -1,42 +0,0 @@
|
|||
# mnist-playground
|
||||
|
||||
This is intended to be a short-lived "playground" for doing various experiments, guided by a real model use case, for improving the npcomp reference backend.
|
||||
|
||||
It's expected that utilities developed here will graduate to a more general utility or that this utility will be obsoleted by Python-driven flows once those come online.
|
||||
|
||||
## Goals:
|
||||
|
||||
- Obtain a performance-grounded analysis of the TCF/TCP design + reference backend design, and improve the designs.
|
||||
|
||||
- Make forward progress on TCF/TCP + reference backend while the PyTorch frontend is being brought up.
|
||||
|
||||
## Rough sketch of how we intend to get there:
|
||||
|
||||
1. Link against PyTorch, and write a simple routine to do inference on a simple FC MNIST.
|
||||
|
||||
2. Write a similar routine in TCF, extending TCF and the reference backend as needed for functional completeness. The PyTorch code serves as a numerical correctness reference.
|
||||
|
||||
3. Run and profile the reference backend and obtain a set of action items for design improvements, both to performance and stability. The PyTorch code serves as a performance baseline.
|
||||
|
||||
4. Implement important action items on a priority basis, and document remaining major design issues that don't make sense to address at this time, along with a justification for why the current design doesn't prevent us from eventually solving them. Iterate the previous step and this one as makes sense.
|
||||
|
||||
5. (Stretch) Add support for convolutional MNIST and/or training.
|
||||
|
||||
## Current Status
|
||||
|
||||
Step 1. DONE
|
||||
|
||||
Step 2. MOSTLY DONE. Still need to improve the op set to make the FC MNIST more complete. In particular, implementing functionality for reshape and softmax.
|
||||
|
||||
Step 3. STARTING. Initial performance on 10x784x100 (10 FC feature, batch 100) is 66x off from PyTorch. No profiling done yet.
|
||||
|
||||
Example command line (the .mlir file and `-invoke` are similar to npcomp-run-mlir):
|
||||
|
||||
```
|
||||
$ mnist-playground tools/mnist-playground/fc.mlir -invoke fc
|
||||
PyTorch: numRuns: 16384 nsPerRun: 3.947563e+05
|
||||
RefBackend: numRuns: 256 nsPerRun: 2.471073e+07
|
||||
Ratio (RefBackend / PyTorch): 62.5974
|
||||
```
|
||||
|
||||
There is currently a fragile dependency between hardcoded `at::` function calls in the .cpp file and the TCF code in the `.mlir` file. A correctness check is done to make sure they agree. Once we have a PyTorch frontend and/or ATen roundrip ATen backend oneline, we can avoid this fragility.
|
|
@ -1,15 +0,0 @@
|
|||
|
||||
func @fc(
|
||||
// TODO: Implement "reshape" so that %image can be passed as batch of 2D tensors.
|
||||
%image: tensor<?x?xf32>,
|
||||
%weights: tensor<?x?xf32>,
|
||||
%biases: tensor<?x?xf32>)
|
||||
-> (
|
||||
tensor<?x?xf32>
|
||||
) {
|
||||
%0 = tcf.matmul %weights, %image : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = tcf.add %0, %biases : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// TODO: Implement softmax for classification.
|
||||
// For now, this returns a not-terribly useful number.
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
|
@ -1,298 +0,0 @@
|
|||
//===- mnist-playground.cpp -------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "npcomp/InitAll.h"
|
||||
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
using namespace mlir;
|
||||
using llvm::Error;
|
||||
using llvm::ErrorOr;
|
||||
using llvm::Expected;
|
||||
using llvm::StringError;
|
||||
using llvm::Twine;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Wrap a string into an llvm::StringError.
|
||||
static Error make_string_error(const Twine &message) {
|
||||
return llvm::make_error<StringError>(message.str(),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
Expected<std::unique_ptr<refback::JITModule>>
|
||||
createJITModule(std::string mlirFile, mlir::DialectRegistry ®istry,
|
||||
ArrayRef<StringRef> sharedLibs, bool optimize) {
|
||||
MLIRContext context;
|
||||
registry.loadAll(&context);
|
||||
OwningModuleRef moduleRef = parseSourceFile(mlirFile, &context);
|
||||
if (!moduleRef)
|
||||
return make_string_error(Twine("could not open ") + mlirFile);
|
||||
|
||||
ModuleOp module = *moduleRef;
|
||||
|
||||
// Compile.
|
||||
PassManager pm(module.getContext(), OpPassManager::Nesting::Implicit);
|
||||
applyPassManagerCLOptions(pm);
|
||||
refback::JITModule::buildBackendCompilationPipeline(pm, optimize);
|
||||
if (failed(pm.run(module)))
|
||||
return make_string_error(Twine("error compiling to jit backend"));
|
||||
|
||||
return refback::JITModule::fromCompiledModule(module, sharedLibs);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Benchmarking / correctness-testing code.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Expected<std::vector<at::Tensor>>
|
||||
invokeJITModuleWithATenTensors(refback::JITModule &jitModule,
|
||||
StringRef invokeFunction,
|
||||
std::vector<at::Tensor> &args) {
|
||||
|
||||
// Do a bit of checking. We don't handle all possible tensors right now.
|
||||
std::vector<at::TensorArg> tensorArgs;
|
||||
for (auto arg : llvm::enumerate(args))
|
||||
tensorArgs.push_back(at::TensorArg(arg.value(), "arg", arg.index()));
|
||||
at::CheckedFrom c = "converting to refbackrt::Tensor";
|
||||
for (auto &tensorArg : tensorArgs)
|
||||
at::checkScalarType(c, tensorArg, at::ScalarType::Float);
|
||||
at::checkAllContiguous(c, tensorArgs);
|
||||
|
||||
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> refbackInputs;
|
||||
for (at::Tensor arg : args) {
|
||||
SmallVector<int32_t, 6> extents(arg.sizes().begin(), arg.sizes().end());
|
||||
float *data = arg.storage().data<float>();
|
||||
// This does a deep copy of the data. Let's see if it shows up on the
|
||||
// profile.
|
||||
refbackInputs.push_back(refbackrt::Tensor::create(
|
||||
refbackrt::ArrayRef<int32_t>(extents.data(), extents.size()),
|
||||
refbackrt::ElementType::F32, data));
|
||||
}
|
||||
|
||||
// Invoke the RefBackend function.
|
||||
auto expectedOutputs = jitModule.invoke(invokeFunction, refbackInputs);
|
||||
if (!expectedOutputs)
|
||||
return expectedOutputs.takeError();
|
||||
auto refbackrtOutputs = std::move(*expectedOutputs);
|
||||
|
||||
std::vector<at::Tensor> results;
|
||||
for (auto output : refbackrtOutputs) {
|
||||
std::vector<int64_t> sizes(output->getExtents().data(),
|
||||
output->getExtents().data() +
|
||||
output->getExtents().size());
|
||||
// Make a copy for passing to at::from_blob, which does its own internal
|
||||
// reference counting.
|
||||
auto *dataCopy = std::malloc(output->getDataByteSize());
|
||||
std::memcpy(dataCopy, output->getData(), output->getDataByteSize());
|
||||
results.push_back(at::from_blob(
|
||||
dataCopy, sizes, [](void *p) { std::free(p); }, at::kFloat));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
using InvocationFunction =
|
||||
std::function<Expected<std::vector<at::Tensor>>(std::vector<at::Tensor>)>;
|
||||
|
||||
struct BenchmarkResult {
|
||||
int numRuns;
|
||||
float nsPerRun;
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, const BenchmarkResult &result) {
|
||||
os << "numRuns: " << result.numRuns << " nsPerRun: " << std::scientific
|
||||
<< result.nsPerRun << std::defaultfloat;
|
||||
return os;
|
||||
}
|
||||
|
||||
Expected<BenchmarkResult> benchmark(std::function<Error()> f) {
|
||||
for (int itersAtATime = 1;; itersAtATime *= 2) {
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
for (int i = 0; i < itersAtATime; i++) {
|
||||
auto error = f();
|
||||
if (error)
|
||||
return std::move(error);
|
||||
}
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
std::chrono::duration<float> elapsed = end - start;
|
||||
|
||||
// If the runtime is longer than 0.5 seconds, it's reliable enough.
|
||||
if (elapsed.count() > 0.5f) {
|
||||
BenchmarkResult result;
|
||||
result.numRuns = itersAtATime;
|
||||
result.nsPerRun = elapsed.count() * 10e9 / itersAtATime;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return make_string_error("too short running to benchmark!");
|
||||
}
|
||||
|
||||
static Error doIt(InvocationFunction ptFunc, InvocationFunction refBackendFunc,
|
||||
bool doBenchmark, int numCorrectnessTests) {
|
||||
|
||||
torch::manual_seed(42);
|
||||
torch::set_num_threads(1);
|
||||
|
||||
std::vector<at::Tensor> args;
|
||||
args.push_back(at::rand({784, 100}));
|
||||
args.push_back(at::rand({10, 784}));
|
||||
args.push_back(at::rand({10, 1}));
|
||||
|
||||
// Initial correctness check of the two functions.
|
||||
for (int correctnessTest = 0; correctnessTest < numCorrectnessTests;
|
||||
correctnessTest++) {
|
||||
auto expectedPt = ptFunc(args);
|
||||
auto expectedRefBackend = refBackendFunc(args);
|
||||
if (!expectedPt)
|
||||
return expectedPt.takeError();
|
||||
if (!expectedRefBackend)
|
||||
return expectedRefBackend.takeError();
|
||||
auto pt = std::move(*expectedPt);
|
||||
auto refBackend = std::move(*expectedRefBackend);
|
||||
if (pt.size() != refBackend.size())
|
||||
return make_string_error("mismatch in result arity!");
|
||||
for (int i = 0, e = pt.size(); i < e; i++) {
|
||||
if (!at::allclose(pt[i], refBackend[i])) {
|
||||
std::cout << "PyTorch:\n" << pt[i] << "\n";
|
||||
std::cout << "RefBackend:\n" << refBackend[i] << "\n";
|
||||
return make_string_error(Twine("mismatch in result contents ") +
|
||||
Twine(i) + Twine(" on correctness test #") +
|
||||
Twine(correctnessTest));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!doBenchmark)
|
||||
return Error::success();
|
||||
|
||||
// Benchmark the two against each other.
|
||||
BenchmarkResult ptBenchmarkResult;
|
||||
BenchmarkResult refBackendBenchmarkResult;
|
||||
{
|
||||
auto expectedResult =
|
||||
benchmark([&]() -> Error { return ptFunc(args).takeError(); });
|
||||
if (!expectedResult)
|
||||
return expectedResult.takeError();
|
||||
ptBenchmarkResult = std::move(*expectedResult);
|
||||
}
|
||||
|
||||
{
|
||||
auto expectedResult =
|
||||
benchmark([&]() -> Error { return refBackendFunc(args).takeError(); });
|
||||
if (!expectedResult)
|
||||
return expectedResult.takeError();
|
||||
refBackendBenchmarkResult = std::move(*expectedResult);
|
||||
}
|
||||
std::cout << "PyTorch: " << ptBenchmarkResult << "\n";
|
||||
std::cout << "RefBackend: " << refBackendBenchmarkResult << "\n";
|
||||
std::cout << "Ratio (RefBackend / PyTorch): "
|
||||
<< refBackendBenchmarkResult.nsPerRun / ptBenchmarkResult.nsPerRun
|
||||
<< "\n";
|
||||
|
||||
// TODO: Check for memory leaks?
|
||||
|
||||
return Error::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Main-related init and option parsing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
namespace cl = llvm::cl;
|
||||
struct Options {
|
||||
cl::opt<std::string> inputFile{
|
||||
cl::Positional, cl::desc("the input .mlir file"), cl::init("-")};
|
||||
cl::opt<std::string> invokeFunction{"invoke", cl::Required,
|
||||
cl::desc("function to invoke")};
|
||||
|
||||
cl::list<std::string> sharedLibs{"shared-libs", cl::ZeroOrMore,
|
||||
cl::MiscFlags::CommaSeparated,
|
||||
cl::desc("Libraries to link dynamically")};
|
||||
cl::opt<bool> optimize{
|
||||
"optimize", cl::Optional,
|
||||
cl::desc("whether the refback pass pipeline should run optimizations"),
|
||||
cl::init(false)};
|
||||
|
||||
cl::opt<bool> benchmark{"benchmark", cl::Optional,
|
||||
cl::desc("whether to do a benchmark comparison"),
|
||||
cl::init(true)};
|
||||
|
||||
cl::opt<uint32_t> numCorrectnessTests{
|
||||
"num-correctness-tests", cl::Optional,
|
||||
cl::desc("how many correctness tests to run (useful for nondeterministic "
|
||||
"correctness failures"),
|
||||
cl::init(1)};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::registerAllDialects(registry);
|
||||
mlir::registerAllPasses();
|
||||
mlir::NPCOMP::registerAllDialects(registry);
|
||||
mlir::NPCOMP::registerAllPasses();
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
llvm::InitializeNativeTarget();
|
||||
llvm::InitializeNativeTargetAsmPrinter();
|
||||
mlir::initializeLLVMPasses();
|
||||
|
||||
mlir::registerAsmPrinterCLOptions();
|
||||
mlir::registerPassManagerCLOptions();
|
||||
|
||||
Options options;
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "mnist playground utility\n");
|
||||
|
||||
SmallVector<StringRef, 6> sharedLibs(options.sharedLibs.begin(),
|
||||
options.sharedLibs.end());
|
||||
auto expectedJITModule = createJITModule(options.inputFile, registry,
|
||||
sharedLibs, options.optimize);
|
||||
if (Error error = expectedJITModule.takeError())
|
||||
llvm::report_fatal_error(llvm::toString(std::move(error)),
|
||||
/*gen_crash_diag=*/false);
|
||||
auto jitModule = std::move(*expectedJITModule);
|
||||
|
||||
Error error = doIt(
|
||||
[](std::vector<at::Tensor> args) {
|
||||
auto image = args[0];
|
||||
auto weights = args[1];
|
||||
auto biases = args[2];
|
||||
auto v0 = at::matmul(weights, image);
|
||||
auto v1 = at::add(v0, biases);
|
||||
return std::vector<at::Tensor>{v1};
|
||||
},
|
||||
[&](std::vector<at::Tensor> args) {
|
||||
return invokeJITModuleWithATenTensors(*jitModule,
|
||||
options.invokeFunction, args);
|
||||
},
|
||||
options.benchmark, options.numCorrectnessTests);
|
||||
|
||||
int exitCode = EXIT_SUCCESS;
|
||||
llvm::handleAllErrors(std::move(error),
|
||||
[&exitCode](const llvm::ErrorInfoBase &info) {
|
||||
llvm::errs() << "Error: ";
|
||||
info.log(llvm::errs());
|
||||
llvm::errs() << '\n';
|
||||
exitCode = EXIT_FAILURE;
|
||||
});
|
||||
return exitCode;
|
||||
}
|
Loading…
Reference in New Issue