From 16adda3db95fc9eceafbf8ee0969c1af89db419d Mon Sep 17 00:00:00 2001
From: "Adam J. Stewart" <ajstewart426@gmail.com>
Date: Tue, 21 Feb 2023 12:14:27 -0700
Subject: py-jax: add v0.4.3 (#35460)

* py-jax: add v0.4.3

* Minimum version is minimum

* py-jax no longer has cuda variant

* Enable CUDA by default

* Link to discussion of upper bound
---
 .../repos/builtin/packages/py-alphafold/package.py |  6 --
 var/spack/repos/builtin/packages/py-jax/package.py | 42 ++++++------
 .../repos/builtin/packages/py-jaxlib/package.py    | 76 ++++++++++++----------
 3 files changed, 64 insertions(+), 60 deletions(-)

(limited to 'var')

diff --git a/var/spack/repos/builtin/packages/py-alphafold/package.py b/var/spack/repos/builtin/packages/py-alphafold/package.py
index 3356059a4d..ad95a4201b 100644
--- a/var/spack/repos/builtin/packages/py-alphafold/package.py
+++ b/var/spack/repos/builtin/packages/py-alphafold/package.py
@@ -37,12 +37,6 @@ class PyAlphafold(PythonPackage, CudaPackage):
     depends_on("py-immutabledict@2.0.0:", type=("build", "run"))
     depends_on("py-jax@0.2.14:", type=("build", "run"), when="@2.1.1")
     depends_on("py-jax@0.3.17:", type=("build", "run"), when="@2.2.4")
-    for arch in CudaPackage.cuda_arch_values:
-        depends_on(
-            "py-jax+cuda cuda_arch={0}".format(arch),
-            type=("build", "run"),
-            when="cuda_arch={0}".format(arch),
-        )
     depends_on("py-ml-collections@0.1.0:", type=("build", "run"))
     depends_on("py-numpy@1.19.5:", type=("build", "run"), when="@2.1.1")
     depends_on("py-numpy@1.21.6:", type=("build", "run"), when="@2.2.4")
diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py
index 5753f33b2a..279e4a00f7 100644
--- a/var/spack/repos/builtin/packages/py-jax/package.py
+++ b/var/spack/repos/builtin/packages/py-jax/package.py
@@ -7,7 +7,7 @@
 from spack.package import *
 
 
-class PyJax(PythonPackage, CudaPackage):
+class PyJax(PythonPackage):
     """JAX is Autograd and XLA, brought together for high-performance
     machine learning research. With its updated version of Autograd,
     JAX can automatically differentiate native Python and NumPy
@@ -21,29 +21,29 @@ class PyJax(PythonPackage, CudaPackage):
     homepage = "https://github.com/google/jax"
     pypi = "jax/jax-0.2.25.tar.gz"
 
+    version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae")
     version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd")
     version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446")
 
-    variant("cuda", default=True, description="CUDA support")
-
-    depends_on("python@3.7:", type=("build", "run"))
+    depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
     depends_on("py-setuptools", type="build")
-    depends_on("py-numpy@1.18:", type=("build", "run"), when="@0.2.25")
-    depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.23")
+    depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
     depends_on("py-numpy@1.18:", type=("build", "run"))
-    depends_on("py-absl-py", type=("build", "run"))
     depends_on("py-opt-einsum", type=("build", "run"))
-    depends_on("py-scipy@1.2.1:", type=("build", "run"), when="@0.2.25")
-    depends_on("py-scipy@1.5:", type=("build", "run"), when="@0.3.23")
-    depends_on("py-typing-extensions", type=("build", "run"))
-    depends_on("py-etils+epath", type=("build", "run"), when="@0.3.23")
-    depends_on("py-jaxlib@0.3.15:", type=("build", "run"), when="@0.3.23~cuda")
-    depends_on("py-jaxlib@0.3.15:+cuda", type=("build", "run"), when="@0.3.23+cuda")
-    depends_on("py-jaxlib@0.1.69:", type=("build", "run"), when="@0.2.25~cuda")
-    depends_on("py-jaxlib@0.1.69:+cuda", type=("build", "run"), when="@0.2.25+cuda")
-    for arch in CudaPackage.cuda_arch_values:
-        depends_on(
-            "py-jaxlib+cuda cuda_arch={0}".format(arch),
-            type=("build", "run"),
-            when="cuda_arch={0}".format(arch),
-        )
+    depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run"))
+    depends_on("py-scipy@1.2.1:", type=("build", "run"))
+
+    # See _minimum_jaxlib_version in jax/version.py
+    jax_to_jaxlib = {
+        "0.4.3": "0.4.2",
+        "0.3.23": "0.3.15",
+        "0.2.25": "0.1.69",
+    }
+
+    for jax, jaxlib in jax_to_jaxlib.items():
+        depends_on(f"py-jaxlib@{jaxlib}:", when=f"@{jax}", type=("build", "run"))
+
+    # Historical dependencies
+    depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
+    depends_on("py-typing-extensions", when="@:0.3", type=("build", "run"))
+    depends_on("py-etils+epath", when="@0.3", type=("build", "run"))
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py
index 6e0e57e2a8..f2ebf205ad 100644
--- a/var/spack/repos/builtin/packages/py-jaxlib/package.py
+++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py
@@ -17,25 +17,55 @@ class PyJaxlib(PythonPackage, CudaPackage):
     tmp_path = ""
     buildtmp = ""
 
+    version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
     version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d")
     version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35")
 
-    # see jaxlib/setup.py for dependencies
-    depends_on("python@3.7:", type=("build", "run"))
-    depends_on("py-setuptools", type="build")
+    variant("cuda", default=True, description="Build with CUDA")
 
-    depends_on("py-numpy@1.18:", type=("build", "run"), when="@0.1.74")
-    depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.22")
+    # jaxlib/setup.py
+    depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
+    depends_on("py-setuptools", type="build")
+    depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
+    depends_on("py-numpy@1.18:", type=("build", "run"))
     depends_on("py-scipy@1.5:", type=("build", "run"))
-    depends_on("py-absl-py", type=("build", "run"))
-    depends_on("py-flatbuffers@1.12:2", type=("build", "run"), when="@0.1.74")
-    # Bazel 5 not yet supported: https://github.com/google/jax/issues/8440
-    depends_on("bazel@4.1.0:4", type=("build"), when="@0.1.74")
-    # Bazel 5 support starts here
-    depends_on("bazel@5.1.1:", type=("build"), when="@0.3.22")
+
+    # .bazelversion
+    depends_on("bazel@5.1.1:", when="@0.3:", type="build")
+    # https://github.com/google/jax/issues/8440
+    depends_on("bazel@4.1:4", when="@0.1", type="build")
+
+    # README.md
+    depends_on("cuda@11.4:", when="@0.4:+cuda")
+    depends_on("cuda@11.1:", when="@0.3+cuda")
+    # https://github.com/google/jax/issues/12614
+    depends_on("cuda@11.1:11.7.0", when="@0.1+cuda")
+    depends_on("cudnn@8.2:", when="@0.4:+cuda")
     depends_on("cudnn@8.0.5:", when="+cuda")
-    depends_on("cuda@11.1:11.7.0", when="@0.1.74+cuda")
-    depends_on("cuda@11.1:", when="@0.3.22+cuda")
+
+    # Historical dependencies
+    depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
+    depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run"))
+
+    def patch(self):
+        self.tmp_path = tempfile.mkdtemp(prefix="spack")
+        self.buildtmp = tempfile.mkdtemp(prefix="spack")
+        # triple quotes necessary because of a variety
+        # of other embedded quote(s)
+        filter_file(
+            """f"--output_path={output_path}",""",
+            """f"--output_path={output_path}","""
+            """f"--sources_path=%s","""
+            """f"--nohome_rc'","""
+            """f"--nosystem_rc'",""" % self.tmp_path,
+            "build/build.py",
+        )
+        filter_file(
+            "args = parser.parse_args()",
+            "args,junk = parser.parse_known_args()",
+            "build/build_wheel.py",
+            string=True,
+        )
 
     def install(self, spec, prefix):
         args = []
@@ -58,23 +88,3 @@ class PyJaxlib(PythonPackage, CudaPackage):
             pip(*args)
         remove_linked_tree(self.wrapped_package_object.tmp_path)
         remove_linked_tree(self.wrapped_package_object.buildtmp)
-
-    def patch(self):
-        self.tmp_path = tempfile.mkdtemp(prefix="spack")
-        self.buildtmp = tempfile.mkdtemp(prefix="spack")
-        # triple quotes necessary because of a variety
-        # of other embedded quote(s)
-        filter_file(
-            """f"--output_path={output_path}",""",
-            """f"--output_path={output_path}","""
-            """f"--sources_path=%s","""
-            """f"--nohome_rc'","""
-            """f"--nosystem_rc'",""" % self.tmp_path,
-            "build/build.py",
-        )
-        filter_file(
-            "args = parser.parse_args()",
-            "args,junk = parser.parse_known_args()",
-            "build/build_wheel.py",
-            string=True,
-        )
-- 
cgit v1.2.3-70-g09d2