torch-mlir/cmake/modules/ConfigurePyTorch.cmake

81 lines
3.3 KiB
CMake

# NpcompFindPyTorch
# Calls find_package on Torch and does any needed post-processing.
# The enable_pytorch flag can be OFF, ON or OPTIONAL.
macro(NpcompFindPyTorch enable_pytorch)
if(${enable_pytorch} OR ${enable_pytorch} STREQUAL "OPTIONAL")
NpcompProbeForPyTorchInstall()
if(${enable_pytorch} STREQUAL "OPTIONAL")
find_package(Torch 1.8)
else()
find_package(Torch 1.8 REQUIRED)
endif()
if(${TORCH_FOUND})
NpcompConfigurePyTorch()
endif()
else()
message(STATUS "Not configuring PyTorch (disabled)")
endif()
endmacro()
# NpcompProbeForPyTorchInstall
# Attempts to find a Torch installation and set the Torch_ROOT variable
# based on introspecting the python environment. This allows a subsequent
# call to find_package(Torch) to work.
function(NpcompProbeForPyTorchInstall)
if(Torch_ROOT)
message(STATUS "Using cached Torch root = ${Torch_ROOT}")
else()
message(STATUS "Checking for PyTorch using ${PYTHON_EXECUTABLE} ...")
execute_process(
COMMAND ${PYTHON_EXECUTABLE}
-c "import os;import torch;print(torch.utils.cmake_prefix_path, end='')"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE PYTORCH_STATUS
OUTPUT_VARIABLE PYTORCH_PACKAGE_DIR)
if(NOT PYTORCH_STATUS EQUAL "0")
message(STATUS "Unable to 'import torch' with ${PYTHON_EXECUTABLE} (fallback to explicit config)")
return()
endif()
message(STATUS "Found PyTorch installation at ${PYTORCH_PACKAGE_DIR}")
set(Torch_ROOT "${PYTORCH_PACKAGE_DIR}" CACHE STRING
"Torch configure directory" FORCE)
endif()
endfunction()
# NpcompConfigurePyTorch
# Performs configuration of PyTorch flags after CMake has found it to be
# present. Most of this comes down to detecting whether building against a
# source or official binary and adjusting compiler options in the latter case
# (in the former, we assume that it was built with system defaults). We do this
# conservatively and assume non-binary builds by default.
#
# In the future, we may want to switch away from custom building these
# extensions and instead rely on the Torch machinery directly (definitely want
# to do that for official builds).
function(NpcompConfigurePyTorch)
if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
# Linux specific libstdcpp ABI checking.
message(STATUS "Checking if Torch is an official binary ...")
execute_process(
COMMAND ${PYTHON_EXECUTABLE}
-c "from torch.utils import cpp_extension as c; import sys; sys.exit(0 if c._is_binary_build() else 1)"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE _is_binary_build)
if(${_is_binary_build} EQUAL 0)
set(TORCH_CXXFLAGS "")
if(${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU")
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11")
elseif(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
set(TORCH_CXXFLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 -U__GXX_ABI_VERSION -D__GXX_ABI_VERSION=1011 '-DPYBIND11_COMPILER_TYPE=\"_gcc\"'")
else()
message(WARNING "Unrecognized compiler. Cannot determine ABI flags.")
return()
endif()
message(STATUS "Detected Torch official binary build. Setting ABI flags: ${TORCH_CXXFLAGS}")
set(TORCH_CXXFLAGS "${TORCH_CXXFLAGS}" PARENT_SCOPE)
endif()
endif()
endfunction()