diff options
author | Adam J. Stewart <ajstewart426@gmail.com> | 2024-05-08 20:36:24 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-08 11:36:24 -0700 |
commit | 314893982e5e4b3f11da994b77fbb610afa444ec (patch) | |
tree | fa4138382d3b742b55712a6eb4feb65bdee78ead | |
parent | 9ab6c30a3dfac2a5d5c3f1f72a420f061e9f3f0f (diff) | |
download | spack-314893982e5e4b3f11da994b77fbb610afa444ec.tar.gz spack-314893982e5e4b3f11da994b77fbb610afa444ec.tar.bz2 spack-314893982e5e4b3f11da994b77fbb610afa444ec.tar.xz spack-314893982e5e4b3f11da994b77fbb610afa444ec.zip |
JAX: add v0.4.27, NCCL variant (#44071)
-rw-r--r-- | var/spack/repos/builtin/packages/py-jax/package.py | 142 | ||||
-rw-r--r-- | var/spack/repos/builtin/packages/py-jaxlib/package.py | 102 |
2 files changed, 131 insertions, 113 deletions
diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py index 2be3c700f0..2394c0b9ae 100644 --- a/var/spack/repos/builtin/packages/py-jax/package.py +++ b/var/spack/repos/builtin/packages/py-jax/package.py @@ -19,11 +19,12 @@ class PyJax(PythonPackage): arbitrarily to any order.""" homepage = "https://github.com/google/jax" - pypi = "jax/jax-0.2.25.tar.gz" + pypi = "jax/jax-0.4.27.tar.gz" license("Apache-2.0") maintainers("adamjstewart", "jonas-eschle") + version("0.4.27", sha256="f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c") version("0.4.26", sha256="2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383") version("0.4.25", sha256="a8ee189c782de2b7b2ffb64a8916da380b882a617e2769aa429b71d79747b982") version("0.4.24", sha256="4a6b6fd026ddd22653c7fa2fac1904c3de2dbe845b61ede08af9a5cc709662ae") @@ -59,74 +60,79 @@ class PyJax(PythonPackage): deprecated=True, ) - depends_on("python@3.9:", when="@0.4.14:", type=("build", "run")) - depends_on("python@3.8:", when="@0.4:", type=("build", "run")) depends_on("py-setuptools", type="build") - depends_on("py-ml-dtypes@0.2:", when="@0.4.14:", type=("build", "run")) - depends_on("py-ml-dtypes@0.1:", when="@0.4.9:", type=("build", "run")) - depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run")) - depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run")) - depends_on("py-numpy@1.21:", when="@0.4.7:", type=("build", "run")) - depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run")) - depends_on("py-numpy@1.18:", type=("build", "run")) - depends_on("py-opt-einsum", type=("build", "run")) - depends_on("py-scipy@1.9:", when="@0.4.19:", type=("build", "run")) - depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run")) - depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run")) - depends_on("py-scipy@1.2.1:", type=("build", "run")) - depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9", type=("build", "run")) - # See jax/_src/lib/__init__.py - # https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4 - for v in [ - "0.4.26", - "0.4.25", - "0.4.24", - "0.4.23", - "0.4.22", - "0.4.21", - "0.4.20", - "0.4.19", - "0.4.18", - "0.4.17", - "0.4.16", - "0.4.15", - "0.4.14", - "0.4.13", - "0.4.12", - "0.4.11", - "0.4.10", - "0.4.9", - "0.4.8", - "0.4.7", - "0.4.6", - "0.4.5", - "0.4.4", - "0.4.3", - "0.3.23", - ]: - depends_on(f"py-jaxlib@:{v}", when=f"@{v}", type=("build", "run")) + with default_args(type=("build", "run")): + # setup.py + depends_on("python@3.9:", when="@0.4.14:") + depends_on("python@3.8:", when="@0.4:") + depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") + depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") + depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") + depends_on("py-numpy@1.22:", when="@0.4.14:") + depends_on("py-numpy@1.21:", when="@0.4.7:") + depends_on("py-numpy@1.20:", when="@0.3:") + depends_on("py-numpy@1.18:") + depends_on("py-opt-einsum") + depends_on("py-scipy@1.9:", when="@0.4.19:") + depends_on("py-scipy@1.7:", when="@0.4.7:") + depends_on("py-scipy@1.5:", when="@0.3:") + depends_on("py-scipy@1.2.1:") + depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9") + + # jax/_src/lib/__init__.py + # https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4 + for v in [ + "0.4.27", + "0.4.26", + "0.4.25", + "0.4.24", + "0.4.23", + "0.4.22", + "0.4.21", + "0.4.20", + "0.4.19", + "0.4.18", + "0.4.17", + "0.4.16", + "0.4.15", + "0.4.14", + "0.4.13", + "0.4.12", + "0.4.11", + "0.4.10", + "0.4.9", + "0.4.8", + "0.4.7", + "0.4.6", + "0.4.5", + "0.4.4", + "0.4.3", + "0.3.23", + ]: + depends_on(f"py-jaxlib@:{v}", when=f"@{v}") - # See _minimum_jaxlib_version in jax/version.py - depends_on("py-jaxlib@0.4.20:", when="@0.4.25:", type=("build", "run")) - depends_on("py-jaxlib@0.4.19:", when="@0.4.21:", type=("build", "run")) - depends_on("py-jaxlib@0.4.14:", when="@0.4.15:", type=("build", "run")) - depends_on("py-jaxlib@0.4.11:", when="@0.4.12:", type=("build", "run")) - depends_on("py-jaxlib@0.4.7:", when="@0.4.8:", type=("build", "run")) - depends_on("py-jaxlib@0.4.6:", when="@0.4.7:", type=("build", "run")) - depends_on("py-jaxlib@0.4.4:", when="@0.4.5:", type=("build", "run")) - depends_on("py-jaxlib@0.4.2:", when="@0.4.3:", type=("build", "run")) - depends_on("py-jaxlib@0.4.1:", when="@0.4.2:", type=("build", "run")) - depends_on("py-jaxlib@0.3.22:", when="@0.3.24:", type=("build", "run")) - depends_on("py-jaxlib@0.3.15:", when="@0.3.18:", type=("build", "run")) - depends_on("py-jaxlib@0.3.14:", when="@0.3.15:", type=("build", "run")) - depends_on("py-jaxlib@0.3.7:", when="@0.3.8:", type=("build", "run")) - depends_on("py-jaxlib@0.3.2:", when="@0.3.7:", type=("build", "run")) - depends_on("py-jaxlib@0.3.0:", when="@0.3.2:", type=("build", "run")) - depends_on("py-jaxlib@0.1.74:", when="@0.2.26:", type=("build", "run")) - depends_on("py-jaxlib@0.1.69:", when="@0.2.18:", type=("build", "run")) + # See _minimum_jaxlib_version in jax/version.py + depends_on("py-jaxlib@0.4.23:", when="@0.4.27:") + depends_on("py-jaxlib@0.4.20:", when="@0.4.25:") + depends_on("py-jaxlib@0.4.19:", when="@0.4.21:") + depends_on("py-jaxlib@0.4.14:", when="@0.4.15:") + depends_on("py-jaxlib@0.4.11:", when="@0.4.12:") + depends_on("py-jaxlib@0.4.7:", when="@0.4.8:") + depends_on("py-jaxlib@0.4.6:", when="@0.4.7:") + depends_on("py-jaxlib@0.4.4:", when="@0.4.5:") + depends_on("py-jaxlib@0.4.2:", when="@0.4.3:") + depends_on("py-jaxlib@0.4.1:", when="@0.4.2:") + depends_on("py-jaxlib@0.3.22:", when="@0.3.24:") + depends_on("py-jaxlib@0.3.15:", when="@0.3.18:") + depends_on("py-jaxlib@0.3.14:", when="@0.3.15:") + depends_on("py-jaxlib@0.3.7:", when="@0.3.8:") + depends_on("py-jaxlib@0.3.2:", when="@0.3.7:") + depends_on("py-jaxlib@0.3.0:", when="@0.3.2:") + depends_on("py-jaxlib@0.1.74:", when="@0.2.26:") + depends_on("py-jaxlib@0.1.69:", when="@0.2.18:") - # Historical dependencies - depends_on("py-absl-py", when="@:0.3", type=("build", "run")) - depends_on("py-typing-extensions", when="@:0.3", type=("build", "run")) - depends_on("py-etils+epath", when="@0.3", type=("build", "run")) + # Historical dependencies + depends_on("py-absl-py", when="@:0.3") + depends_on("py-typing-extensions", when="@:0.3") + depends_on("py-etils+epath", when="@0.3") diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index 06864e48c8..fcca93ad04 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -12,7 +12,7 @@ class PyJaxlib(PythonPackage, CudaPackage): """XLA library for Jax""" homepage = "https://github.com/google/jax" - url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz" + url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.27.tar.gz" tmp_path = "" buildtmp = "" @@ -20,6 +20,7 @@ class PyJaxlib(PythonPackage, CudaPackage): license("Apache-2.0") maintainers("adamjstewart") + version("0.4.27", sha256="c2c82cd9ad3b395d5cbc0affa26a2938e52677a69ca8f0b9ef9922a52cac4f0c") version("0.4.26", sha256="ddc14da1eaa34f23430d40ad9b9585088575cac439a2fa1c6833a247e1b221fd") version("0.4.25", sha256="fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8") version("0.4.24", sha256="c4e6963c2c36f634a9a1765e476a1ed4e6c4a7954465ebf72e29f344c28ddc28") @@ -45,52 +46,63 @@ class PyJaxlib(PythonPackage, CudaPackage): deprecated=True, ) - variant("cuda", default=True, description="Build with CUDA") - - # build/build.py - depends_on("py-build", when="@0.4.14:", type="build") - - # Based on PyPI wheels - depends_on("python@3.9:3.12", when="@0.4.17:", type=("build", "run")) - depends_on("python@3.9:3.11", when="@0.4.14:0.4.16", type=("build", "run")) - depends_on("python@3.8:3.11", when="@0.4.6:0.4.13", type=("build", "run")) - - # jaxlib/setup.py - depends_on("py-setuptools", type="build") - depends_on("py-scipy@1.9:", when="@0.4.19:", type=("build", "run")) - depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run")) - depends_on("py-scipy@1.5:", type=("build", "run")) - depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run")) - depends_on("py-numpy@1.21:", when="@0.4.7:", type=("build", "run")) - depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run")) - depends_on("py-numpy@1.18:", type=("build", "run")) - depends_on("py-ml-dtypes@0.2:", when="@0.4.14:", type=("build", "run")) - depends_on("py-ml-dtypes@0.1:", when="@0.4.9:", type=("build", "run")) - depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run")) - - # .bazelversion - depends_on("bazel@6.1.2", when="@0.4.11:", type="build") - depends_on("bazel@5.1.1", when="@0.3.7:0.4.10", type="build") - depends_on("bazel@5.1.0", when="@0.3.5", type="build") - depends_on("bazel@5.0.0", when="@0.3.0:0.3.2", type="build") - depends_on("bazel@4.2.1", when="@0.1.75:0.1.76", type="build") - depends_on("bazel@4.1.0", when="@0.1.70:0.1.74", type="build") + variant("cuda", default=True, description="Build with CUDA enabled") + variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda") + # docs/installation.md # jaxlib/setup.py - depends_on("cuda@12.1.105:", when="@0.4.26:+cuda") - depends_on("cuda@11.8:", when="@0.4.11:+cuda") - depends_on("cuda@11.4:", when="@0.4.0:0.4.7+cuda") - depends_on("cuda@11.1:", when="@0.3+cuda") - # https://github.com/google/jax/issues/12614 - depends_on("cuda@11.1:11.7.0", when="@0.1+cuda") - - depends_on("cudnn@8.8:", when="@0.4.11:+cuda") - depends_on("cudnn@8.2:", when="@0.4:0.4.7+cuda") - depends_on("cudnn@8.0.5:", when="+cuda") - - # Historical dependencies - depends_on("py-absl-py", when="@:0.3", type=("build", "run")) - depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run")) + with when("+cuda"): + depends_on("cuda@12.1:", when="@0.4.26:") + depends_on("cuda@11.8:", when="@0.4.11:") + depends_on("cuda@11.4:", when="@0.4.0:0.4.7") + depends_on("cuda@11.1:", when="@0.3") + depends_on("cuda@11.1:11.7.0", when="@0.1") + depends_on("cudnn@8.9:8", when="@0.4.26:") + depends_on("cudnn@8.8:", when="@0.4.11:") + depends_on("cudnn@8.2:", when="@0.4:0.4.7") + depends_on("cudnn@8.0.5:") + + with when("+nccl"): + depends_on("nccl@2.18:", when="@0.4.26:") + depends_on("nccl@2.16:", when="@0.4.18:") + depends_on("nccl") + + with default_args(type="build"): + # .bazelversion + depends_on("bazel@6.1.2", when="@0.4.11:") + depends_on("bazel@5.1.1", when="@0.3.7:0.4.10") + depends_on("bazel@5.1.0", when="@0.3.5") + depends_on("bazel@5.0.0", when="@0.3.0:0.3.2") + depends_on("bazel@4.2.1", when="@0.1.75:0.1.76") + depends_on("bazel@4.1.0", when="@0.1.70:0.1.74") + + # jaxlib/setup.py + depends_on("py-setuptools") + + # build/build.py + depends_on("py-build", when="@0.4.14:") + + with default_args(type=("build", "run")): + # Based on PyPI wheels + depends_on("python@3.9:3.12", when="@0.4.17:") + depends_on("python@3.9:3.11", when="@0.4.14:0.4.16") + depends_on("python@3.8:3.11", when="@0.4.6:0.4.13") + + # jaxlib/setup.py + depends_on("py-scipy@1.9:", when="@0.4.19:") + depends_on("py-scipy@1.7:", when="@0.4.7:") + depends_on("py-scipy@1.5:") + depends_on("py-numpy@1.22:", when="@0.4.14:") + depends_on("py-numpy@1.21:", when="@0.4.7:") + depends_on("py-numpy@1.20:", when="@0.3:") + depends_on("py-numpy@1.18:") + depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") + depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") + depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") + + # Historical dependencies + depends_on("py-absl-py", when="@:0.3") + depends_on("py-flatbuffers@1.12:2", when="@0.1") conflicts( "cuda_arch=none", |