summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--var/spack/repos/builtin/packages/cudnn/package.py8
-rw-r--r--var/spack/repos/builtin/packages/py-jax/package.py3
-rw-r--r--var/spack/repos/builtin/packages/py-jaxlib/package.py11
-rw-r--r--var/spack/repos/builtin/packages/py-ml-dtypes/package.py6
-rw-r--r--var/spack/repos/builtin/packages/py-tensorflow/package.py12
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py5
6 files changed, 31 insertions, 14 deletions
diff --git a/var/spack/repos/builtin/packages/cudnn/package.py b/var/spack/repos/builtin/packages/cudnn/package.py
index 5b5c10c44e..38dc8c88a3 100644
--- a/var/spack/repos/builtin/packages/cudnn/package.py
+++ b/var/spack/repos/builtin/packages/cudnn/package.py
@@ -9,6 +9,14 @@ import platform
from spack.package import *
_versions = {
+ # cuDNN 9.2.0
+ "9.2.0.82-12": {
+ "Linux-x86_64": "1362b4d437e37e92c9814c3b4065db5106c2e03268e22275a5869e968cee7aa8",
+ "Linux-aarch64": "24cc2a0308dfe412c02c7d41d4b07ec12dacb021ebf8c719de38eb77d22f68c1",
+ },
+ "9.2.0.82-11": {
+ "Linux-x86_64": "99dcb3fa2bf7eed7f35b0f8e58e7d1f04d9a52e01e382efc1de16fed230d3b26"
+ },
# cuDNN 8.9.7
"8.9.7.29-12": {
"Linux-x86_64": "475333625c7e42a7af3ca0b2f7506a106e30c93b1aa0081cd9c13efb6e21e3bb",
diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py
index 4b6136242d..31e5f8d1a5 100644
--- a/var/spack/repos/builtin/packages/py-jax/package.py
+++ b/var/spack/repos/builtin/packages/py-jax/package.py
@@ -24,6 +24,7 @@ class PyJax(PythonPackage):
license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle")
+ version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
version("0.4.28", sha256="dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9")
version("0.4.27", sha256="f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c")
version("0.4.26", sha256="2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383")
@@ -56,6 +57,7 @@ class PyJax(PythonPackage):
with default_args(type=("build", "run")):
# setup.py
depends_on("python@3.9:", when="@0.4.14:")
+ depends_on("py-ml-dtypes@0.4:", when="@0.4.29:")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
@@ -71,6 +73,7 @@ class PyJax(PythonPackage):
# jax/_src/lib/__init__.py
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
for v in [
+ "0.4.29",
"0.4.28",
"0.4.27",
"0.4.26",
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py
index 7c19784bf9..166bbc0474 100644
--- a/var/spack/repos/builtin/packages/py-jaxlib/package.py
+++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py
@@ -18,8 +18,9 @@ class PyJaxlib(PythonPackage, CudaPackage):
buildtmp = ""
license("Apache-2.0")
- maintainers("adamjstewart")
+ maintainers("adamjstewart", "jonas-eschle")
+ version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
version("0.4.28", sha256="4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b")
version("0.4.27", sha256="c2c82cd9ad3b395d5cbc0affa26a2938e52677a69ca8f0b9ef9922a52cac4f0c")
version("0.4.26", sha256="ddc14da1eaa34f23430d40ad9b9585088575cac439a2fa1c6833a247e1b221fd")
@@ -46,9 +47,10 @@ class PyJaxlib(PythonPackage, CudaPackage):
depends_on("cuda@12.1:", when="@0.4.26:")
depends_on("cuda@11.8:", when="@0.4.11:")
depends_on("cuda@11.4:", when="@0.4.0:0.4.7")
- depends_on("cudnn@8.9:8", when="@0.4.26:")
- depends_on("cudnn@8.8:", when="@0.4.11:")
- depends_on("cudnn@8.2:", when="@0.4:0.4.7")
+ depends_on("cudnn@9", when="@0.4.29:")
+ depends_on("cudnn@8.9:8", when="@0.4.26:0.4.28")
+ depends_on("cudnn@8.8:8", when="@0.4.11:0.4.25")
+ depends_on("cudnn@8.2:8", when="@0.4:0.4.7")
with when("+nccl"):
depends_on("nccl@2.18:", when="@0.4.26:")
@@ -80,6 +82,7 @@ class PyJaxlib(PythonPackage, CudaPackage):
depends_on("py-numpy@1.22:", when="@0.4.14:")
depends_on("py-numpy@1.21:", when="@0.4.7:")
depends_on("py-numpy@1.20:", when="@0.3:")
+ depends_on("py-ml-dtypes@0.4:", when="@0.4.29:")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
diff --git a/var/spack/repos/builtin/packages/py-ml-dtypes/package.py b/var/spack/repos/builtin/packages/py-ml-dtypes/package.py
index 5e3f85ab7b..c90978d2b5 100644
--- a/var/spack/repos/builtin/packages/py-ml-dtypes/package.py
+++ b/var/spack/repos/builtin/packages/py-ml-dtypes/package.py
@@ -17,11 +17,13 @@ class PyMlDtypes(PythonPackage):
license("Apache-2.0")
+ version("0.4.0", tag="v0.4.0", commit="9fc7e6773acb66fa496ed8d476a008a489a4da49")
version("0.3.1", tag="v0.3.1", commit="bbeedd470ecac727c42e97648c0f27bfc312af30")
version("0.2.0", tag="v0.2.0", commit="5b9fc9ad978757654843f4a8d899715dbea30e88")
depends_on("python@3.9:", when="@0.3:", type=("build", "link", "run"))
- depends_on("py-numpy@1.21:", type=("build", "link", "run"))
+ depends_on("py-numpy@1.21:", when="@0.4:", type=("build", "link", "run"))
+ depends_on("py-numpy@1.21:1", when="@:0.3", type=("build", "link", "run"))
# Build dependencies are overconstrained, older versions work just fine
- depends_on("py-pybind11", type=("build", "link"))
+ depends_on("py-pybind11", when="@:0.3.1", type=("build", "link"))
depends_on("py-setuptools", type="build")
diff --git a/var/spack/repos/builtin/packages/py-tensorflow/package.py b/var/spack/repos/builtin/packages/py-tensorflow/package.py
index 627ee26710..028bf1e97b 100644
--- a/var/spack/repos/builtin/packages/py-tensorflow/package.py
+++ b/var/spack/repos/builtin/packages/py-tensorflow/package.py
@@ -305,12 +305,12 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
depends_on("cuda@:11.4", when="@2.4:2.7")
depends_on("cuda@:10.2", when="@:2.3")
- depends_on("cudnn@8.9:", when="@2.15:")
- depends_on("cudnn@8.7:", when="@2.14:")
- depends_on("cudnn@8.6:", when="@2.12:")
- depends_on("cudnn@8.1:", when="@2.5:")
- depends_on("cudnn@8.0:", when="@2.4:")
- depends_on("cudnn@7.6:", when="@2.1:")
+ depends_on("cudnn@8.9:8", when="@2.15:")
+ depends_on("cudnn@8.7:8", when="@2.14:")
+ depends_on("cudnn@8.6:8", when="@2.12:")
+ depends_on("cudnn@8.1:8", when="@2.5:")
+ depends_on("cudnn@8.0:8", when="@2.4:")
+ depends_on("cudnn@7.6:8", when="@2.1:")
depends_on("cudnn@:7", when="@:2.2")
# depends_on('tensorrt', when='+tensorrt')
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index 672b8b0201..e2616d2972 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -247,8 +247,9 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
depends_on("cuda@9.2:11.4", when="@1.6:1.9+cuda")
depends_on("cuda@9:11.4", when="@:1.5+cuda")
# https://github.com/pytorch/pytorch#prerequisites
- depends_on("cudnn@8.5:", when="@2.3:+cudnn")
- depends_on("cudnn@7:", when="@1.6:+cudnn")
+ # https://github.com/pytorch/pytorch/issues/119400
+ depends_on("cudnn@8.5:9.0", when="@2.3:+cudnn")
+ depends_on("cudnn@7:8", when="@1.6:2.2+cudnn")
depends_on("cudnn@7", when="@:1.5+cudnn")
depends_on("magma+cuda", when="+magma+cuda")
depends_on("magma+rocm", when="+magma+rocm")