summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--share/spack/gitlab/cloud_pipelines/stacks/ml-cpu/spack.yaml5
-rw-r--r--share/spack/gitlab/cloud_pipelines/stacks/ml-cuda/spack.yaml5
-rw-r--r--share/spack/gitlab/cloud_pipelines/stacks/ml-rocm/spack.yaml5
-rw-r--r--var/spack/repos/builtin/packages/openmm/package.py6
-rw-r--r--var/spack/repos/builtin/packages/py-alphafold/package.py22
-rw-r--r--var/spack/repos/builtin/packages/py-dm-haiku/package.py4
-rw-r--r--var/spack/repos/builtin/packages/py-etils/package.py26
-rw-r--r--var/spack/repos/builtin/packages/py-jax/package.py15
-rw-r--r--var/spack/repos/builtin/packages/py-jaxlib/package.py34
-rw-r--r--var/spack/repos/builtin/packages/py-tensorflow/package.py2
10 files changed, 92 insertions, 32 deletions
diff --git a/share/spack/gitlab/cloud_pipelines/stacks/ml-cpu/spack.yaml b/share/spack/gitlab/cloud_pipelines/stacks/ml-cpu/spack.yaml
index d4e86fd501..b5d989c904 100644
--- a/share/spack/gitlab/cloud_pipelines/stacks/ml-cpu/spack.yaml
+++ b/share/spack/gitlab/cloud_pipelines/stacks/ml-cpu/spack.yaml
@@ -28,9 +28,8 @@ spack:
- py-transformers
# JAX
- # https://github.com/google/jax/issues/12614
- # - py-jax
- # - py-jaxlib
+ - py-jax
+ - py-jaxlib
# Keras
- py-keras
diff --git a/share/spack/gitlab/cloud_pipelines/stacks/ml-cuda/spack.yaml b/share/spack/gitlab/cloud_pipelines/stacks/ml-cuda/spack.yaml
index f04b96ce93..1ea78372cb 100644
--- a/share/spack/gitlab/cloud_pipelines/stacks/ml-cuda/spack.yaml
+++ b/share/spack/gitlab/cloud_pipelines/stacks/ml-cuda/spack.yaml
@@ -31,9 +31,8 @@ spack:
- py-transformers
# JAX
- # https://github.com/google/jax/issues/12614
- # - py-jax
- # - py-jaxlib
+ - py-jax
+ - py-jaxlib
# Keras
- py-keras
diff --git a/share/spack/gitlab/cloud_pipelines/stacks/ml-rocm/spack.yaml b/share/spack/gitlab/cloud_pipelines/stacks/ml-rocm/spack.yaml
index 4c44e881d7..2d728b501e 100644
--- a/share/spack/gitlab/cloud_pipelines/stacks/ml-rocm/spack.yaml
+++ b/share/spack/gitlab/cloud_pipelines/stacks/ml-rocm/spack.yaml
@@ -33,9 +33,8 @@ spack:
- py-transformers
# JAX
- # https://github.com/google/jax/issues/12614
- # - py-jax
- # - py-jaxlib
+ - py-jax
+ - py-jaxlib
# Keras
- py-keras
diff --git a/var/spack/repos/builtin/packages/openmm/package.py b/var/spack/repos/builtin/packages/openmm/package.py
index 7181c61b41..97a87e79b4 100644
--- a/var/spack/repos/builtin/packages/openmm/package.py
+++ b/var/spack/repos/builtin/packages/openmm/package.py
@@ -19,6 +19,7 @@ class Openmm(CMakePackage, CudaPackage):
homepage = "https://openmm.org/"
url = "https://github.com/openmm/openmm/archive/7.4.1.tar.gz"
+ version("7.7.0", sha256="51970779b8dc639ea192e9c61c67f70189aa294575acb915e14be1670a586c25")
version("7.6.0", sha256="5a99c491ded9ba83ecc3fb1d8d22fca550f45da92e14f64f25378fda0048a89d")
version("7.5.1", sha256="c88d6946468a2bde2619acb834f57b859b5e114a93093cf562165612e10f4ff7")
version("7.5.0", sha256="516748b4f1ae936c4d70cc6401174fc9384244c65cd3aef27bc2c53eac6d6de5")
@@ -27,8 +28,11 @@ class Openmm(CMakePackage, CudaPackage):
install_targets = ["install", "PythonInstall"]
depends_on("python@2.7:", type=("build", "run"))
+ depends_on("cmake@3.17:", type="build", when="@7.6.0:")
depends_on("cmake@3.1:", type="build")
- depends_on("doxygen", type="build")
+ # https://github.com/openmm/openmm/issues/3317
+ depends_on("doxygen@:1.9.1", type="build", when="@:7.6.0")
+ depends_on("doxygen", type="build", when="@7.7:")
depends_on("swig", type="build")
depends_on("fftw")
depends_on("py-cython", type="build")
diff --git a/var/spack/repos/builtin/packages/py-alphafold/package.py b/var/spack/repos/builtin/packages/py-alphafold/package.py
index b5a38f8b46..598debac18 100644
--- a/var/spack/repos/builtin/packages/py-alphafold/package.py
+++ b/var/spack/repos/builtin/packages/py-alphafold/package.py
@@ -17,31 +17,41 @@ class PyAlphafold(PythonPackage, CudaPackage):
url = "https://github.com/deepmind/alphafold/archive/refs/tags/v2.1.1.tar.gz"
maintainers = ["aweits"]
+ version("2.2.4", sha256="8d756e16f6dc7897331d834aade8493820d0ff6a03bf60ce511bac4756c1b1e8")
version("2.1.1", sha256="1adb6e213ba9ac321fc1acb1c563ba9b4fc054c1cebe1191bc0e2aaa671dadf7")
conflicts("platform=darwin", msg="alphafold is only supported on Linux")
+ # lots of hints on versions and patching taken from docker/Dockerfile
+ # and requirements.txt
depends_on("python@3.7:", type=("build", "run"))
depends_on("py-setuptools", type="build")
- depends_on("py-absl-py@0.13.0:", type=("build", "run"))
+ depends_on("py-absl-py@0.13.0:", type=("build", "run"), when="@2.1.1")
+ depends_on("py-absl-py@1.0.0:", type=("build", "run"), when="@2.2.4")
depends_on("py-biopython@1.79:", type=("build", "run"))
depends_on("py-chex@0.0.7:", type=("build", "run"))
- depends_on("py-dm-haiku@0.0.4:", type=("build", "run"))
+ depends_on("py-dm-haiku@0.0.4:", type=("build", "run"), when="@2.1.1")
+ depends_on("py-dm-haiku@0.0.7:", type=("build", "run"), when="@2.2.4")
depends_on("py-dm-tree@0.1.6:", type=("build", "run"))
+ depends_on("py-docker", type=("build", "run"))
depends_on("py-immutabledict@2.0.0:", type=("build", "run"))
- depends_on("py-jax@0.2.14:", type=("build", "run"))
+ depends_on("py-jax@0.2.14:", type=("build", "run"), when="@2.1.1")
+ depends_on("py-jax@0.3.17:", type=("build", "run"), when="@2.2.4")
for arch in CudaPackage.cuda_arch_values:
depends_on(
- "py-jax@0.2.14:+cuda cuda_arch={0}".format(arch),
+ "py-jax+cuda cuda_arch={0}".format(arch),
type=("build", "run"),
when="cuda_arch={0}".format(arch),
)
depends_on("py-ml-collections@0.1.0:", type=("build", "run"))
- depends_on("py-numpy@1.19.5:", type=("build", "run"))
+ depends_on("py-numpy@1.19.5:", type=("build", "run"), when="@2.1.1")
+ depends_on("py-numpy@1.21.6:", type=("build", "run"), when="@2.2.4")
depends_on("py-pandas@1.3.4:", type=("build", "run"))
+ depends_on("py-protobuf@3.19:", type=("build", "run"), when="@2.2.4")
depends_on("py-scipy@1.7.0:", type=("build", "run"))
depends_on("py-pdbfixer@1.7", type=("build", "run"))
- depends_on("py-tensorflow@2.5:", type=("build", "run"))
+ depends_on("py-tensorflow@2.5:", type=("build", "run"), when="@2.1.1")
+ depends_on("py-tensorflow@2.9:", type=("build", "run"), when="@2.2.4")
depends_on(
"openmm@7.5.1+cuda",
type="run",
diff --git a/var/spack/repos/builtin/packages/py-dm-haiku/package.py b/var/spack/repos/builtin/packages/py-dm-haiku/package.py
index 478b2301de..d8ca5ed112 100644
--- a/var/spack/repos/builtin/packages/py-dm-haiku/package.py
+++ b/var/spack/repos/builtin/packages/py-dm-haiku/package.py
@@ -13,6 +13,7 @@ class PyDmHaiku(PythonPackage):
homepage = "https://github.com/deepmind/dm-haiku"
pypi = "dm-haiku/dm-haiku-0.0.5.tar.gz"
+ version("0.0.7", sha256="86c34af6952a305a4bbfda6b9925998577acc4aa2ad9333da3d6047f4f8ed7c1")
version("0.0.5", sha256="e986237e1f840aa3bd26217ecad84b611bf1456e2139f0f79ea71f9c6222d231")
depends_on("python@3.7:", type=("build", "run"))
depends_on("py-setuptools", type="build")
@@ -21,4 +22,7 @@ class PyDmHaiku(PythonPackage):
depends_on("py-numpy@1.18.0:", type=("build", "run"))
depends_on("py-tabulate@0.8.9:", type=("build", "run"))
depends_on("py-typing-extensions", when="^python@:3.7", type=("build", "run"))
+ # from README.md:
+ # Because JAX installation is different depending on your CUDA version, Haiku does
+ # not list JAX as a dependency in `requirements.txt`.
depends_on("py-jax", type=("build", "run"))
diff --git a/var/spack/repos/builtin/packages/py-etils/package.py b/var/spack/repos/builtin/packages/py-etils/package.py
new file mode 100644
index 0000000000..19907873e1
--- /dev/null
+++ b/var/spack/repos/builtin/packages/py-etils/package.py
@@ -0,0 +1,26 @@
+# Copyright 2013-2022 Lawrence Livermore National Security, LLC and other
+# Spack Project Developers. See the top-level COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+from spack.package import *
+
+
+class PyEtils(PythonPackage):
+ """etils (eclectic utils) is an open-source collection of utils
+ for python."""
+
+ homepage = "https://github.com/google/etils"
+ pypi = "etils/etils-0.9.0.tar.gz"
+
+ version("0.9.0", sha256="489103e9e499a566765c60458ee15d185cf0065f2060a4d16a68f8f46962ed0d")
+
+ variant("epath", default=False, description="with epath module")
+
+ depends_on("python@3.7:", type=("build", "run"))
+
+ depends_on("py-importlib-resources", type=("build", "run"), when="+epath")
+ depends_on("py-typing-extensions", type=("build", "run"), when="+epath")
+ depends_on("py-zipp", type=("build", "run"), when="+epath")
+
+ depends_on("py-flit-core@3.5:3", type="build")
diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py
index 7511f79701..eaefda3d7d 100644
--- a/var/spack/repos/builtin/packages/py-jax/package.py
+++ b/var/spack/repos/builtin/packages/py-jax/package.py
@@ -21,22 +21,29 @@ class PyJax(PythonPackage, CudaPackage):
homepage = "https://github.com/google/jax"
pypi = "jax/jax-0.2.25.tar.gz"
+ version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd")
version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446")
variant("cuda", default=True, description="CUDA support")
depends_on("python@3.7:", type=("build", "run"))
depends_on("py-setuptools", type="build")
+ depends_on("py-numpy@1.18:", type=("build", "run"), when="@0.2.25")
+ depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.23")
depends_on("py-numpy@1.18:", type=("build", "run"))
depends_on("py-absl-py", type=("build", "run"))
depends_on("py-opt-einsum", type=("build", "run"))
- depends_on("py-scipy@1.2.1:", type=("build", "run"))
+ depends_on("py-scipy@1.2.1:", type=("build", "run"), when="@0.2.25")
+ depends_on("py-scipy@1.5:", type=("build", "run"), when="@0.3.23")
depends_on("py-typing-extensions", type=("build", "run"))
- depends_on("py-jaxlib@0.1.69:", type=("build", "run"), when="~cuda")
- depends_on("py-jaxlib@0.1.69:+cuda", type=("build", "run"), when="+cuda")
+ depends_on("py-etils+epath", type=("build", "run"), when="@0.3.23")
+ depends_on("py-jaxlib@0.3.15:", type=("build", "run"), when="@0.3.23~cuda")
+ depends_on("py-jaxlib@0.3.15:+cuda", type=("build", "run"), when="@0.3.23+cuda")
+ depends_on("py-jaxlib@0.1.69:", type=("build", "run"), when="@0.2.25~cuda")
+ depends_on("py-jaxlib@0.1.69:+cuda", type=("build", "run"), when="@0.2.25+cuda")
for arch in CudaPackage.cuda_arch_values:
depends_on(
- "py-jaxlib@0.1.69:+cuda cuda_arch={0}".format(arch),
+ "py-jaxlib+cuda cuda_arch={0}".format(arch),
type=("build", "run"),
when="cuda_arch={0}".format(arch),
)
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py
index 87da9fb763..297ddb064c 100644
--- a/var/spack/repos/builtin/packages/py-jaxlib/package.py
+++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py
@@ -5,8 +5,6 @@
import tempfile
-import llnl.util.tty as tty
-
from spack.package import *
@@ -17,19 +15,27 @@ class PyJaxlib(PythonPackage, CudaPackage):
url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz"
tmp_path = ""
+ buildtmp = ""
+ version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d")
version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35")
+ # see jaxlib/setup.py for dependencies
depends_on("python@3.7:", type=("build", "run"))
depends_on("py-setuptools", type="build")
- depends_on("py-numpy@1.18:", type=("build", "run"))
- depends_on("py-scipy", type=("build", "run"))
+
+ depends_on("py-numpy@1.18:", type=("build", "run"), when="@0.1.74")
+ depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.22")
+ depends_on("py-scipy@1.5:", type=("build", "run"))
depends_on("py-absl-py", type=("build", "run"))
- depends_on("py-flatbuffers@1.12:2", type=("build", "run"))
+ depends_on("py-flatbuffers@1.12:2", type=("build", "run"), when="@0.1.74")
# Bazel 5 not yet supported: https://github.com/google/jax/issues/8440
- depends_on("bazel@4.1.0:4", type=("build"))
+ depends_on("bazel@4.1.0:4", type=("build"), when="@0.1.74")
+ # Bazel 5 support starts here
+ depends_on("bazel@5.1.1:", type=("build"), when="@0.3.22")
depends_on("cudnn@8.0.5:", when="+cuda")
- depends_on("cuda@11.1:", when="+cuda")
+ depends_on("cuda@11.1:11.7.0", when="@0.1.74+cuda")
+ depends_on("cuda@11.1:", when="@0.3.22+cuda")
def install(self, spec, prefix):
args = []
@@ -42,18 +48,22 @@ class PyJaxlib(PythonPackage, CudaPackage):
"{0:.1f}".format(float(i) / 10.0) for i in spec.variants["cuda_arch"].value
)
args.append("--cuda_compute_capabilities={0}".format(capabilities))
- args.append("--bazel_startup_options=" "--output_user_root={0}".format(self.buildtmp))
+ args.append(
+ "--bazel_startup_options="
+ "--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
+ )
python(*args)
- with working_dir(self.tmp_path):
- tty.warn("in dir " + self.tmp_path)
+ with working_dir(self.wrapped_package_object.tmp_path):
args = std_pip_args + ["--prefix=" + self.prefix, "."]
pip(*args)
- remove_linked_tree(self.tmp_path)
- remove_linked_tree(self.buildtmp)
+ remove_linked_tree(self.wrapped_package_object.tmp_path)
+ remove_linked_tree(self.wrapped_package_object.buildtmp)
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")
self.buildtmp = tempfile.mkdtemp(prefix="spack")
+ # triple quotes necessary because of a variety
+ # of other embedded quote(s)
filter_file(
"""f"--output_path={output_path}",""",
"""f"--output_path={output_path}","""
diff --git a/var/spack/repos/builtin/packages/py-tensorflow/package.py b/var/spack/repos/builtin/packages/py-tensorflow/package.py
index f9f8aa04e4..1598f8c67b 100644
--- a/var/spack/repos/builtin/packages/py-tensorflow/package.py
+++ b/var/spack/repos/builtin/packages/py-tensorflow/package.py
@@ -321,6 +321,8 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage):
# depends_on('trisycl', when='+opencl~computepp')
depends_on("cuda@:10.2", when="+cuda @:2.3")
depends_on("cuda@:11.4", when="+cuda @2.4:2.7")
+ # avoid problem fixed by commit a76f797b9cd4b9b15bec4c503b16236a804f676f
+ depends_on("cuda@:11.7.0", when="+cuda @:2.9")
depends_on("cudnn", when="+cuda")
depends_on("cudnn@:7", when="@:2.2 +cuda")
# depends_on('tensorrt', when='+tensorrt')