From f6d7ee06ef22d42ace603df6442d1c13c83a90a5 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 14 Dec 2020 08:42:42 -0800 Subject: [PATCH] Make torch_mlir compatible with binary PyTorch installations. * This has been anticipated for a long time in that it is quite hard to keep C++ binary compatibility across a system landscape as diverse as PyTorch, LLVM, and this project. This is why we based the PyTorch extension on the MLIR and NPCOMP C APIs only: that is the only sane linkage story for the entire matrix. * Removes the few LLVM'isms in torch_mlir that had snuck in, using either STL or PyTorch support utilities. The new rule here is that LLVM C++ includes are forbidden at this level and (as stated in the design), torch_mlir should use the PyTorch runtime and support libraries (not introduce an incidental C++ dependency on LLVM). * Also deletes mnist-playground as it was proving impossible to keep the grid of PyTorch vs system ABI divisions functioning. I am open to a less drastic course here (optional/disabled by default?) * This gets us pretty close to just using PyTorch's extension builder API, which will be nice for distribution (i.e. it integrates well with the PyTorch ecosystem for deployment). I ended up just simplifying the in-tree CMake support for now. * Fixes #138 --- CMakeLists.txt | 16 +- README.md | 5 - cmake/modules/ConfigurePyTorch.cmake | 68 +++- cmake/modules/NpcompDetectPythonEnv.cmake | 3 +- frontends/CMakeLists.txt | 5 - frontends/pytorch/csrc/CMakeLists.txt | 32 +- frontends/pytorch/csrc/builder/CMakeLists.txt | 30 -- .../pytorch/csrc/builder/acap_dispatch.cpp | 11 +- .../pytorch/csrc/builder/acap_dispatch.h | 4 +- .../pytorch/csrc/builder/func_builder.cpp | 25 +- frontends/pytorch/csrc/builder/func_builder.h | 15 +- .../pytorch/csrc/builder/graph_importer.cpp | 16 +- .../pytorch/csrc/builder/graph_importer.h | 8 +- .../pytorch/csrc/builder/module_builder.cpp | 14 +- .../pytorch/csrc/builder/module_builder.h | 1 - .../pytorch/csrc/init_python_bindings.cpp | 124 +------- .../graph_export/test_script_debug_info.py | 9 +- tools/CMakeLists.txt | 3 - tools/mnist-playground/CMakeLists.txt | 48 --- tools/mnist-playground/README.md | 42 --- tools/mnist-playground/fc.mlir | 15 - tools/mnist-playground/mnist-playground.cpp | 298 ------------------ 22 files changed, 134 insertions(+), 658 deletions(-) delete mode 100644 frontends/CMakeLists.txt delete mode 100644 frontends/pytorch/csrc/builder/CMakeLists.txt delete mode 100644 tools/mnist-playground/CMakeLists.txt delete mode 100644 tools/mnist-playground/README.md delete mode 100644 tools/mnist-playground/fc.mlir delete mode 100644 tools/mnist-playground/mnist-playground.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 7505fc22e..558d4a8e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,14 +150,7 @@ message(STATUS "Found python libraries: ${Python3_LIBRARIES}") # Pytorch Configuration #------------------------------------------------------------------------------- -if(NPCOMP_ENABLE_PYTORCH) - ProbeForPyTorchInstall() - if(NPCOMP_ENABLE_PYTORCH STREQUAL "OPTIONAL") - find_package(Torch) - else() - find_package(Torch REQUIRED) - endif() -endif() +NpcompFindPyTorch(${NPCOMP_ENABLE_PYTORCH}) #------------------------------------------------------------------------------- # Pybind11 Configuration @@ -186,8 +179,13 @@ add_subdirectory(include/npcomp) add_subdirectory(lib) add_subdirectory(python) add_subdirectory(test) -add_subdirectory(frontends) # Tools needs to come late to ensure that NPCOMP_ALL_LIBS is populated. # Generally things after this point may depend on NPCOMP_ALL_LIBS or libNPCOMP.so. add_subdirectory(tools) + +if(${TORCH_FOUND}) + add_subdirectory(frontends/pytorch) +else() + message("Skipping pytorch frontend, because PyTorch not found!") +endif() diff --git a/README.md b/README.md index 7f7501195..41ff71b20 100644 --- a/README.md +++ b/README.md @@ -94,10 +94,6 @@ git submodule update LLVM_VERSION=10 export CC=clang-$LLVM_VERSION export CXX=clang++-$LLVM_VERSION -# If compiling on a new OS that defaults to the CXX11 ABI (i.e. Ubuntu >= 20.04) -# and looking to use binary installs from the PyTorch website, you must build -# without the CXX11 ABI. -export CXXFLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" export LDFLAGS=-fuse-ld=$(which ld.lld-$LLVM_VERSION) # Build and install LLVM/MLIR into the ./install-mlir directory @@ -124,7 +120,6 @@ source .env ### PyTorch Frontend (with PyTorch installed via conda) ```shell -# See note above about -D_GLIBCXX_USE_CXX11_ABI=0 ./build_tools/cmake_configure.sh cmake --build build --target check-npcomp check-frontends-pytorch ``` diff --git a/cmake/modules/ConfigurePyTorch.cmake b/cmake/modules/ConfigurePyTorch.cmake index 8a28820b8..109028c93 100644 --- a/cmake/modules/ConfigurePyTorch.cmake +++ b/cmake/modules/ConfigurePyTorch.cmake @@ -1,11 +1,35 @@ -function(ProbeForPyTorchInstall) +# NpcompFindPyTorch +# Calls find_package on Torch and does any needed post-processing. +# The enable_pytorch flag can be OFF, ON or OPTIONAL. +macro(NpcompFindPyTorch enable_pytorch) + if(${enable_pytorch} OR ${enable_pytorch} STREQUAL "OPTIONAL") + NpcompProbeForPyTorchInstall() + if(${enable_pytorch} STREQUAL "OPTIONAL") + find_package(Torch 1.8) + else() + find_package(Torch 1.8 REQUIRED) + endif() + + if(${TORCH_FOUND}) + NpcompConfigurePyTorch() + endif() + else() + message(STATUS "Not configuring PyTorch (disabled)") + endif() +endmacro() + +# NpcompProbeForPyTorchInstall +# Attempts to find a Torch installation and set the Torch_ROOT variable +# based on introspecting the python environment. This allows a subsequent +# call to find_package(Torch) to work. +function(NpcompProbeForPyTorchInstall) if(Torch_ROOT) message(STATUS "Using cached Torch root = ${Torch_ROOT}") else() message(STATUS "Checking for PyTorch using ${PYTHON_EXECUTABLE} ...") execute_process( COMMAND ${PYTHON_EXECUTABLE} - -c "import os;import torch;print(os.path.dirname(torch.__file__), end='')" + -c "import os;import torch;print(torch.utils.cmake_prefix_path, end='')" WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE PYTORCH_STATUS OUTPUT_VARIABLE PYTORCH_PACKAGE_DIR) @@ -15,8 +39,42 @@ function(ProbeForPyTorchInstall) endif() message(STATUS "Found PyTorch installation at ${PYTORCH_PACKAGE_DIR}") - # PyTorch stashes its installed .cmake files under share/cmake/Torch. - set(Torch_ROOT "${PYTORCH_PACKAGE_DIR}/share/cmake/Torch" - CACHE STRING "Torch package root") + set(Torch_ROOT "${PYTORCH_PACKAGE_DIR}" CACHE STRING + "Torch configure directory" FORCE) + endif() +endfunction() + +# NpcompConfigurePyTorch +# Performs configuration of PyTorch flags after CMake has found it to be +# present. Most of this comes down to detecting whether building against a +# source or official binary and adjusting compiler options in the latter case +# (in the former, we assume that it was built with system defaults). We do this +# conservatively and assume non-binary builds by default. +# +# In the future, we may want to switch away from custom building these +# extensions and instead rely on the Torch machinery directly (definitely want +# to do that for official builds). +function(NpcompConfigurePyTorch) + if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") + # Linux specific libstdcpp ABI checking. + message(STATUS "Checking if Torch is an official binary ...") + execute_process( + COMMAND ${PYTHON_EXECUTABLE} + -c "from torch.utils import cpp_extension as c; import sys; sys.exit(0 if c._is_binary_build() else 1)" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE _is_binary_build) + if(${_is_binary_build} EQUAL 0) + set(TORCH_CXXFLAGS "") + if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11") + elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang") + set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=1011 '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'") + else() + message(WARNING "Unrecognized compiler. Cannot determine ABI flags.") + return() + endif() + message(STATUS "Detected Torch official binary build. Setting ABI flags: ${TORCH_CXXFLAGS}") + set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE) + endif() endif() endfunction() diff --git a/cmake/modules/NpcompDetectPythonEnv.cmake b/cmake/modules/NpcompDetectPythonEnv.cmake index 9367d0de5..1edae312c 100644 --- a/cmake/modules/NpcompDetectPythonEnv.cmake +++ b/cmake/modules/NpcompDetectPythonEnv.cmake @@ -12,7 +12,7 @@ function(npcomp_detect_pybind11_install) if(pybind11_DIR) message(STATUS "Using explicit pybind11 cmake directory: ${pybind11_DIR} (-Dpybind11_DIR to change)") else() - message(CHECK_START "Checking for pybind11 in python path...") + message(STATUS "Checking for pybind11 in python path...") execute_process( COMMAND "${Python3_EXECUTABLE}" -c "import pybind11;print(pybind11.get_cmake_dir(), end='')" @@ -24,7 +24,6 @@ function(npcomp_detect_pybind11_install) message(CHECK_FAIL "not found (install via 'pip install pybind11' or set pybind11_DIR)") return() endif() - message(CHECK_PASS "found (${PACKAGE_DIR})") set(pybind11_DIR "${PACKAGE_DIR}" PARENT_SCOPE) endif() endfunction() diff --git a/frontends/CMakeLists.txt b/frontends/CMakeLists.txt deleted file mode 100644 index 1c7105443..000000000 --- a/frontends/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -if(${TORCH_FOUND}) - add_subdirectory(pytorch) -else() - message("Skipping pytorch frontend, because PyTorch not found!") -endif() diff --git a/frontends/pytorch/csrc/CMakeLists.txt b/frontends/pytorch/csrc/CMakeLists.txt index 1fad264d8..f9320686f 100644 --- a/frontends/pytorch/csrc/CMakeLists.txt +++ b/frontends/pytorch/csrc/CMakeLists.txt @@ -1,51 +1,49 @@ -add_subdirectory(builder) - -include(NpcompPython) - # Sharp edge: Torch extensions need to use the same pybind11 that torch # was compiled with, or else there will be issues in cross module exception # handling (which will abort instead of raise). We circumvent the possibility # by forcing the torch directories first. include_directories(BEFORE ${TORCH_INCLUDE_DIRS} - ${TORCH_INSTALL_PREFIX}/include/TH - ${TORCH_INSTALL_PREFIX}/include/THC/opt/pytorch/pytorch ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR} ${Python3_INCLUDE_DIRS} - # TODO: Fix implicit ordering. If PyTorch was build against an external - # pybind11, then it will not be in the above search path and must be - # resolved here, in the hope that it is the same one we were configured - # with (which it should be if installed via pip). This is really fragile, - # though, causing cast failures at runtime if we get it wrong. Come up with - # a way to strengthen this. - ${pybind11_INCLUDE_DIR} ) link_directories("${TORCH_INSTALL_PREFIX}/lib") add_library(NPCOMPTorchMLIRExt SHARED + builder/acap_dispatch.cpp + builder/debug.cpp + builder/func_builder.cpp + builder/graph_importer.cpp + builder/module_builder.cpp + builder/python_bindings.cpp init_python_bindings.cpp ) + target_link_libraries(NPCOMPTorchMLIRExt + # NPCOMP shared library. + # TODO: Debug why order matters here (if NPCOMP is included last a large + # amount of LLVM/MLIR/NPCOMP ends up compiled into this library). + NPCOMP + ${TORCH_LIBRARIES} ${Python3_LIBRARIES} torch_python - npcomp_torch_builder_bindings - - # NPCOMP shared library. - NPCOMP ) add_dependencies(NPCOMPTorchMLIRExt # Uses of the torch_mlir extension also require the npcomp extension to # be built. NPCOMPNativePyExt ) + +message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS}") set_target_properties(NPCOMPTorchMLIRExt PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/python OUTPUT_NAME _torch_mlir PREFIX "${PYTHON_MODULE_PREFIX}" SUFFIX "${PYTHON_MODULE_EXTENSION}" CXX_VISIBILITY_PRESET "hidden" + COMPILE_FLAGS "${TORCH_CXXFLAGS}" ) npcomp_python_target_compile_options(NPCOMPTorchMLIRExt) diff --git a/frontends/pytorch/csrc/builder/CMakeLists.txt b/frontends/pytorch/csrc/builder/CMakeLists.txt deleted file mode 100644 index c89104d1f..000000000 --- a/frontends/pytorch/csrc/builder/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -include_directories( - ${TORCH_INCLUDE_DIRS} - ${TORCH_INSTALL_PREFIX}/include/TH - ${TORCH_INSTALL_PREFIX}/include/THC/opt/pytorch/pytorch - ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_BINARY_DIR} - ${Python3_INCLUDE_DIRS} - # TODO: Fix implicit ordering. If PyTorch was build against an external - # pybind11, then it will not be in the above search path and must be - # resolved here, in the hope that it is the same one we were configured - # with (which it should be if installed via pip). This is really fragile, - # though, causing cast failures at runtime if we get it wrong. Come up with - # a way to strengthen this. - ${pybind11_INCLUDE_DIR} - ) -link_directories("${TORCH_INSTALL_PREFIX}/lib") -add_library(npcomp_torch_builder_bindings - acap_dispatch.cpp - debug.cpp - func_builder.cpp - graph_importer.cpp - module_builder.cpp - python_bindings.cpp -) - -target_link_libraries(npcomp_torch_builder_bindings - ${TORCH_LIBRARIES} - ${Python3_LIBRARIES} - torch_python - ) diff --git a/frontends/pytorch/csrc/builder/acap_dispatch.cpp b/frontends/pytorch/csrc/builder/acap_dispatch.cpp index 718466881..d14d848dd 100644 --- a/frontends/pytorch/csrc/builder/acap_dispatch.cpp +++ b/frontends/pytorch/csrc/builder/acap_dispatch.cpp @@ -46,7 +46,7 @@ static c10::DispatchKey kAcapGradDispatchKey = AcapController::TracedKernelCallBuilder::TracedKernelCallBuilder( AcapController &parent, MlirContext context, MlirLocation loc, const c10::OperatorHandle &opHandle, - llvm::Optional overrideKernelName) + c10::optional overrideKernelName) : KernelCallBuilder(context, loc, overrideKernelName ? *overrideKernelName : opHandle.operator_name().name, @@ -125,8 +125,8 @@ void AcapController::contextExit(py::object exc_type, py::object exc_val, void AcapController::returns(std::vector tensors) { verifyHasNotReturned(); - llvm::SmallVector returnsTypes; - llvm::SmallVector returnsValues; + std::vector returnsTypes; + std::vector returnsValues; for (auto &tensor : tensors) { MlirValue v = funcBuilder->lookupTensor(tensor); if (mlirValueIsNull(v)) { @@ -470,7 +470,7 @@ MlirValue AcapController::mapIValueToMlirValue(MlirLocation loc, } if (ival.isList()) { auto list = ival.toList(); - llvm::SmallVector elements; + std::vector elements; for (IValue element : list) { elements.push_back(mapIValueToMlirValue(loc, element)); } @@ -557,8 +557,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) { } else { elementType = typeMapper.mapFromTorchScalarType(tensor.scalar_type()); } - llvm::SmallVector shape(tensor.sizes().begin(), - tensor.sizes().end()); + std::vector shape(tensor.sizes().begin(), tensor.sizes().end()); MlirType shapedType = mlirRankedTensorTypeGetChecked( shape.size(), shape.data(), elementType, loc); if (mlirTypeIsNull(shapedType)) { diff --git a/frontends/pytorch/csrc/builder/acap_dispatch.h b/frontends/pytorch/csrc/builder/acap_dispatch.h index 0a9443e0b..37d4a261d 100644 --- a/frontends/pytorch/csrc/builder/acap_dispatch.h +++ b/frontends/pytorch/csrc/builder/acap_dispatch.h @@ -85,7 +85,7 @@ private: TracedKernelCallBuilder( AcapController &parent, MlirContext context, MlirLocation loc, const c10::OperatorHandle &opHandle, - llvm::Optional overrideKernelName = llvm::None); + c10::optional overrideKernelName = c10::nullopt); void addOperand(const c10::IValue &value); void addResult(const c10::IValue &result); MlirOperation create(); @@ -94,7 +94,7 @@ private: AcapController &parent; const c10::OperatorHandle &opHandle; int resultCount = 0; - llvm::SmallVector, 4> resultIndexToTensorMap; + std::vector> resultIndexToTensorMap; }; MlirLocation getCurrentLocation(); diff --git a/frontends/pytorch/csrc/builder/func_builder.cpp b/frontends/pytorch/csrc/builder/func_builder.cpp index 1254a1b16..56332e837 100644 --- a/frontends/pytorch/csrc/builder/func_builder.cpp +++ b/frontends/pytorch/csrc/builder/func_builder.cpp @@ -25,10 +25,10 @@ static MlirOperation createStandardConstant(MlirLocation loc, MlirType type, } KernelCallBuilder::KernelCallBuilder(MlirContext context, MlirLocation loc, - llvm::StringRef kernelName, + const std::string &kernelName, const c10::FunctionSchema &schema) : context(context), loc(loc), state("torch.kernel_call", loc), - kernelName(kernelName), schema(schema) { + schema(schema) { (void)this->context; // Preserve for future. MlirNamedAttribute kernelNameAttr = mlirNamedAttributeGet( toMlirStringRef("kernelName"), @@ -45,7 +45,7 @@ void KernelCallBuilder::addSchemaAttrs() { // sigIsVararg // sigIsVarret // sigIsMutable - llvm::SmallVector attrs; + std::vector attrs; attrs.push_back( mlirNamedAttributeGet(toMlirStringRef("sigIsMutable"), mlirBoolAttrGet(context, schema.is_mutable()))); @@ -57,7 +57,7 @@ void KernelCallBuilder::addSchemaAttrs() { mlirBoolAttrGet(context, schema.is_varret()))); // Arg types. - llvm::SmallVector args; + std::vector args; for (auto &arg : schema.arguments()) { const std::string &typeStr = arg.type()->str(); args.push_back(mlirStringAttrGet( @@ -68,7 +68,7 @@ void KernelCallBuilder::addSchemaAttrs() { mlirArrayAttrGet(context, args.size(), args.data()))); // Return types. - llvm::SmallVector returns; + std::vector returns; for (auto &ret : schema.returns()) { const std::string &typeStr = ret.type()->str(); returns.push_back(mlirStringAttrGet( @@ -170,7 +170,7 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, } // Ranked with possibly dynamic dims. auto &symbolicShape = tensorType->symbolic_sizes(); - llvm::SmallVector dims; + std::vector dims; dims.resize(*sizes.rank()); for (size_t i = 0; i < dims.size(); ++i) { auto shapeSymbol = symbolicShape[i]; @@ -204,12 +204,12 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) { std::unique_ptr FuncBuilder::createFunction(FuncBuilder::Inserter &inserter, - MlirLocation location, llvm::StringRef name, - llvm::SmallVectorImpl &inputTypes) { + MlirLocation location, const std::string &name, + std::vector &inputTypes) { auto context = mlirLocationGetContext(location); // TODO: Create a dedicated API upstream for creating/manipulating func ops. // (this is fragile and reveals details that are not guaranteed). - llvm::SmallVector funcAttrs; + std::vector funcAttrs; funcAttrs.push_back( mlirNamedAttributeGet(toMlirStringRef("type"), mlirTypeAttrGet(mlirFunctionTypeGet( @@ -242,8 +242,7 @@ FuncBuilder::createFunction(FuncBuilder::Inserter &inserter, context, funcOp, BlockBuilder(entryBlock, /*returnOp=*/{nullptr}, true))); } -void FuncBuilder::rewriteFuncReturnTypes( - llvm::SmallVectorImpl &resultTypes) { +void FuncBuilder::rewriteFuncReturnTypes(std::vector &resultTypes) { // Get inputs from current function type. MlirAttribute funcTypeAttr = mlirOperationGetAttributeByName(funcOp, toMlirStringRef("type")); @@ -252,7 +251,7 @@ void FuncBuilder::rewriteFuncReturnTypes( assert(mlirAttributeIsAType(funcTypeAttr) && "function type is not a TypeAttr"); MlirType funcType = mlirTypeAttrGetValue(funcTypeAttr); - llvm::SmallVector inputTypes; + std::vector inputTypes; for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(funcType); i < e; ++i) { inputTypes.push_back(mlirFunctionTypeGetInput(funcType, i)); } @@ -328,7 +327,7 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc, } MlirValue FuncBuilder::buildList(MlirLocation loc, - llvm::SmallVectorImpl &elements) { + std::vector &elements) { MlirType resultType = npcompListTypeGet(context); OperationStateHolder state{"basicpy.build_list", loc}; mlirOperationStateAddResults(state, 1, &resultType); diff --git a/frontends/pytorch/csrc/builder/func_builder.h b/frontends/pytorch/csrc/builder/func_builder.h index e8c7a42c2..9eed0235d 100644 --- a/frontends/pytorch/csrc/builder/func_builder.h +++ b/frontends/pytorch/csrc/builder/func_builder.h @@ -11,8 +11,6 @@ #include "mlir_utils.h" #include "mlir-c/IR.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/StringRef.h" #include #include @@ -102,7 +100,7 @@ private: class KernelCallBuilder { public: KernelCallBuilder(MlirContext context, MlirLocation loc, - llvm::StringRef kernelName, + const std::string &kernelName, const c10::FunctionSchema &schema); void addOperand(MlirValue operand); void addResultType(MlirType resultType); @@ -115,7 +113,6 @@ protected: private: void addSchemaAttrs(); OperationStateHolder state; - llvm::StringRef kernelName; const c10::FunctionSchema &schema; }; @@ -131,8 +128,7 @@ public: /// to a parent. static std::unique_ptr createFunction(Inserter &inserter, MlirLocation location, - llvm::StringRef name, - llvm::SmallVectorImpl &inputTypes); + const std::string &name, std::vector &inputTypes); MlirContext getContext() { return context; } MlirOperation getFuncOp() { return funcOp; } @@ -143,7 +139,7 @@ public: /// Rewrites the function's signature to return the given types. It is /// assumed that a compatible terminator has been added. - void rewriteFuncReturnTypes(llvm::SmallVectorImpl &resultTypes); + void rewriteFuncReturnTypes(std::vector &resultTypes); /// Maps a live Tensor to an MlirValue. void mapTensor(at::Tensor tensor, MlirValue value) { @@ -168,8 +164,7 @@ public: MlirValue getGeneralConstant(MlirLocation loc, MlirAttribute value); /// Builds a list with the given elements - MlirValue buildList(MlirLocation loc, - llvm::SmallVectorImpl &elements); + MlirValue buildList(MlirLocation loc, std::vector &elements); private: FuncBuilder(MlirContext context, MlirOperation funcOp, @@ -200,7 +195,7 @@ private: /// that tensors may be mapped and accessed in proximity. /// TODO: Tensors referenced via an IValue support hash code lookup and /// identity checks. Switch to this instead of a linear scan. - llvm::SmallVector, 16> tensorValueMap; + std::vector> tensorValueMap; }; } // namespace torch_mlir diff --git a/frontends/pytorch/csrc/builder/graph_importer.cpp b/frontends/pytorch/csrc/builder/graph_importer.cpp index 590096c73..29d5958c8 100644 --- a/frontends/pytorch/csrc/builder/graph_importer.cpp +++ b/frontends/pytorch/csrc/builder/graph_importer.cpp @@ -7,6 +7,8 @@ #include "graph_importer.h" +#include + #include "mlir_utils.h" #include "mlir-c/BuiltinAttributes.h" @@ -33,7 +35,7 @@ public: MlirValue findRequiredValue(MlirLocation loc, torch::jit::Value *torchValue); private: - llvm::DenseMap valueMap; + std::unordered_map valueMap; NodeScope *prev = nullptr; }; @@ -134,8 +136,8 @@ void GraphImporter::NodeImporter::importNode() { mlirBlockInsertOwnedOperationBefore(block, ip, op); // Map results. - for (auto it : llvm::enumerate(node->outputs())) { - scope->bindValue(it.value(), mlirOperationGetResult(op, it.index())); + for (size_t i = 0; i < node->outputs().size(); ++i) { + scope->bindValue(node->outputs()[i], mlirOperationGetResult(op, i)); } return; } @@ -151,7 +153,7 @@ void GraphImporter::NodeImporter::importNode() { void GraphImporter::NodeImporter::importReturnOp() { OperationStateHolder s("std.return", loc); - llvm::SmallVector returnsValues; + std::vector returnsValues; for (auto *input : node->inputs()) { returnsValues.push_back(scope->findRequiredValue(loc, input)); } @@ -271,9 +273,9 @@ void GraphImporter::importGenericFunc() { // Bind inputs. NodeScope scope; - for (const auto &it : llvm::enumerate(graph->inputs())) { - MlirValue value = mlirBlockGetArgument(entryBlock, it.index()); - scope.bindValue(it.value(), value); + for (size_t i = 0; i < graph->inputs().size(); ++i) { + MlirValue value = mlirBlockGetArgument(entryBlock, i); + scope.bindValue(graph->inputs()[i], value); } // Walk body nodes. diff --git a/frontends/pytorch/csrc/builder/graph_importer.h b/frontends/pytorch/csrc/builder/graph_importer.h index 02e738369..fa449a91c 100644 --- a/frontends/pytorch/csrc/builder/graph_importer.h +++ b/frontends/pytorch/csrc/builder/graph_importer.h @@ -41,8 +41,8 @@ public: /// when to import globals as constants vs shared arrays, etc. struct MlirMappingOptions { MlirContext context; - llvm::Optional genericFuncName; - llvm::Optional funcName; + c10::optional genericFuncName; + c10::optional funcName; TypeMapper &typeMapper; FuncBuilder::Inserter &inserter; }; @@ -82,8 +82,8 @@ private: MlirLocation defaultLoc; /// Argument and return types for the generic func. - llvm::SmallVector genericFuncArgTypes; - llvm::SmallVector genericFuncReturnTypes; + std::vector genericFuncArgTypes; + std::vector genericFuncReturnTypes; }; } // namespace torch_mlir diff --git a/frontends/pytorch/csrc/builder/module_builder.cpp b/frontends/pytorch/csrc/builder/module_builder.cpp index 3e1b37424..efb38f295 100644 --- a/frontends/pytorch/csrc/builder/module_builder.cpp +++ b/frontends/pytorch/csrc/builder/module_builder.cpp @@ -74,7 +74,7 @@ std::shared_ptr ModuleBuilder::startCaptureFunction(std::string &name, std::vector args) { // TODO: Verify that arguments do not alias each other. - llvm::SmallVector inputTypes; + std::vector inputTypes; for (auto &arg : args) { inputTypes.push_back(typeMapper.forwardTensorToType(arg)); } @@ -88,9 +88,8 @@ ModuleBuilder::startCaptureFunction(std::string &name, assert(mlirBlockGetNumArguments(entryBlock) == static_cast(args.size()) && "entry block incorrect arg arity"); - for (auto it : llvm::enumerate(args)) { - funcBuilder->mapTensor(it.value(), - mlirBlockGetArgument(entryBlock, it.index())); + for (size_t i = 0; i < args.size(); ++i) { + funcBuilder->mapTensor(args[i], mlirBlockGetArgument(entryBlock, i)); } return std::make_shared(typeMapper, std::move(funcBuilder)); } @@ -100,10 +99,9 @@ ModuleBuilder::importFunction(torch::jit::StrongFunctionPtr function) { auto inserter = createInserter(); GraphImporter::MlirMappingOptions mappingOptions{ context, - llvm::None, // genericFuncName (default to auto) - llvm::None, // funcName (default to auto) - typeMapper, inserter, - }; + c10::nullopt, // genericFuncName (default to auto) + c10::nullopt, // funcName (default to auto) + typeMapper, inserter}; auto graphImporter = GraphImporter::forPythonJitFunc( function.function_, std::move(mappingOptions)); graphImporter->initialize(); diff --git a/frontends/pytorch/csrc/builder/module_builder.h b/frontends/pytorch/csrc/builder/module_builder.h index bff96a2ab..46e0e5316 100644 --- a/frontends/pytorch/csrc/builder/module_builder.h +++ b/frontends/pytorch/csrc/builder/module_builder.h @@ -13,7 +13,6 @@ #include "acap_dispatch.h" #include "mlir-c/IR.h" -#include "llvm/ADT/SmallVector.h" #include #include diff --git a/frontends/pytorch/csrc/init_python_bindings.cpp b/frontends/pytorch/csrc/init_python_bindings.cpp index 83098577e..df5116c62 100644 --- a/frontends/pytorch/csrc/init_python_bindings.cpp +++ b/frontends/pytorch/csrc/init_python_bindings.cpp @@ -12,137 +12,15 @@ // b) Direct IR translation from PyTorch Graphs (not implemented). // c) Using the PyTorch JIT facility (not implemented). -#include "llvm/Support/Debug.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/raw_ostream.h" - -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Verifier.h" -#include "mlir/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" - -#include "npcomp/Dialect/ATen/IR/ATenDialect.h" -#include "npcomp/Dialect/ATen/Transforms/LivenessReport.h" -#include "npcomp/Dialect/ATen/Transforms/Passes.h" - #include "init_python_bindings.h" #include namespace py = pybind11; -using namespace mlir; - -namespace llvm { -extern bool DebugFlag; -} - namespace torch_mlir { namespace { -mlir::OwningModuleRef LoadModule(mlir::MLIRContext &context, std::string mlir) { - - mlir::OwningModuleRef module; - - std::unique_ptr membuf = - llvm::MemoryBuffer::getMemBuffer(mlir); - - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(membuf), llvm::SMLoc()); - module = mlir::parseSourceFile(sourceMgr, &context); - - if (!module) { - llvm::errs() << "Error can't parse mlir module\n"; - return nullptr; - } - if (failed(mlir::verify(*module))) { - llvm::errs() << "Error verifying MLIR module\n"; - return nullptr; - } - if (!module) - return nullptr; - return module; -} - -void InitModuleBindings(py::module &m) { - m.def( - "_op_report", - [](std::string mlir) -> std::string { - mlir::MLIRContext context; - auto module = LoadModule(context, mlir); - mlir::PassManager pm(module->getContext()); - - // our pass - std::string report; - pm.addPass(mlir::NPCOMP::aten::createATenLayerNamePass()); - pm.addPass(mlir::NPCOMP::aten::createATenOpReportPass(report)); - - if (failed(pm.run(*module))) { - llvm::errs() << "ATenOpReportPass failed"; - return ""; - } - return report; - }, - "run ATenOpReportPass"); - - m.def( - "_liveness_report", - [](std::string mlir) -> std::string { - mlir::MLIRContext context; - auto module = LoadModule(context, mlir); - - mlir::PassManager pm(module->getContext()); - - pm.addPass(mlir::NPCOMP::aten::createATenLayerNamePass()); - if (failed(pm.run(*module))) { - llvm::errs() << "ATen generate liveness report failed"; - return ""; - } - - auto mOp = module.get(); - auto liveness = mlir::NPCOMP::aten::LivenessReport(mOp); - std::string report = liveness.emitJSONReport(); - return report; - }, - "generate liveness report"); - - // TODO: Could this be implemented with MLIR python bindings? - m.def( - "lower_to_std", - [](std::string mlir) -> std::string { - mlir::MLIRContext context; - auto module = LoadModule(context, mlir); - - PassManager pm0(module->getContext()); - pm0.addPass(mlir::NPCOMP::aten::createATenLoweringPass()); - pm0.addPass(mlir::NPCOMP::aten::createReturnEliminationPass()); - pm0.addPass(mlir::createCSEPass()); - - if (failed(pm0.run(*module))) { - llvm::errs() << "aten to loops conversion failed "; - return ""; - } - - // dump MLIR to string and return - std::string s; - llvm::raw_string_ostream ss(s); - ss << "# Lowered to Std\n"; - module->print(ss); - return ss.str(); - }, - "lower aten to std dialect"); - - m.def( - "set_debug", - [](bool b, std::string type) -> void { - llvm::setCurrentDebugType(type.c_str()); - llvm::DebugFlag = b; - }, - "enable/disable debug messages"); -} +void InitModuleBindings(py::module &m) {} } // namespace diff --git a/frontends/pytorch/test/graph_export/test_script_debug_info.py b/frontends/pytorch/test/graph_export/test_script_debug_info.py index e44830009..391d5be92 100644 --- a/frontends/pytorch/test/graph_export/test_script_debug_info.py +++ b/frontends/pytorch/test/graph_export/test_script_debug_info.py @@ -17,14 +17,13 @@ mb = torch_mlir.ModuleBuilder() @mb.import_function @torch.jit.script def add3(t0, t1, t2): - # CHECK: constant 1{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 2]] - # CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]] + # TODO: Checks for debug info are quite hard with the new trailing debug + # attribute print. See if this can be improved. + # CHECK: loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]] intermediate = t0 + t1 - # CHECK: aten::add{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]] + # CHECK: loc({{.*}}test_script_debug_info.py":[[# @LINE + 1]] final = intermediate + t2 - # CHECK: return{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 3]] return final - # CHECK: }{{.*}}loc({{.*}}test_script_debug_info.py":[[# @LINE - 5]] # Verify again with debug info present. Just checking that it makes it in there. mb.module.operation.print(enable_debug_info=True) diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 4dea2ccca..693b60687 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,6 +1,3 @@ add_subdirectory(npcomp-opt) -if(${TORCH_FOUND}) - add_subdirectory(mnist-playground) -endif() add_subdirectory(npcomp-run-mlir) add_subdirectory(npcomp-shlib) diff --git a/tools/mnist-playground/CMakeLists.txt b/tools/mnist-playground/CMakeLists.txt deleted file mode 100644 index e6c460b69..000000000 --- a/tools/mnist-playground/CMakeLists.txt +++ /dev/null @@ -1,48 +0,0 @@ -# TODO: This is copied from frontends/pytorch/csrc/c10_dispatch/CMakeLists.txt -# What is the idiomatic way of sharing this in CMake? -include_directories( - ${TORCH_INCLUDE_DIRS} - ${TORCH_INSTALL_PREFIX}/include/TH - ${TORCH_INSTALL_PREFIX}/include/THC/opt/pytorch/pytorch - ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_BINARY_DIR} - ${Python3_INCLUDE_DIRS} - ) -link_directories("${TORCH_INSTALL_PREFIX}/lib") - - -set(LLVM_LINK_COMPONENTS - Core - Support - nativecodegen - ) - -add_npcomp_executable(mnist-playground - mnist-playground.cpp - ) -llvm_update_compile_flags(mnist-playground) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -target_link_libraries(mnist-playground PRIVATE - # Shared library deps first ensure we get most of what we need from libraries. - NPCOMP - MLIR - - MLIRAnalysis - MLIREDSC - MLIRExecutionEngine - MLIRIR - MLIRJitRunner - MLIRLLVMIR - MLIRParser - MLIRTargetLLVMIR - MLIRSupport - NPCOMPInitAll - NPCOMPRefBackendJITHelpers - ${conversion_libs} - ${dialect_libs} - ${TORCH_LIBRARIES} - ) -add_dependencies(mnist-playground - NPCOMPCompilerRuntimeShlib - ) diff --git a/tools/mnist-playground/README.md b/tools/mnist-playground/README.md deleted file mode 100644 index bebdaec39..000000000 --- a/tools/mnist-playground/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# mnist-playground - -This is intended to be a short-lived "playground" for doing various experiments, guided by a real model use case, for improving the npcomp reference backend. - -It's expected that utilities developed here will graduate to a more general utility or that this utility will be obsoleted by Python-driven flows once those come online. - -## Goals: - -- Obtain a performance-grounded analysis of the TCF/TCP design + reference backend design, and improve the designs. - -- Make forward progress on TCF/TCP + reference backend while the PyTorch frontend is being brought up. - -## Rough sketch of how we intend to get there: - -1. Link against PyTorch, and write a simple routine to do inference on a simple FC MNIST. - -2. Write a similar routine in TCF, extending TCF and the reference backend as needed for functional completeness. The PyTorch code serves as a numerical correctness reference. - -3. Run and profile the reference backend and obtain a set of action items for design improvements, both to performance and stability. The PyTorch code serves as a performance baseline. - -4. Implement important action items on a priority basis, and document remaining major design issues that don't make sense to address at this time, along with a justification for why the current design doesn't prevent us from eventually solving them. Iterate the previous step and this one as makes sense. - -5. (Stretch) Add support for convolutional MNIST and/or training. - -## Current Status - -Step 1. DONE - -Step 2. MOSTLY DONE. Still need to improve the op set to make the FC MNIST more complete. In particular, implementing functionality for reshape and softmax. - -Step 3. STARTING. Initial performance on 10x784x100 (10 FC feature, batch 100) is 66x off from PyTorch. No profiling done yet. - -Example command line (the .mlir file and `-invoke` are similar to npcomp-run-mlir): - -``` -$ mnist-playground tools/mnist-playground/fc.mlir -invoke fc -PyTorch: numRuns: 16384 nsPerRun: 3.947563e+05 -RefBackend: numRuns: 256 nsPerRun: 2.471073e+07 -Ratio (RefBackend / PyTorch): 62.5974 -``` - -There is currently a fragile dependency between hardcoded `at::` function calls in the .cpp file and the TCF code in the `.mlir` file. A correctness check is done to make sure they agree. Once we have a PyTorch frontend and/or ATen roundrip ATen backend oneline, we can avoid this fragility. diff --git a/tools/mnist-playground/fc.mlir b/tools/mnist-playground/fc.mlir deleted file mode 100644 index 0bee43506..000000000 --- a/tools/mnist-playground/fc.mlir +++ /dev/null @@ -1,15 +0,0 @@ - -func @fc( - // TODO: Implement "reshape" so that %image can be passed as batch of 2D tensors. - %image: tensor, - %weights: tensor, - %biases: tensor) --> ( - tensor -) { - %0 = tcf.matmul %weights, %image : (tensor, tensor) -> tensor - %1 = tcf.add %0, %biases : (tensor, tensor) -> tensor - // TODO: Implement softmax for classification. - // For now, this returns a not-terribly useful number. - return %1 : tensor -} diff --git a/tools/mnist-playground/mnist-playground.cpp b/tools/mnist-playground/mnist-playground.cpp deleted file mode 100644 index 94e3db456..000000000 --- a/tools/mnist-playground/mnist-playground.cpp +++ /dev/null @@ -1,298 +0,0 @@ -//===- mnist-playground.cpp -------------------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/ExecutionEngine/OptUtils.h" -#include "mlir/IR/AsmState.h" -#include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" -#include "mlir/Parser.h" -#include "mlir/Pass/PassManager.h" -#include "npcomp/InitAll.h" -#include "npcomp/RefBackend/JITHelpers/JITModule.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/TargetSelect.h" - -#include - -#include - -using namespace mlir; -using llvm::Error; -using llvm::ErrorOr; -using llvm::Expected; -using llvm::StringError; -using llvm::Twine; - -//===----------------------------------------------------------------------===// -// Utilities -//===----------------------------------------------------------------------===// - -/// Wrap a string into an llvm::StringError. -static Error make_string_error(const Twine &message) { - return llvm::make_error(message.str(), - llvm::inconvertibleErrorCode()); -} - -Expected> -createJITModule(std::string mlirFile, mlir::DialectRegistry ®istry, - ArrayRef sharedLibs, bool optimize) { - MLIRContext context; - registry.loadAll(&context); - OwningModuleRef moduleRef = parseSourceFile(mlirFile, &context); - if (!moduleRef) - return make_string_error(Twine("could not open ") + mlirFile); - - ModuleOp module = *moduleRef; - - // Compile. - PassManager pm(module.getContext(), OpPassManager::Nesting::Implicit); - applyPassManagerCLOptions(pm); - refback::JITModule::buildBackendCompilationPipeline(pm, optimize); - if (failed(pm.run(module))) - return make_string_error(Twine("error compiling to jit backend")); - - return refback::JITModule::fromCompiledModule(module, sharedLibs); -} - -//===----------------------------------------------------------------------===// -// Benchmarking / correctness-testing code. -//===----------------------------------------------------------------------===// - -static Expected> -invokeJITModuleWithATenTensors(refback::JITModule &jitModule, - StringRef invokeFunction, - std::vector &args) { - - // Do a bit of checking. We don't handle all possible tensors right now. - std::vector tensorArgs; - for (auto arg : llvm::enumerate(args)) - tensorArgs.push_back(at::TensorArg(arg.value(), "arg", arg.index())); - at::CheckedFrom c = "converting to refbackrt::Tensor"; - for (auto &tensorArg : tensorArgs) - at::checkScalarType(c, tensorArg, at::ScalarType::Float); - at::checkAllContiguous(c, tensorArgs); - - SmallVector, 6> refbackInputs; - for (at::Tensor arg : args) { - SmallVector extents(arg.sizes().begin(), arg.sizes().end()); - float *data = arg.storage().data(); - // This does a deep copy of the data. Let's see if it shows up on the - // profile. - refbackInputs.push_back(refbackrt::Tensor::create( - refbackrt::ArrayRef(extents.data(), extents.size()), - refbackrt::ElementType::F32, data)); - } - - // Invoke the RefBackend function. - auto expectedOutputs = jitModule.invoke(invokeFunction, refbackInputs); - if (!expectedOutputs) - return expectedOutputs.takeError(); - auto refbackrtOutputs = std::move(*expectedOutputs); - - std::vector results; - for (auto output : refbackrtOutputs) { - std::vector sizes(output->getExtents().data(), - output->getExtents().data() + - output->getExtents().size()); - // Make a copy for passing to at::from_blob, which does its own internal - // reference counting. - auto *dataCopy = std::malloc(output->getDataByteSize()); - std::memcpy(dataCopy, output->getData(), output->getDataByteSize()); - results.push_back(at::from_blob( - dataCopy, sizes, [](void *p) { std::free(p); }, at::kFloat)); - } - return results; -} - -using InvocationFunction = - std::function>(std::vector)>; - -struct BenchmarkResult { - int numRuns; - float nsPerRun; -}; - -std::ostream &operator<<(std::ostream &os, const BenchmarkResult &result) { - os << "numRuns: " << result.numRuns << " nsPerRun: " << std::scientific - << result.nsPerRun << std::defaultfloat; - return os; -} - -Expected benchmark(std::function f) { - for (int itersAtATime = 1;; itersAtATime *= 2) { - auto start = std::chrono::steady_clock::now(); - for (int i = 0; i < itersAtATime; i++) { - auto error = f(); - if (error) - return std::move(error); - } - auto end = std::chrono::steady_clock::now(); - std::chrono::duration elapsed = end - start; - - // If the runtime is longer than 0.5 seconds, it's reliable enough. - if (elapsed.count() > 0.5f) { - BenchmarkResult result; - result.numRuns = itersAtATime; - result.nsPerRun = elapsed.count() * 10e9 / itersAtATime; - return result; - } - } - return make_string_error("too short running to benchmark!"); -} - -static Error doIt(InvocationFunction ptFunc, InvocationFunction refBackendFunc, - bool doBenchmark, int numCorrectnessTests) { - - torch::manual_seed(42); - torch::set_num_threads(1); - - std::vector args; - args.push_back(at::rand({784, 100})); - args.push_back(at::rand({10, 784})); - args.push_back(at::rand({10, 1})); - - // Initial correctness check of the two functions. - for (int correctnessTest = 0; correctnessTest < numCorrectnessTests; - correctnessTest++) { - auto expectedPt = ptFunc(args); - auto expectedRefBackend = refBackendFunc(args); - if (!expectedPt) - return expectedPt.takeError(); - if (!expectedRefBackend) - return expectedRefBackend.takeError(); - auto pt = std::move(*expectedPt); - auto refBackend = std::move(*expectedRefBackend); - if (pt.size() != refBackend.size()) - return make_string_error("mismatch in result arity!"); - for (int i = 0, e = pt.size(); i < e; i++) { - if (!at::allclose(pt[i], refBackend[i])) { - std::cout << "PyTorch:\n" << pt[i] << "\n"; - std::cout << "RefBackend:\n" << refBackend[i] << "\n"; - return make_string_error(Twine("mismatch in result contents ") + - Twine(i) + Twine(" on correctness test #") + - Twine(correctnessTest)); - } - } - } - if (!doBenchmark) - return Error::success(); - - // Benchmark the two against each other. - BenchmarkResult ptBenchmarkResult; - BenchmarkResult refBackendBenchmarkResult; - { - auto expectedResult = - benchmark([&]() -> Error { return ptFunc(args).takeError(); }); - if (!expectedResult) - return expectedResult.takeError(); - ptBenchmarkResult = std::move(*expectedResult); - } - - { - auto expectedResult = - benchmark([&]() -> Error { return refBackendFunc(args).takeError(); }); - if (!expectedResult) - return expectedResult.takeError(); - refBackendBenchmarkResult = std::move(*expectedResult); - } - std::cout << "PyTorch: " << ptBenchmarkResult << "\n"; - std::cout << "RefBackend: " << refBackendBenchmarkResult << "\n"; - std::cout << "Ratio (RefBackend / PyTorch): " - << refBackendBenchmarkResult.nsPerRun / ptBenchmarkResult.nsPerRun - << "\n"; - - // TODO: Check for memory leaks? - - return Error::success(); -} - -//===----------------------------------------------------------------------===// -// Main-related init and option parsing. -//===----------------------------------------------------------------------===// - -namespace { -namespace cl = llvm::cl; -struct Options { - cl::opt inputFile{ - cl::Positional, cl::desc("the input .mlir file"), cl::init("-")}; - cl::opt invokeFunction{"invoke", cl::Required, - cl::desc("function to invoke")}; - - cl::list sharedLibs{"shared-libs", cl::ZeroOrMore, - cl::MiscFlags::CommaSeparated, - cl::desc("Libraries to link dynamically")}; - cl::opt optimize{ - "optimize", cl::Optional, - cl::desc("whether the refback pass pipeline should run optimizations"), - cl::init(false)}; - - cl::opt benchmark{"benchmark", cl::Optional, - cl::desc("whether to do a benchmark comparison"), - cl::init(true)}; - - cl::opt numCorrectnessTests{ - "num-correctness-tests", cl::Optional, - cl::desc("how many correctness tests to run (useful for nondeterministic " - "correctness failures"), - cl::init(1)}; -}; -} // namespace - -int main(int argc, char **argv) { - mlir::DialectRegistry registry; - mlir::registerAllDialects(registry); - mlir::registerAllPasses(); - mlir::NPCOMP::registerAllDialects(registry); - mlir::NPCOMP::registerAllPasses(); - - llvm::InitLLVM y(argc, argv); - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - mlir::initializeLLVMPasses(); - - mlir::registerAsmPrinterCLOptions(); - mlir::registerPassManagerCLOptions(); - - Options options; - llvm::cl::ParseCommandLineOptions(argc, argv, "mnist playground utility\n"); - - SmallVector sharedLibs(options.sharedLibs.begin(), - options.sharedLibs.end()); - auto expectedJITModule = createJITModule(options.inputFile, registry, - sharedLibs, options.optimize); - if (Error error = expectedJITModule.takeError()) - llvm::report_fatal_error(llvm::toString(std::move(error)), - /*gen_crash_diag=*/false); - auto jitModule = std::move(*expectedJITModule); - - Error error = doIt( - [](std::vector args) { - auto image = args[0]; - auto weights = args[1]; - auto biases = args[2]; - auto v0 = at::matmul(weights, image); - auto v1 = at::add(v0, biases); - return std::vector{v1}; - }, - [&](std::vector args) { - return invokeJITModuleWithATenTensors(*jitModule, - options.invokeFunction, args); - }, - options.benchmark, options.numCorrectnessTests); - - int exitCode = EXIT_SUCCESS; - llvm::handleAllErrors(std::move(error), - [&exitCode](const llvm::ErrorInfoBase &info) { - llvm::errs() << "Error: "; - info.log(llvm::errs()); - llvm::errs() << '\n'; - exitCode = EXIT_FAILURE; - }); - return exitCode; -}