From 2c3b3606d058ea0e109907fe47db5992ea7eefe0 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Sat, 30 Jul 2022 05:54:40 -0400 Subject: [PATCH] Resolve remaining LTC CI failures (#1110) * Replace CHECK_EQ with TORCH_CHECK_EQ * Check value of TORCH_MLIR_USE_INSTALLED_PYTORCH during LTC build * Update LTC XFAIL with NewZerosModule ops * Explicitly blacklist _like ops * Automatically blacklist new_/_like ops * Prune away unused Python dependencies from LTC * Add flag to disable LTC * Autogen dummy _REFERENCE_LAZY_BACKEND library when LTC is disabled * Implement compute_shape_var * Removed Var tests from XFAIL Set * XFAIL tests using _local_scalar_dense or index.Tensor * Add StdDim tests to XFAIL set * Autogen aten::cat --- .github/workflows/buildAndTest.yml | 1 + CMakeLists.txt | 2 + build_tools/autogen_ltc_backend.py | 5 +- build_tools/autogen_ltc_backend.yaml | 1 - e2e_testing/torchscript/xfail_sets.py | 25 +++--- python/CMakeLists.txt | 56 +++++++----- .../csrc/base_lazy_backend/CMakeLists.txt | 10 ++- .../mlir_lowering_context.cpp | 2 +- .../mlir_native_functions.cpp | 19 ---- .../base_lazy_backend/mlir_node_lowering.cpp | 4 +- .../csrc/base_lazy_backend/ops/to_copy.h | 2 +- .../base_lazy_backend/shape_inference.cpp | 7 ++ .../reference_lazy_backend/CMakeLists.txt | 87 +++++++++++-------- .../reference_lazy_backend/gen_dummy_lib.py | 23 +++++ .../torch/importer/jit_ir/csrc/CMakeLists.txt | 9 +- 15 files changed, 154 insertions(+), 99 deletions(-) create mode 100755 python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 38943aedb..913014f50 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -96,6 +96,7 @@ jobs: -DMLIR_ENABLE_BINDINGS_PYTHON=OFF \ -DTORCH_MLIR_ENABLE_MHLO=ON \ -DTORCH_MLIR_USE_INSTALLED_PYTORCH=${{ matrix.torch-binary }} \ + -DTORCH_MLIR_ENABLE_LTC=OFF \ -DPython3_EXECUTABLE=$(which python) \ . diff --git a/CMakeLists.txt b/CMakeLists.txt index 10432d1e9..00340141c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,8 @@ if(TORCH_MLIR_ENABLE_MHLO) endif() endif() +option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" ON) + torch_mlir_add_llvm_external_project( torch-mlir-dialects TORCH_MLIR_DIALECTS diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 44c56b645..ff31b95b4 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -86,7 +86,7 @@ class GenMlirLazyIr(torchgen.dest.GenLazyIR): {emplace_arguments_str} {emplace_kwarguments} torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments); - CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); + TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)}); return {schema.aten_name}_out; }} @@ -236,6 +236,9 @@ class GenTorchMlirLTC: continue if base in supported or op in supported: continue + # Blacklist new_/_like ops since they are non-differentiable. + if any(o.startswith("new_") or o.endswith("_like") for o in (base, op)): + continue if func.has_composite_implicit_autograd_kernel: composite_implicit.add(op) diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index ff6c7f5ee..505701eaa 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -43,7 +43,6 @@ supported: # - bernoulli # - bernoulli_ - _to_copy -- cat - clone - empty.memory_format - empty_strided diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 8856ee0af..7e0e43918 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -227,6 +227,7 @@ LTC_XFAIL_SET = { "FullLikeModuleInt3D_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", + "GeIntModule_basic", "GtFloatIntModule_basic", "GtIntModule_basic", "HBC_basic", @@ -266,6 +267,11 @@ LTC_XFAIL_SET = { "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexTensorModule3dInput_basic", "IndexTensorModule_basic", + "IndexTensorMultiInputContiguousCenter_basic", + "IndexTensorMultiInputNonContiguous_basic", + "IndexTensorMultiInputOneDim_basic", + "IndexTensorMultiInputThreeIndexers_basic", + "IndexTensorMultiInput_basic", "IndexTensorSelectDimModule_basic", "Matmul_dot", "Matmul_matvec", @@ -288,6 +294,12 @@ LTC_XFAIL_SET = { "NewOnesModuleFloat3D_basic", "NewOnesModuleInt2D_basic", "NewOnesModuleInt3D_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", "OnesLikeModule_defaultDtype", "OnesLikeModule_falsePinMemory", "OnesLikeModule_float", @@ -302,6 +314,9 @@ LTC_XFAIL_SET = { "SliceStartEqEndModule_basic", "SqrtIntModule_basic", "StdBiasedModule_basic", + "StdDimBiasedModule_basic", + "StdDimKeepDimFalseModule_basic", + "StdDimKeepDimTrueModule_basic", "StdUnbiasedModule_basic", "SubFloatModule_basic", "SubIntModule_basic", @@ -317,15 +332,5 @@ LTC_XFAIL_SET = { "UniformModule_basic", "UniformStaticModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", - "VarBiasedModule_basic", - "VarDimAllDimReduceModule_basic", - "VarDimBiasedModule_basic", - "VarDimKeepDimFalseModule_basic", - "VarDimModule_basic", - "VarDimMultiDimModule_basic", - "VarDimNegativeModule_basic", - "VarDimSingleDimModule_basic", - "VarDimUnbiasedModule_basic", - "VarUnbiasedModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", } diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 86e222beb..304f4a11c 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -13,6 +13,30 @@ set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir") # We vendor our own MLIR instance in the `torch_mlir` namespace. add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") +################################################################################ +# PyTorch +################################################################################ + +option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON) + +if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) + # Source builds + set(ENV{PYTORCH_REPO} ${PYTORCH_REPO}) + set(ENV{PYTORCH_BRANCH} ${PYTORCH_BRANCH}) + set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET}) + set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) + set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) + set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) + execute_process( + COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/../build_tools/build_libtorch.sh + RESULT_VARIABLE _result + ) + if(_result) + message(FATAL_ERROR "Failed to run `build_libtorch.sh`") + endif() + set(TORCH_INSTALL_PREFIX "libtorch") +endif() + ################################################################################ # Sources ################################################################################ @@ -60,33 +84,17 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main # Lazy Tensor Core ################################################################################ -add_subdirectory(torch_mlir/csrc/base_lazy_backend) +if(TORCH_MLIR_ENABLE_LTC) + add_subdirectory(torch_mlir/csrc/base_lazy_backend) +endif() +# Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC. add_subdirectory(torch_mlir/csrc/reference_lazy_backend) ################################################################################ # Optionally handle JIT IR importer. ################################################################################ -option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON) - if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) - if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) - # Source builds - set(ENV{PYTORCH_REPO} ${PYTORCH_REPO}) - set(ENV{PYTORCH_BRANCH} ${PYTORCH_BRANCH}) - set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET}) - set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) - set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) - set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) - execute_process( - COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/../build_tools/build_libtorch.sh - RESULT_VARIABLE _result - ) - if(_result) - message(FATAL_ERROR "Failed to run `build_libtorch.sh`") - endif() - set(TORCH_INSTALL_PREFIX "libtorch") - endif() add_subdirectory(torch_mlir/dialects/torch/importer/jit_ir) add_subdirectory(torch_mlir_e2e_test) endif() @@ -154,8 +162,10 @@ endif() # TODO: Add after macOS builds are fixed #add_dependencies(TorchMLIRPythonModules torch_mlir_custom_op_example) -# Add Torch-MLIR LTC backend as dependency -add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend) -add_dependencies(TorchMLIRPythonModules reference_lazy_backend) +if(TORCH_MLIR_ENABLE_LTC) + # Add Torch-MLIR LTC backend as dependency + add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend) + add_dependencies(TorchMLIRPythonModules reference_lazy_backend) +endif() add_subdirectory(test) diff --git a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt index 4b0df6233..db34a8e12 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt +++ b/python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt @@ -5,10 +5,16 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") include(TorchMLIRPyTorch) + TorchMLIRProbeForPyTorchInstall() +if(TORCH_MLIR_USE_INSTALLED_PYTORCH) + TorchMLIRConfigurePyTorch() +else() + set(Torch_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libtorch/share/cmake/Torch") +endif() + find_package(Torch 1.11 REQUIRED) -TorchMLIRConfigurePyTorch() set(TORCHGEN_DIR ${Torch_ROOT}/../../../torchgen) include_directories(BEFORE @@ -76,8 +82,6 @@ target_link_libraries(torch_mlir_ltc_backend TorchMLIRAggregateCAPI TorchMLIRJITIRImporter ${TORCH_LIBRARIES} - ${Python3_LIBRARIES} - torch_python ) message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic") diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp index a2d58145a..fdef62719 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -64,7 +64,7 @@ void TorchMlirLoweringContext::Lower(const Node* node) { dynamic_cast(node)) { TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this); CHECK(!ops.empty()) << "Failed to lower: " << *node; - CHECK_EQ(node->num_outputs(), ops.size()); + TORCH_CHECK_EQ(node->num_outputs(), ops.size()); for (size_t i = 0; i < ops.size(); ++i) { AssignOutputOp(torch::lazy::Output(node, i), ops[i]); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp index eb793615b..e197af3e5 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -154,25 +154,6 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { // // return self; // } -at::Tensor LazyNativeFunctions::cat(at::TensorList tensors, int64_t dim) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto lazy_tensors = torch::lazy::GetLtcTensors(tensors); - std::vector values; - values.reserve(lazy_tensors.size()); - for (auto& tensor : lazy_tensors) { - values.emplace_back(tensor->GetIrValue()); - } - - auto shapes = torch::lazy::compute_shape_cat(tensors, dim); - UNIMPLEMENTED_FUNCTION_ERROR(); - // auto node = - // torch::lazy::MakeNode(values, dim, std::move(shapes)); - // auto result = torch::lazy::CreateAtenFromLtcTensor( - // torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 0), - // lazy_tensors[0]->GetDevice())); - // return result; -} - // clone is special in LT because we make it a no-op. // This should be safe to do, because every operator in the LT is functional. at::Tensor LazyNativeFunctions::clone( diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp index a18a20e78..e3d4fab86 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -205,7 +205,7 @@ GenerateClone(torch::jit::Value* val, TorchMlirFunction function) { // Type of cloned value should be identical to the original one. TorchMlirOpVector cloned = LowerBuiltin(at::aten::clone, {val->type()}, function, clone_arguments); - CHECK_EQ(cloned.size(), 1); + TORCH_CHECK_EQ(cloned.size(), 1); return cloned.front(); } @@ -235,7 +235,7 @@ torch::jit::Value* GenerateSlice( c10::ArrayRef( compute_shape_slice(base->type(), dim, start, end, step)), function, arguments); - CHECK_EQ(selected.size(), 1); + TORCH_CHECK_EQ(selected.size(), 1); return selected.front(); } diff --git a/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h b/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h index 311d97f90..c6b75baaf 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h +++ b/python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h @@ -84,7 +84,7 @@ class ToCopy : public torch::lazy::TorchMlirNode { kwarguments.emplace_back("non_blocking", non_blocking); kwarguments.emplace_back("memory_format", memory_format); torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments); - CHECK_EQ(_to_copy_out.size(), 1); + TORCH_CHECK_EQ(_to_copy_out.size(), 1); return _to_copy_out; diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index 2ad1c962d..48004d9d3 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -29,5 +29,12 @@ compute_shape_mul(const at::Tensor& self, const at::Scalar& other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape_var( + const at::Tensor& self, at::OptionalIntArrayRef dim, + c10::optional correction, bool keepdim) { + // Result of variance is scalar tensor. + return {Shape(self.scalar_type(), {})}; +} + } // namespace lazy } // namespace torch diff --git a/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt b/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt index 8585aaf73..7340890d9 100644 --- a/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt +++ b/python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt @@ -4,10 +4,15 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") include(TorchMLIRPyTorch) -TorchMLIRProbeForPyTorchInstall() -find_package(Torch 1.11 REQUIRED) -TorchMLIRConfigurePyTorch() +TorchMLIRProbeForPyTorchInstall() +if(TORCH_MLIR_USE_INSTALLED_PYTORCH) + TorchMLIRConfigurePyTorch() +else() + set(Torch_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libtorch/share/cmake/Torch") +endif() + +find_package(Torch 1.11 REQUIRED) ########################################################################### # Setup Python development @@ -21,39 +26,47 @@ mlir_configure_python_dev_packages() # Library definition ########################################################################### -include_directories(BEFORE - ${TORCH_INCLUDE_DIRS} - ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_BINARY_DIR} - ${Python3_INCLUDE_DIRS} - ${PYTHON_H_DIR} - ${PROJECT_SOURCE_DIR}/python - ) -link_directories("${TORCH_INSTALL_PREFIX}/lib") -link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib) -add_link_options(-Wl,-rpath,$ORIGIN/lib) +set(LIBRARY_OUTPUT_PATH "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/reference_lazy_backend") +set(OUTPUT_NAME "_REFERENCE_LAZY_BACKEND") -set(REFERENCE_LAZY_BACKEND_CSRC - backend_impl.cpp - reference_lazy_backend_pybind.cpp - ) -add_library(reference_lazy_backend SHARED ${REFERENCE_LAZY_BACKEND_CSRC}) -add_dependencies(reference_lazy_backend - torch_mlir_ltc_backend - ) -target_link_libraries(reference_lazy_backend - ${TORCH_LIBRARIES} - ${Python3_LIBRARIES} - torch_python - torch_mlir_ltc_backend - ) +if(TORCH_MLIR_ENABLE_LTC) + include_directories(BEFORE + ${TORCH_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + ${Python3_INCLUDE_DIRS} + ${PYTHON_H_DIR} + ${PROJECT_SOURCE_DIR}/python + ) + link_directories("${TORCH_INSTALL_PREFIX}/lib") + link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib) + add_link_options(-Wl,-rpath,$ORIGIN/lib) -message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic") -set_target_properties(reference_lazy_backend PROPERTIES - LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/reference_lazy_backend" - OUTPUT_NAME _REFERENCE_LAZY_BACKEND - PREFIX "${PYTHON_MODULE_PREFIX}" - SUFFIX "${PYTHON_MODULE_EXTENSION}" - CXX_VISIBILITY_PRESET "hidden" - COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic" - ) + add_library(reference_lazy_backend SHARED + backend_impl.cpp + reference_lazy_backend_pybind.cpp + ) + add_dependencies(reference_lazy_backend + torch_mlir_ltc_backend + ) + target_link_libraries(reference_lazy_backend + ${TORCH_LIBRARIES} + torch_mlir_ltc_backend + ) + + message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic") + set_target_properties(reference_lazy_backend PROPERTIES + LIBRARY_OUTPUT_DIRECTORY ${LIBRARY_OUTPUT_PATH} + OUTPUT_NAME ${OUTPUT_NAME} + PREFIX "${PYTHON_MODULE_PREFIX}" + SUFFIX "${PYTHON_MODULE_EXTENSION}" + CXX_VISIBILITY_PRESET "hidden" + COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic" + ) +else() + # To avoid import errors when LTC is disabled (and a bunch of checks + # associated with that), we will generate a dummy placeholder library. + execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/gen_dummy_lib.py ${LIBRARY_OUTPUT_PATH} ${OUTPUT_NAME} + ) +endif() diff --git a/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py b/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py new file mode 100755 index 000000000..34c9e6190 --- /dev/null +++ b/python/torch_mlir/csrc/reference_lazy_backend/gen_dummy_lib.py @@ -0,0 +1,23 @@ +# When LTC is disabled in Torch-MLIR build, we will generate a dummy module to +# ensure that no import errors occur. + +import sys +import os + +if __name__ == '__main__': + path = sys.argv[1] # dummy script path + file_name = sys.argv[2] # dummy script + + contents = ''' +# This file was automatically generated due to LTC being disabled in build. + +class LazyTensorCoreTestConfig: + def __init__(self): + assert False, "LTC is not enabled. Check the value of `TORCH_MLIR_ENABLE_LTC`" + ''' + + if not os.path.exists(path): + os.makedirs(path) + + with open(os.path.join(path, file_name + '.py'), 'w') as file: + file.write(contents) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt index 3dfaf67be..043975990 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt @@ -10,7 +10,14 @@ include_directories(BEFORE ) link_directories("${TORCH_INSTALL_PREFIX}/lib") -add_library(TorchMLIRJITIRImporter SHARED +# Hack! Currently out-of-tree build fails when this is set to SHARED, so we have this toggle +if(TORCH_MLIR_ENABLE_LTC) + set(LIBRARY_TYPE "SHARED") +else() + set(LIBRARY_TYPE "MODULE") +endif() + +add_library(TorchMLIRJITIRImporter ${LIBRARY_TYPE} class_annotator.cpp class_annotator_pybind.cpp get_registered_ops.cpp