diff --git a/build_tools/build_python_wheels.sh b/build_tools/build_python_wheels.sh index 3e2ff4ca9..c766323e6 100755 --- a/build_tools/build_python_wheels.sh +++ b/build_tools/build_python_wheels.sh @@ -23,12 +23,13 @@ echo "---- CREATING VENV ----" python -m venv "$package_test_venv" VENV_PYTHON="$package_test_venv/bin/python" -echo "---- INSTALLING torch ----" -$VENV_PYTHON -m pip install -r "${repo_root}/requirements.txt" +# Install the Torch-MLIR package. +# Note that we also need to pass in the `-r requirements.txt` here to pick up +# the right --find-links flag for the nightly PyTorch wheel registry. +echo "---- INSTALLING torch-mlir and dependencies ----" +$VENV_PYTHON -m pip install -f "$wheelhouse" --force-reinstall torch_mlir -r "${repo_root}/requirements.txt" echo "---- INSTALLING other deps for smoke test ----" $VENV_PYTHON -m pip install requests pillow -echo "---- INSTALLING torch-mlir ----" -$VENV_PYTHON -m pip install -f "$wheelhouse" --force-reinstall torch_mlir echo "---- RUNNING SMOKE TEST ----" $VENV_PYTHON "$repo_root/examples/torchscript_resnet18.py" diff --git a/setup.py b/setup.py index be26f714b..ab34ead8c 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,8 @@ from setuptools import setup, Extension from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py +import torch + PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1" # Build phase discovery is unreliable. Just tell it what phases to run. @@ -127,5 +129,10 @@ setup( ext_modules=[ CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), ], + install_requires=[ + # To avoid issues with drift for each nightly build, we pin to the + # exact version we built against. + f"torch=={torch.__version__}", + ], zip_safe=False, )