From 16adda3db95fc9eceafbf8ee0969c1af89db419d Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 21 Feb 2023 12:14:27 -0700 Subject: py-jax: add v0.4.3 (#35460) * py-jax: add v0.4.3 * Minimum version is minimum * py-jax no longer has cuda variant * Enable CUDA by default * Link to discussion of upper bound --- .../repos/builtin/packages/py-alphafold/package.py | 6 -- var/spack/repos/builtin/packages/py-jax/package.py | 42 ++++++------ .../repos/builtin/packages/py-jaxlib/package.py | 76 ++++++++++++---------- 3 files changed, 64 insertions(+), 60 deletions(-) (limited to 'var') diff --git a/var/spack/repos/builtin/packages/py-alphafold/package.py b/var/spack/repos/builtin/packages/py-alphafold/package.py index 3356059a4d..ad95a4201b 100644 --- a/var/spack/repos/builtin/packages/py-alphafold/package.py +++ b/var/spack/repos/builtin/packages/py-alphafold/package.py @@ -37,12 +37,6 @@ class PyAlphafold(PythonPackage, CudaPackage): depends_on("py-immutabledict@2.0.0:", type=("build", "run")) depends_on("py-jax@0.2.14:", type=("build", "run"), when="@2.1.1") depends_on("py-jax@0.3.17:", type=("build", "run"), when="@2.2.4") - for arch in CudaPackage.cuda_arch_values: - depends_on( - "py-jax+cuda cuda_arch={0}".format(arch), - type=("build", "run"), - when="cuda_arch={0}".format(arch), - ) depends_on("py-ml-collections@0.1.0:", type=("build", "run")) depends_on("py-numpy@1.19.5:", type=("build", "run"), when="@2.1.1") depends_on("py-numpy@1.21.6:", type=("build", "run"), when="@2.2.4") diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py index 5753f33b2a..279e4a00f7 100644 --- a/var/spack/repos/builtin/packages/py-jax/package.py +++ b/var/spack/repos/builtin/packages/py-jax/package.py @@ -7,7 +7,7 @@ from spack.package import * -class PyJax(PythonPackage, CudaPackage): +class PyJax(PythonPackage): """JAX is Autograd and XLA, brought together for high-performance machine learning research. With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy @@ -21,29 +21,29 @@ class PyJax(PythonPackage, CudaPackage): homepage = "https://github.com/google/jax" pypi = "jax/jax-0.2.25.tar.gz" + version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae") version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd") version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446") - variant("cuda", default=True, description="CUDA support") - - depends_on("python@3.7:", type=("build", "run")) + depends_on("python@3.8:", when="@0.4:", type=("build", "run")) depends_on("py-setuptools", type="build") - depends_on("py-numpy@1.18:", type=("build", "run"), when="@0.2.25") - depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.23") + depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run")) depends_on("py-numpy@1.18:", type=("build", "run")) - depends_on("py-absl-py", type=("build", "run")) depends_on("py-opt-einsum", type=("build", "run")) - depends_on("py-scipy@1.2.1:", type=("build", "run"), when="@0.2.25") - depends_on("py-scipy@1.5:", type=("build", "run"), when="@0.3.23") - depends_on("py-typing-extensions", type=("build", "run")) - depends_on("py-etils+epath", type=("build", "run"), when="@0.3.23") - depends_on("py-jaxlib@0.3.15:", type=("build", "run"), when="@0.3.23~cuda") - depends_on("py-jaxlib@0.3.15:+cuda", type=("build", "run"), when="@0.3.23+cuda") - depends_on("py-jaxlib@0.1.69:", type=("build", "run"), when="@0.2.25~cuda") - depends_on("py-jaxlib@0.1.69:+cuda", type=("build", "run"), when="@0.2.25+cuda") - for arch in CudaPackage.cuda_arch_values: - depends_on( - "py-jaxlib+cuda cuda_arch={0}".format(arch), - type=("build", "run"), - when="cuda_arch={0}".format(arch), - ) + depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run")) + depends_on("py-scipy@1.2.1:", type=("build", "run")) + + # See _minimum_jaxlib_version in jax/version.py + jax_to_jaxlib = { + "0.4.3": "0.4.2", + "0.3.23": "0.3.15", + "0.2.25": "0.1.69", + } + + for jax, jaxlib in jax_to_jaxlib.items(): + depends_on(f"py-jaxlib@{jaxlib}:", when=f"@{jax}", type=("build", "run")) + + # Historical dependencies + depends_on("py-absl-py", when="@:0.3", type=("build", "run")) + depends_on("py-typing-extensions", when="@:0.3", type=("build", "run")) + depends_on("py-etils+epath", when="@0.3", type=("build", "run")) diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index 6e0e57e2a8..f2ebf205ad 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -17,25 +17,55 @@ class PyJaxlib(PythonPackage, CudaPackage): tmp_path = "" buildtmp = "" + version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910") version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d") version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35") - # see jaxlib/setup.py for dependencies - depends_on("python@3.7:", type=("build", "run")) - depends_on("py-setuptools", type="build") + variant("cuda", default=True, description="Build with CUDA") - depends_on("py-numpy@1.18:", type=("build", "run"), when="@0.1.74") - depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.22") + # jaxlib/setup.py + depends_on("python@3.8:", when="@0.4:", type=("build", "run")) + depends_on("py-setuptools", type="build") + depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run")) + depends_on("py-numpy@1.18:", type=("build", "run")) depends_on("py-scipy@1.5:", type=("build", "run")) - depends_on("py-absl-py", type=("build", "run")) - depends_on("py-flatbuffers@1.12:2", type=("build", "run"), when="@0.1.74") - # Bazel 5 not yet supported: https://github.com/google/jax/issues/8440 - depends_on("bazel@4.1.0:4", type=("build"), when="@0.1.74") - # Bazel 5 support starts here - depends_on("bazel@5.1.1:", type=("build"), when="@0.3.22") + + # .bazelversion + depends_on("bazel@5.1.1:", when="@0.3:", type="build") + # https://github.com/google/jax/issues/8440 + depends_on("bazel@4.1:4", when="@0.1", type="build") + + # README.md + depends_on("cuda@11.4:", when="@0.4:+cuda") + depends_on("cuda@11.1:", when="@0.3+cuda") + # https://github.com/google/jax/issues/12614 + depends_on("cuda@11.1:11.7.0", when="@0.1+cuda") + depends_on("cudnn@8.2:", when="@0.4:+cuda") depends_on("cudnn@8.0.5:", when="+cuda") - depends_on("cuda@11.1:11.7.0", when="@0.1.74+cuda") - depends_on("cuda@11.1:", when="@0.3.22+cuda") + + # Historical dependencies + depends_on("py-absl-py", when="@:0.3", type=("build", "run")) + depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run")) + + def patch(self): + self.tmp_path = tempfile.mkdtemp(prefix="spack") + self.buildtmp = tempfile.mkdtemp(prefix="spack") + # triple quotes necessary because of a variety + # of other embedded quote(s) + filter_file( + """f"--output_path={output_path}",""", + """f"--output_path={output_path}",""" + """f"--sources_path=%s",""" + """f"--nohome_rc'",""" + """f"--nosystem_rc'",""" % self.tmp_path, + "build/build.py", + ) + filter_file( + "args = parser.parse_args()", + "args,junk = parser.parse_known_args()", + "build/build_wheel.py", + string=True, + ) def install(self, spec, prefix): args = [] @@ -58,23 +88,3 @@ class PyJaxlib(PythonPackage, CudaPackage): pip(*args) remove_linked_tree(self.wrapped_package_object.tmp_path) remove_linked_tree(self.wrapped_package_object.buildtmp) - - def patch(self): - self.tmp_path = tempfile.mkdtemp(prefix="spack") - self.buildtmp = tempfile.mkdtemp(prefix="spack") - # triple quotes necessary because of a variety - # of other embedded quote(s) - filter_file( - """f"--output_path={output_path}",""", - """f"--output_path={output_path}",""" - """f"--sources_path=%s",""" - """f"--nohome_rc'",""" - """f"--nosystem_rc'",""" % self.tmp_path, - "build/build.py", - ) - filter_file( - "args = parser.parse_args()", - "args,junk = parser.parse_known_args()", - "build/build_wheel.py", - string=True, - ) -- cgit v1.2.3-60-g2f50