summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
authorJonas Eschle <mayou36@jonas.eschle.com>2024-04-07 06:04:23 -0400
committerGitHub <noreply@github.com>2024-04-07 12:04:23 +0200
commit93e6f5fa4e47c6cc9183eae1b531c261a6664598 (patch)
tree69b92843516668d0cb8bbb0d72af932e555125bd /var
parent54acda3f1182b20994bc6b4592d91995c863d41f (diff)
downloadspack-93e6f5fa4e47c6cc9183eae1b531c261a6664598.tar.gz
spack-93e6f5fa4e47c6cc9183eae1b531c261a6664598.tar.bz2
spack-93e6f5fa4e47c6cc9183eae1b531c261a6664598.tar.xz
spack-93e6f5fa4e47c6cc9183eae1b531c261a6664598.zip
Update jax & jaxlib versions (#42863)
* upgrade new versions * style fix * update jaxlib deps (not cuda and bazel yet) * update jaxlib cuda versions * update jaxlib cuda versions * update jaxlib cuda versions * chore: style fix * Update package.py * Update package.py * fix: typo * docs: add source for cuda version * py-jaxlib 0.4.14 also doesn't build on ppc64le * Add 0.4.26 --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-jax/package.py50
-rw-r--r--var/spack/repos/builtin/packages/py-jaxlib/package.py29
2 files changed, 71 insertions, 8 deletions
diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py
index d3c0ed93f6..2be3c700f0 100644
--- a/var/spack/repos/builtin/packages/py-jax/package.py
+++ b/var/spack/repos/builtin/packages/py-jax/package.py
@@ -22,11 +22,31 @@ class PyJax(PythonPackage):
pypi = "jax/jax-0.2.25.tar.gz"
license("Apache-2.0")
- maintainers("adamjstewart")
+ maintainers("adamjstewart", "jonas-eschle")
+ version("0.4.26", sha256="2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383")
version("0.4.25", sha256="a8ee189c782de2b7b2ffb64a8916da380b882a617e2769aa429b71d79747b982")
+ version("0.4.24", sha256="4a6b6fd026ddd22653c7fa2fac1904c3de2dbe845b61ede08af9a5cc709662ae")
version("0.4.23", sha256="2a229a5a758d1b803891b2eaed329723f6b15b4258b14dc0ccb1498c84963685")
+ version("0.4.22", sha256="801434dda6e14f82a45fff753969a33281ab22fb2a50fe801b651390321057ba")
+ version("0.4.21", sha256="c97fd0d2751d6e1eb15aa2052ff7cfdc129f8fafc2c14cd779720658926a587b")
+ version("0.4.20", sha256="ea96a763a8b1a9374639d1159ab4de163461d01cd022f67c34c09581b71ed2ac")
+ version("0.4.19", sha256="29f87f9a50964d3ca5eeb2973de3462f0e8b4eca6d46027894a0e9a903420601")
+ version("0.4.18", sha256="776cf33890100803e98f45f9af10aa727271c6993d4e766c069118733c928132")
+ version("0.4.17", sha256="d7508a69e87835f534cb07a2f21d79cc1cb8c4cfdcf7fb010927267ef7355f1d")
version("0.4.16", sha256="e2ca82c9bf973c2c1c01f5340a583692b31f277aa3abd0544229c1fe5fa44b02")
+ version("0.4.15", sha256="2aa123ccef591e355dea94a6e714b6559f8e1d6368a576a223f97d031ece0d15")
+ version("0.4.14", sha256="18fed3881f26e8b13c8cb46eeeea3dba9eb4d48e3714d8e8f2304dd6e237083d")
+ version("0.4.13", sha256="03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa")
+ version("0.4.12", sha256="d2de9a2388ffe002f16506d3ad1cc6e34d7536b98948e49c7e05bbcfe8e57998")
+ version("0.4.11", sha256="8b1cd443b698339df8d8807578ee141e5b67e36125b3945b146f600177d60d79")
+ version("0.4.10", sha256="1bf0f2720f778f2937301a16a4d5cd3497f13a4d6c970c24a88918a81816a888")
+ version("0.4.9", sha256="1ed135cd08f48e4baf10f6eafdb4a4cdae781f9052b5838c09c91a9f4fa75f09")
+ version("0.4.8", sha256="08116481f7336db16c24812bfb5e6f9786915f4c2f6ff4028331fa69e7535202")
+ version("0.4.7", sha256="5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8")
+ version("0.4.6", sha256="d06ea8fba4ed315ec55110396058cb48c8edb2ab0b412f28c8a123beee9e58ab")
+ version("0.4.5", sha256="1633e56d34b18ddfa7d2a216ce214fa6fa712d36552532aaa71da416aede7268")
+ version("0.4.4", sha256="39b07e07343ed7c74492ee5e75db77456d3afdd038a322671f09fc748f6392cb")
version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae")
version(
"0.3.23",
@@ -58,7 +78,33 @@ class PyJax(PythonPackage):
# See jax/_src/lib/__init__.py
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
- for v in ["0.4.25", "0.4.23", "0.4.16", "0.4.3", "0.3.23"]:
+ for v in [
+ "0.4.26",
+ "0.4.25",
+ "0.4.24",
+ "0.4.23",
+ "0.4.22",
+ "0.4.21",
+ "0.4.20",
+ "0.4.19",
+ "0.4.18",
+ "0.4.17",
+ "0.4.16",
+ "0.4.15",
+ "0.4.14",
+ "0.4.13",
+ "0.4.12",
+ "0.4.11",
+ "0.4.10",
+ "0.4.9",
+ "0.4.8",
+ "0.4.7",
+ "0.4.6",
+ "0.4.5",
+ "0.4.4",
+ "0.4.3",
+ "0.3.23",
+ ]:
depends_on(f"py-jaxlib@:{v}", when=f"@{v}", type=("build", "run"))
# See _minimum_jaxlib_version in jax/version.py
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py
index 4a9c9ec352..06864e48c8 100644
--- a/var/spack/repos/builtin/packages/py-jaxlib/package.py
+++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py
@@ -20,9 +20,19 @@ class PyJaxlib(PythonPackage, CudaPackage):
license("Apache-2.0")
maintainers("adamjstewart")
+ version("0.4.26", sha256="ddc14da1eaa34f23430d40ad9b9585088575cac439a2fa1c6833a247e1b221fd")
version("0.4.25", sha256="fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8")
version("0.4.24", sha256="c4e6963c2c36f634a9a1765e476a1ed4e6c4a7954465ebf72e29f344c28ddc28")
+ version("0.4.23", sha256="e4c06d62ba54becffd91abc862627b8b11b79c5a77366af8843b819665b6d568")
+ version("0.4.21", sha256="8d57f66d00b9c0b824b1eff84adda5b765a412b3f316ef7c773632d1edbf9477")
+ version("0.4.20", sha256="058410d2bc12f7562c7b01e0c8cd587cb68059c12f78bc945055e5ddc445f5fd")
+ version("0.4.19", sha256="51242b217a1f82474e42d24f09ed5dedff951eeb4579c6e49e706d1adfd6949d")
version("0.4.16", sha256="85c8bc050abe0a2cf62e8cfc7edb4904dd3807924b5714ec6277f291c576b5ca")
+ version("0.4.14", sha256="9f309476a8f6337717b059b8d10b5859b4134c30cf8f1220bb70379b5e2744a4")
+ version("0.4.11", sha256="bdfc45f33970beba5caf28d061668a4863f05994deea26791db50ea605fc2e36")
+ version("0.4.7", sha256="0578d5dd5035b5225cadb6a62ca5f93dd76b70292268502fc01a0fd9ca7001d0")
+ version("0.4.6", sha256="2c9bf8962815bc54ef524e33dc8eda9d165d379fe87e0df210f316adead27787")
+ version("0.4.4", sha256="881f402c7983b56b185e182d5315dd64c9f5320be96213d0415996ece1826806")
version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
version(
"0.3.22",
@@ -40,9 +50,12 @@ class PyJaxlib(PythonPackage, CudaPackage):
# build/build.py
depends_on("py-build", when="@0.4.14:", type="build")
+ # Based on PyPI wheels
+ depends_on("python@3.9:3.12", when="@0.4.17:", type=("build", "run"))
+ depends_on("python@3.9:3.11", when="@0.4.14:0.4.16", type=("build", "run"))
+ depends_on("python@3.8:3.11", when="@0.4.6:0.4.13", type=("build", "run"))
+
# jaxlib/setup.py
- depends_on("python@3.9:", when="@0.4.14:", type=("build", "run"))
- depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
depends_on("py-setuptools", type="build")
depends_on("py-scipy@1.9:", when="@0.4.19:", type=("build", "run"))
depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run"))
@@ -63,12 +76,16 @@ class PyJaxlib(PythonPackage, CudaPackage):
depends_on("bazel@4.2.1", when="@0.1.75:0.1.76", type="build")
depends_on("bazel@4.1.0", when="@0.1.70:0.1.74", type="build")
- # README.md
- depends_on("cuda@11.4:", when="@0.4:+cuda")
+ # jaxlib/setup.py
+ depends_on("cuda@12.1.105:", when="@0.4.26:+cuda")
+ depends_on("cuda@11.8:", when="@0.4.11:+cuda")
+ depends_on("cuda@11.4:", when="@0.4.0:0.4.7+cuda")
depends_on("cuda@11.1:", when="@0.3+cuda")
# https://github.com/google/jax/issues/12614
depends_on("cuda@11.1:11.7.0", when="@0.1+cuda")
- depends_on("cudnn@8.2:", when="@0.4:+cuda")
+
+ depends_on("cudnn@8.8:", when="@0.4.11:+cuda")
+ depends_on("cudnn@8.2:", when="@0.4:0.4.7+cuda")
depends_on("cudnn@8.0.5:", when="+cuda")
# Historical dependencies
@@ -83,7 +100,7 @@ class PyJaxlib(PythonPackage, CudaPackage):
)
# https://github.com/google/jax/issues/19992
- conflicts("@0.4.16:0.4.25", when="target=ppc64le:")
+ conflicts("@0.4.4:", when="target=ppc64le:")
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")