diff options
author | Adam J. Stewart <ajstewart426@gmail.com> | 2024-08-03 11:16:42 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-03 11:16:42 +0200 |
commit | 705d58005d04370d1482d3a6f71389d860767902 (patch) | |
tree | 367b1dbdcdfb8255184253dc23ed10fd97f911f5 /var | |
parent | cee266046b42644a385c8fa8d73ecfe08e5cd4ea (diff) | |
download | spack-705d58005d04370d1482d3a6f71389d860767902.tar.gz spack-705d58005d04370d1482d3a6f71389d860767902.tar.bz2 spack-705d58005d04370d1482d3a6f71389d860767902.tar.xz spack-705d58005d04370d1482d3a6f71389d860767902.zip |
py-jax / JAX: add v0.4.31 (#45519)
Diffstat (limited to 'var')
-rw-r--r-- | var/spack/repos/builtin/packages/py-jax/package.py | 12 | ||||
-rw-r--r-- | var/spack/repos/builtin/packages/py-jaxlib/package.py | 23 |
2 files changed, 25 insertions, 10 deletions
diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py index e97c0966ba..499b906573 100644 --- a/var/spack/repos/builtin/packages/py-jax/package.py +++ b/var/spack/repos/builtin/packages/py-jax/package.py @@ -24,6 +24,7 @@ class PyJax(PythonPackage): license("Apache-2.0") maintainers("adamjstewart", "jonas-eschle") + version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287") version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577") version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186") version("0.4.28", sha256="dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9") @@ -57,25 +58,27 @@ class PyJax(PythonPackage): with default_args(type=("build", "run")): # setup.py + depends_on("python@3.10:", when="@0.4.31:") depends_on("python@3.9:", when="@0.4.14:") - depends_on("py-ml-dtypes@0.4:", when="@0.4.29") depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") + depends_on("py-numpy@1.24:", when="@0.4.31:") depends_on("py-numpy@1.22:", when="@0.4.14:") depends_on("py-numpy@1.21:", when="@0.4.7:") depends_on("py-numpy@1.20:", when="@0.3:") # https://github.com/google/jax/issues/19246 depends_on("py-numpy@:1", when="@:0.4.25") depends_on("py-opt-einsum") + depends_on("py-scipy@1.10:", when="@0.4.31:") depends_on("py-scipy@1.9:", when="@0.4.19:") depends_on("py-scipy@1.7:", when="@0.4.7:") depends_on("py-scipy@1.5:", when="@0.3:") - depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9") # jax/_src/lib/__init__.py # https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4 for v in [ + "0.4.31", "0.4.30", "0.4.29", "0.4.28", @@ -108,6 +111,7 @@ class PyJax(PythonPackage): depends_on(f"py-jaxlib@:{v}", when=f"@{v}") # See _minimum_jaxlib_version in jax/version.py + depends_on("py-jaxlib@0.4.30:", when="@0.4.31:") depends_on("py-jaxlib@0.4.27:", when="@0.4.28:") depends_on("py-jaxlib@0.4.23:", when="@0.4.27:") depends_on("py-jaxlib@0.4.20:", when="@0.4.25:") @@ -119,3 +123,7 @@ class PyJax(PythonPackage): depends_on("py-jaxlib@0.4.4:", when="@0.4.5:") depends_on("py-jaxlib@0.4.2:", when="@0.4.3:") depends_on("py-jaxlib@0.4.1:", when="@0.4.2:") + + # Historical dependencies + depends_on("py-ml-dtypes@0.4:", when="@0.4.29") + depends_on("py-importlib-metadata@4.6:", when="@0.4.11:0.4.30 ^python@:3.9") diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index aac511a872..ff67dd4604 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -20,6 +20,7 @@ class PyJaxlib(PythonPackage, CudaPackage): license("Apache-2.0") maintainers("adamjstewart", "jonas-eschle") + version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94") version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6") version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246") version("0.4.28", sha256="4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b") @@ -39,18 +40,19 @@ class PyJaxlib(PythonPackage, CudaPackage): version("0.4.4", sha256="881f402c7983b56b185e182d5315dd64c9f5320be96213d0415996ece1826806") version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910") - depends_on("c", type="build") # generated - depends_on("cxx", type="build") # generated + depends_on("c", type="build") + depends_on("cxx", type="build") variant("cuda", default=True, description="Build with CUDA enabled") variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda") - # docs/installation.md + # docs/installation.md (Compatible with) with when("+cuda"): depends_on("cuda@12.1:", when="@0.4.26:") depends_on("cuda@11.8:", when="@0.4.11:") depends_on("cuda@11.4:", when="@0.4.0:0.4.7") - depends_on("cudnn@9", when="@0.4.29:") + depends_on("cudnn@9.1:9", when="@0.4.31:") + depends_on("cudnn@9", when="@0.4.29:0.4.30") depends_on("cudnn@8.9:8", when="@0.4.26:0.4.28") depends_on("cudnn@8.8:8", when="@0.4.11:0.4.25") depends_on("cudnn@8.2:8", when="@0.4:0.4.7") @@ -74,24 +76,29 @@ class PyJaxlib(PythonPackage, CudaPackage): with default_args(type=("build", "run")): # Based on PyPI wheels - depends_on("python@3.9:3.12", when="@0.4.17:") + depends_on("python@3.10:3.12", when="@0.4.31:") + depends_on("python@3.9:3.12", when="@0.4.17:0.4.30") depends_on("python@3.9:3.11", when="@0.4.14:0.4.16") depends_on("python@3.8:3.11", when="@0.4.6:0.4.13") # jaxlib/setup.py + depends_on("py-scipy@1.10:", when="@0.4.31:") depends_on("py-scipy@1.9:", when="@0.4.19:") depends_on("py-scipy@1.7:", when="@0.4.7:") depends_on("py-scipy@1.5:") + depends_on("py-numpy@1.24:", when="@0.4.31:") depends_on("py-numpy@1.22:", when="@0.4.14:") depends_on("py-numpy@1.21:", when="@0.4.7:") depends_on("py-numpy@1.20:", when="@0.3:") - # https://github.com/google/jax/issues/19246 - depends_on("py-numpy@:1", when="@:0.4.25") - depends_on("py-ml-dtypes@0.4:", when="@0.4.29") depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") + # Historical dependencies + # https://github.com/google/jax/issues/19246 + depends_on("py-numpy@:1", when="@:0.4.25") + depends_on("py-ml-dtypes@0.4:", when="@0.4.29") + conflicts( "cuda_arch=none", when="+cuda", |