diff options
6 files changed, 31 insertions, 14 deletions
diff --git a/var/spack/repos/builtin/packages/cudnn/package.py b/var/spack/repos/builtin/packages/cudnn/package.py index 5b5c10c44e..38dc8c88a3 100644 --- a/var/spack/repos/builtin/packages/cudnn/package.py +++ b/var/spack/repos/builtin/packages/cudnn/package.py @@ -9,6 +9,14 @@ import platform from spack.package import * _versions = { + # cuDNN 9.2.0 + "9.2.0.82-12": { + "Linux-x86_64": "1362b4d437e37e92c9814c3b4065db5106c2e03268e22275a5869e968cee7aa8", + "Linux-aarch64": "24cc2a0308dfe412c02c7d41d4b07ec12dacb021ebf8c719de38eb77d22f68c1", + }, + "9.2.0.82-11": { + "Linux-x86_64": "99dcb3fa2bf7eed7f35b0f8e58e7d1f04d9a52e01e382efc1de16fed230d3b26" + }, # cuDNN 8.9.7 "8.9.7.29-12": { "Linux-x86_64": "475333625c7e42a7af3ca0b2f7506a106e30c93b1aa0081cd9c13efb6e21e3bb", diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py index 4b6136242d..31e5f8d1a5 100644 --- a/var/spack/repos/builtin/packages/py-jax/package.py +++ b/var/spack/repos/builtin/packages/py-jax/package.py @@ -24,6 +24,7 @@ class PyJax(PythonPackage): license("Apache-2.0") maintainers("adamjstewart", "jonas-eschle") + version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186") version("0.4.28", sha256="dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9") version("0.4.27", sha256="f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c") version("0.4.26", sha256="2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383") @@ -56,6 +57,7 @@ class PyJax(PythonPackage): with default_args(type=("build", "run")): # setup.py depends_on("python@3.9:", when="@0.4.14:") + depends_on("py-ml-dtypes@0.4:", when="@0.4.29:") depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") @@ -71,6 +73,7 @@ class PyJax(PythonPackage): # jax/_src/lib/__init__.py # https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4 for v in [ + "0.4.29", "0.4.28", "0.4.27", "0.4.26", diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index 7c19784bf9..166bbc0474 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -18,8 +18,9 @@ class PyJaxlib(PythonPackage, CudaPackage): buildtmp = "" license("Apache-2.0") - maintainers("adamjstewart") + maintainers("adamjstewart", "jonas-eschle") + version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246") version("0.4.28", sha256="4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b") version("0.4.27", sha256="c2c82cd9ad3b395d5cbc0affa26a2938e52677a69ca8f0b9ef9922a52cac4f0c") version("0.4.26", sha256="ddc14da1eaa34f23430d40ad9b9585088575cac439a2fa1c6833a247e1b221fd") @@ -46,9 +47,10 @@ class PyJaxlib(PythonPackage, CudaPackage): depends_on("cuda@12.1:", when="@0.4.26:") depends_on("cuda@11.8:", when="@0.4.11:") depends_on("cuda@11.4:", when="@0.4.0:0.4.7") - depends_on("cudnn@8.9:8", when="@0.4.26:") - depends_on("cudnn@8.8:", when="@0.4.11:") - depends_on("cudnn@8.2:", when="@0.4:0.4.7") + depends_on("cudnn@9", when="@0.4.29:") + depends_on("cudnn@8.9:8", when="@0.4.26:0.4.28") + depends_on("cudnn@8.8:8", when="@0.4.11:0.4.25") + depends_on("cudnn@8.2:8", when="@0.4:0.4.7") with when("+nccl"): depends_on("nccl@2.18:", when="@0.4.26:") @@ -80,6 +82,7 @@ class PyJaxlib(PythonPackage, CudaPackage): depends_on("py-numpy@1.22:", when="@0.4.14:") depends_on("py-numpy@1.21:", when="@0.4.7:") depends_on("py-numpy@1.20:", when="@0.3:") + depends_on("py-ml-dtypes@0.4:", when="@0.4.29:") depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") diff --git a/var/spack/repos/builtin/packages/py-ml-dtypes/package.py b/var/spack/repos/builtin/packages/py-ml-dtypes/package.py index 5e3f85ab7b..c90978d2b5 100644 --- a/var/spack/repos/builtin/packages/py-ml-dtypes/package.py +++ b/var/spack/repos/builtin/packages/py-ml-dtypes/package.py @@ -17,11 +17,13 @@ class PyMlDtypes(PythonPackage): license("Apache-2.0") + version("0.4.0", tag="v0.4.0", commit="9fc7e6773acb66fa496ed8d476a008a489a4da49") version("0.3.1", tag="v0.3.1", commit="bbeedd470ecac727c42e97648c0f27bfc312af30") version("0.2.0", tag="v0.2.0", commit="5b9fc9ad978757654843f4a8d899715dbea30e88") depends_on("python@3.9:", when="@0.3:", type=("build", "link", "run")) - depends_on("py-numpy@1.21:", type=("build", "link", "run")) + depends_on("py-numpy@1.21:", when="@0.4:", type=("build", "link", "run")) + depends_on("py-numpy@1.21:1", when="@:0.3", type=("build", "link", "run")) # Build dependencies are overconstrained, older versions work just fine - depends_on("py-pybind11", type=("build", "link")) + depends_on("py-pybind11", when="@:0.3.1", type=("build", "link")) depends_on("py-setuptools", type="build") diff --git a/var/spack/repos/builtin/packages/py-tensorflow/package.py b/var/spack/repos/builtin/packages/py-tensorflow/package.py index 627ee26710..028bf1e97b 100644 --- a/var/spack/repos/builtin/packages/py-tensorflow/package.py +++ b/var/spack/repos/builtin/packages/py-tensorflow/package.py @@ -305,12 +305,12 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension): depends_on("cuda@:11.4", when="@2.4:2.7") depends_on("cuda@:10.2", when="@:2.3") - depends_on("cudnn@8.9:", when="@2.15:") - depends_on("cudnn@8.7:", when="@2.14:") - depends_on("cudnn@8.6:", when="@2.12:") - depends_on("cudnn@8.1:", when="@2.5:") - depends_on("cudnn@8.0:", when="@2.4:") - depends_on("cudnn@7.6:", when="@2.1:") + depends_on("cudnn@8.9:8", when="@2.15:") + depends_on("cudnn@8.7:8", when="@2.14:") + depends_on("cudnn@8.6:8", when="@2.12:") + depends_on("cudnn@8.1:8", when="@2.5:") + depends_on("cudnn@8.0:8", when="@2.4:") + depends_on("cudnn@7.6:8", when="@2.1:") depends_on("cudnn@:7", when="@:2.2") # depends_on('tensorrt', when='+tensorrt') diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index 672b8b0201..e2616d2972 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -247,8 +247,9 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): depends_on("cuda@9.2:11.4", when="@1.6:1.9+cuda") depends_on("cuda@9:11.4", when="@:1.5+cuda") # https://github.com/pytorch/pytorch#prerequisites - depends_on("cudnn@8.5:", when="@2.3:+cudnn") - depends_on("cudnn@7:", when="@1.6:+cudnn") + # https://github.com/pytorch/pytorch/issues/119400 + depends_on("cudnn@8.5:9.0", when="@2.3:+cudnn") + depends_on("cudnn@7:8", when="@1.6:2.2+cudnn") depends_on("cudnn@7", when="@:1.5+cudnn") depends_on("magma+cuda", when="+magma+cuda") depends_on("magma+rocm", when="+magma+rocm") |