summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam J. Stewart <ajstewart426@gmail.com>2024-07-01 09:02:48 +0200
committerGitHub <noreply@github.com>2024-07-01 09:02:48 +0200
commitb57f88cb89b782b7c7e0dfb7f6350470115eb664 (patch)
treea7667f846eeaacf3a747275c78d66612b259854a
parent03afc2a1e6c6f832a37fae2fecfbbbd2e53b17b2 (diff)
downloadspack-b57f88cb89b782b7c7e0dfb7f6350470115eb664.tar.gz
spack-b57f88cb89b782b7c7e0dfb7f6350470115eb664.tar.bz2
spack-b57f88cb89b782b7c7e0dfb7f6350470115eb664.tar.xz
spack-b57f88cb89b782b7c7e0dfb7f6350470115eb664.zip
JAX: add v0.4.30 (#44964)
-rw-r--r--var/spack/repos/builtin/packages/py-jax/package.py4
-rw-r--r--var/spack/repos/builtin/packages/py-jaxlib/package.py4
2 files changed, 5 insertions, 3 deletions
diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py
index 31e5f8d1a5..38f9b85529 100644
--- a/var/spack/repos/builtin/packages/py-jax/package.py
+++ b/var/spack/repos/builtin/packages/py-jax/package.py
@@ -24,6 +24,7 @@ class PyJax(PythonPackage):
license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle")
+ version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
version("0.4.28", sha256="dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9")
version("0.4.27", sha256="f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c")
@@ -57,7 +58,7 @@ class PyJax(PythonPackage):
with default_args(type=("build", "run")):
# setup.py
depends_on("python@3.9:", when="@0.4.14:")
- depends_on("py-ml-dtypes@0.4:", when="@0.4.29:")
+ depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
@@ -73,6 +74,7 @@ class PyJax(PythonPackage):
# jax/_src/lib/__init__.py
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
for v in [
+ "0.4.30",
"0.4.29",
"0.4.28",
"0.4.27",
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py
index 166bbc0474..04e0d50819 100644
--- a/var/spack/repos/builtin/packages/py-jaxlib/package.py
+++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py
@@ -20,6 +20,7 @@ class PyJaxlib(PythonPackage, CudaPackage):
license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle")
+ version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
version("0.4.28", sha256="4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b")
version("0.4.27", sha256="c2c82cd9ad3b395d5cbc0affa26a2938e52677a69ca8f0b9ef9922a52cac4f0c")
@@ -42,7 +43,6 @@ class PyJaxlib(PythonPackage, CudaPackage):
variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda")
# docs/installation.md
- # jaxlib/setup.py
with when("+cuda"):
depends_on("cuda@12.1:", when="@0.4.26:")
depends_on("cuda@11.8:", when="@0.4.11:")
@@ -82,7 +82,7 @@ class PyJaxlib(PythonPackage, CudaPackage):
depends_on("py-numpy@1.22:", when="@0.4.14:")
depends_on("py-numpy@1.21:", when="@0.4.7:")
depends_on("py-numpy@1.20:", when="@0.3:")
- depends_on("py-ml-dtypes@0.4:", when="@0.4.29:")
+ depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")