summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
authorAdam J. Stewart <ajstewart426@gmail.com>2024-08-03 11:16:42 +0200
committerGitHub <noreply@github.com>2024-08-03 11:16:42 +0200
commit705d58005d04370d1482d3a6f71389d860767902 (patch)
tree367b1dbdcdfb8255184253dc23ed10fd97f911f5 /var
parentcee266046b42644a385c8fa8d73ecfe08e5cd4ea (diff)
downloadspack-705d58005d04370d1482d3a6f71389d860767902.tar.gz
spack-705d58005d04370d1482d3a6f71389d860767902.tar.bz2
spack-705d58005d04370d1482d3a6f71389d860767902.tar.xz
spack-705d58005d04370d1482d3a6f71389d860767902.zip
py-jax / JAX: add v0.4.31 (#45519)
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-jax/package.py12
-rw-r--r--var/spack/repos/builtin/packages/py-jaxlib/package.py23
2 files changed, 25 insertions, 10 deletions
diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py
index e97c0966ba..499b906573 100644
--- a/var/spack/repos/builtin/packages/py-jax/package.py
+++ b/var/spack/repos/builtin/packages/py-jax/package.py
@@ -24,6 +24,7 @@ class PyJax(PythonPackage):
license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle")
+ version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287")
version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
version("0.4.28", sha256="dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9")
@@ -57,25 +58,27 @@ class PyJax(PythonPackage):
with default_args(type=("build", "run")):
# setup.py
+ depends_on("python@3.10:", when="@0.4.31:")
depends_on("python@3.9:", when="@0.4.14:")
- depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
+ depends_on("py-numpy@1.24:", when="@0.4.31:")
depends_on("py-numpy@1.22:", when="@0.4.14:")
depends_on("py-numpy@1.21:", when="@0.4.7:")
depends_on("py-numpy@1.20:", when="@0.3:")
# https://github.com/google/jax/issues/19246
depends_on("py-numpy@:1", when="@:0.4.25")
depends_on("py-opt-einsum")
+ depends_on("py-scipy@1.10:", when="@0.4.31:")
depends_on("py-scipy@1.9:", when="@0.4.19:")
depends_on("py-scipy@1.7:", when="@0.4.7:")
depends_on("py-scipy@1.5:", when="@0.3:")
- depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9")
# jax/_src/lib/__init__.py
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
for v in [
+ "0.4.31",
"0.4.30",
"0.4.29",
"0.4.28",
@@ -108,6 +111,7 @@ class PyJax(PythonPackage):
depends_on(f"py-jaxlib@:{v}", when=f"@{v}")
# See _minimum_jaxlib_version in jax/version.py
+ depends_on("py-jaxlib@0.4.30:", when="@0.4.31:")
depends_on("py-jaxlib@0.4.27:", when="@0.4.28:")
depends_on("py-jaxlib@0.4.23:", when="@0.4.27:")
depends_on("py-jaxlib@0.4.20:", when="@0.4.25:")
@@ -119,3 +123,7 @@ class PyJax(PythonPackage):
depends_on("py-jaxlib@0.4.4:", when="@0.4.5:")
depends_on("py-jaxlib@0.4.2:", when="@0.4.3:")
depends_on("py-jaxlib@0.4.1:", when="@0.4.2:")
+
+ # Historical dependencies
+ depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
+ depends_on("py-importlib-metadata@4.6:", when="@0.4.11:0.4.30 ^python@:3.9")
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py
index aac511a872..ff67dd4604 100644
--- a/var/spack/repos/builtin/packages/py-jaxlib/package.py
+++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py
@@ -20,6 +20,7 @@ class PyJaxlib(PythonPackage, CudaPackage):
license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle")
+ version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94")
version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
version("0.4.28", sha256="4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b")
@@ -39,18 +40,19 @@ class PyJaxlib(PythonPackage, CudaPackage):
version("0.4.4", sha256="881f402c7983b56b185e182d5315dd64c9f5320be96213d0415996ece1826806")
version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
- depends_on("c", type="build") # generated
- depends_on("cxx", type="build") # generated
+ depends_on("c", type="build")
+ depends_on("cxx", type="build")
variant("cuda", default=True, description="Build with CUDA enabled")
variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda")
- # docs/installation.md
+ # docs/installation.md (Compatible with)
with when("+cuda"):
depends_on("cuda@12.1:", when="@0.4.26:")
depends_on("cuda@11.8:", when="@0.4.11:")
depends_on("cuda@11.4:", when="@0.4.0:0.4.7")
- depends_on("cudnn@9", when="@0.4.29:")
+ depends_on("cudnn@9.1:9", when="@0.4.31:")
+ depends_on("cudnn@9", when="@0.4.29:0.4.30")
depends_on("cudnn@8.9:8", when="@0.4.26:0.4.28")
depends_on("cudnn@8.8:8", when="@0.4.11:0.4.25")
depends_on("cudnn@8.2:8", when="@0.4:0.4.7")
@@ -74,24 +76,29 @@ class PyJaxlib(PythonPackage, CudaPackage):
with default_args(type=("build", "run")):
# Based on PyPI wheels
- depends_on("python@3.9:3.12", when="@0.4.17:")
+ depends_on("python@3.10:3.12", when="@0.4.31:")
+ depends_on("python@3.9:3.12", when="@0.4.17:0.4.30")
depends_on("python@3.9:3.11", when="@0.4.14:0.4.16")
depends_on("python@3.8:3.11", when="@0.4.6:0.4.13")
# jaxlib/setup.py
+ depends_on("py-scipy@1.10:", when="@0.4.31:")
depends_on("py-scipy@1.9:", when="@0.4.19:")
depends_on("py-scipy@1.7:", when="@0.4.7:")
depends_on("py-scipy@1.5:")
+ depends_on("py-numpy@1.24:", when="@0.4.31:")
depends_on("py-numpy@1.22:", when="@0.4.14:")
depends_on("py-numpy@1.21:", when="@0.4.7:")
depends_on("py-numpy@1.20:", when="@0.3:")
- # https://github.com/google/jax/issues/19246
- depends_on("py-numpy@:1", when="@:0.4.25")
- depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
+ # Historical dependencies
+ # https://github.com/google/jax/issues/19246
+ depends_on("py-numpy@:1", when="@:0.4.25")
+ depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
+
conflicts(
"cuda_arch=none",
when="+cuda",