mirror of https://github.com/llvm/torch-mlir
use fixed torch and torchvision
parent
37e89828a1
commit
3507121764
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -x
|
||||
set -e
|
||||
|
||||
cmake -GNinja -Bbuild \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_C_COMPILER=clang \
|
||||
-DCMAKE_CXX_COMPILER=clang++ \
|
||||
-DCMAKE_LINKER=lld \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DLLVM_ENABLE_PROJECTS=mlir \
|
||||
-DLLVM_EXTERNAL_PROJECTS="torch-mlir" \
|
||||
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$(pwd)" \
|
||||
-DLLVM_TARGETS_TO_BUILD=host \
|
||||
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
|
||||
-DTORCH_MLIR_USE_INSTALLED_PYTORCH=ON \
|
||||
-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS=ON \
|
||||
-DTORCH_MLIR_ENABLE_JIT_IR_IMPORTER=ON \
|
||||
-DTORCH_MLIR_ENABLE_LTC=OFF \
|
||||
$(pwd)/externals/llvm-project/llvm
|
||||
|
||||
cmake --build build --target TorchMLIRPythonModules TorchMLIRJITIRImporterPybind check-torch-mlir-pt1 check-torch-mlir
|
||||
|
|
@ -41,7 +41,7 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
|
|||
ext_module="${TORCH_MLIR_EXT_MODULES} "
|
||||
fi
|
||||
|
||||
PYTHONPATH="${pypath}" python \
|
||||
PYTHONPATH="${pypath}" python3 \
|
||||
-m torch_mlir.jit_ir_importer.build_tools.abstract_interp_lib_gen \
|
||||
--pytorch_op_extensions=${ext_module:-""} \
|
||||
--torch_transforms_cpp_dir="${torch_transforms_cpp_dir}"
|
||||
|
|
|
@ -42,7 +42,7 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
|
|||
fi
|
||||
|
||||
set +u
|
||||
PYTHONPATH="${PYTHONPATH}:${pypath}" python \
|
||||
PYTHONPATH="${pypath}" python3 \
|
||||
-m torch_mlir.jit_ir_importer.build_tools.torch_ods_gen \
|
||||
--torch_ir_include_dir="${torch_ir_include_dir}" \
|
||||
--pytorch_op_extensions="${ext_module}" \
|
||||
|
|
|
@ -11480,7 +11480,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" } else {\n"
|
||||
" %11 = torch.aten.eq.int %2#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %12:2 = torch.prim.If %11 -> (!torch.bool, !torch.tuple<int, int>) {\n"
|
||||
" %13 = torch.prim.TupleConstruct %1#1, %int6 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" %13 = torch.prim.TupleConstruct %1#1, %int15 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" torch.prim.If.yield %true, %13 : !torch.bool, !torch.tuple<int, int>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.tuple<int, int>\n"
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
--pre
|
||||
torch==2.4.0+cpu
|
||||
torchvision==0.19.0+cpu
|
|
@ -0,0 +1,4 @@
|
|||
-f https://download.pytorch.org/whl/cpu
|
||||
--pre
|
||||
torch==2.4.0
|
||||
torchvision==0.19.0
|
|
@ -2710,7 +2710,7 @@ def aten〇_weight_norm_interface〡dtype(v_rank_dtype: Tuple[int, int], g_rank_
|
|||
elif g_dtype == torch.complex64:
|
||||
return v_dtype, torch.float32
|
||||
elif g_dtype == torch.bfloat16:
|
||||
return v_dtype, torch.float32
|
||||
return v_dtype, torch.bfloat16
|
||||
return v_dtype, g_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
|
|
Loading…
Reference in New Issue