mirror of https://github.com/llvm/torch-mlir
81 lines
3.3 KiB
CMake
81 lines
3.3 KiB
CMake
# NpcompFindPyTorch
|
|
# Calls find_package on Torch and does any needed post-processing.
|
|
# The enable_pytorch flag can be OFF, ON or OPTIONAL.
|
|
macro(NpcompFindPyTorch enable_pytorch)
|
|
if(${enable_pytorch} OR ${enable_pytorch} STREQUAL "OPTIONAL")
|
|
NpcompProbeForPyTorchInstall()
|
|
if(${enable_pytorch} STREQUAL "OPTIONAL")
|
|
find_package(Torch 1.8)
|
|
else()
|
|
find_package(Torch 1.8 REQUIRED)
|
|
endif()
|
|
|
|
if(${TORCH_FOUND})
|
|
NpcompConfigurePyTorch()
|
|
endif()
|
|
else()
|
|
message(STATUS "Not configuring PyTorch (disabled)")
|
|
endif()
|
|
endmacro()
|
|
|
|
# NpcompProbeForPyTorchInstall
|
|
# Attempts to find a Torch installation and set the Torch_ROOT variable
|
|
# based on introspecting the python environment. This allows a subsequent
|
|
# call to find_package(Torch) to work.
|
|
function(NpcompProbeForPyTorchInstall)
|
|
if(Torch_ROOT)
|
|
message(STATUS "Using cached Torch root = ${Torch_ROOT}")
|
|
else()
|
|
message(STATUS "Checking for PyTorch using ${PYTHON_EXECUTABLE} ...")
|
|
execute_process(
|
|
COMMAND ${PYTHON_EXECUTABLE}
|
|
-c "import os;import torch;print(torch.utils.cmake_prefix_path, end='')"
|
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
|
RESULT_VARIABLE PYTORCH_STATUS
|
|
OUTPUT_VARIABLE PYTORCH_PACKAGE_DIR)
|
|
if(NOT PYTORCH_STATUS EQUAL "0")
|
|
message(STATUS "Unable to 'import torch' with ${PYTHON_EXECUTABLE} (fallback to explicit config)")
|
|
return()
|
|
endif()
|
|
message(STATUS "Found PyTorch installation at ${PYTORCH_PACKAGE_DIR}")
|
|
|
|
set(Torch_ROOT "${PYTORCH_PACKAGE_DIR}" CACHE STRING
|
|
"Torch configure directory" FORCE)
|
|
endif()
|
|
endfunction()
|
|
|
|
# NpcompConfigurePyTorch
|
|
# Performs configuration of PyTorch flags after CMake has found it to be
|
|
# present. Most of this comes down to detecting whether building against a
|
|
# source or official binary and adjusting compiler options in the latter case
|
|
# (in the former, we assume that it was built with system defaults). We do this
|
|
# conservatively and assume non-binary builds by default.
|
|
#
|
|
# In the future, we may want to switch away from custom building these
|
|
# extensions and instead rely on the Torch machinery directly (definitely want
|
|
# to do that for official builds).
|
|
function(NpcompConfigurePyTorch)
|
|
if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
|
|
# Linux specific libstdcpp ABI checking.
|
|
message(STATUS "Checking if Torch is an official binary ...")
|
|
execute_process(
|
|
COMMAND ${PYTHON_EXECUTABLE}
|
|
-c "from torch.utils import cpp_extension as c; import sys; sys.exit(0 if c._is_binary_build() else 1)"
|
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
|
RESULT_VARIABLE _is_binary_build)
|
|
if(${_is_binary_build} EQUAL 0)
|
|
set(TORCH_CXXFLAGS "")
|
|
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
|
|
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11")
|
|
elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
|
|
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=1011 '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'")
|
|
else()
|
|
message(WARNING "Unrecognized compiler. Cannot determine ABI flags.")
|
|
return()
|
|
endif()
|
|
message(STATUS "Detected Torch official binary build. Setting ABI flags: ${TORCH_CXXFLAGS}")
|
|
set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE)
|
|
endif()
|
|
endif()
|
|
endfunction()
|