From c7afc0eb5ffa7564e2f6f9bdc6f178c391d264c6 Mon Sep 17 00:00:00 2001 From: Jonas Eschle Date: Wed, 28 Feb 2024 13:29:23 -0500 Subject: Upgrade TensorFlow Probability with newer versions (#42673) * enh: add newer versions * enh: add newer versions * format * fix typo * Update package.py * make jax and TF optional dependencies * style fix * remove dependency * remove old TFP version * fix: style --- .../packages/py-tensorflow-probability/package.py | 49 ++++++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/var/spack/repos/builtin/packages/py-tensorflow-probability/package.py b/var/spack/repos/builtin/packages/py-tensorflow-probability/package.py index e73f4cc4c2..053405e09f 100644 --- a/var/spack/repos/builtin/packages/py-tensorflow-probability/package.py +++ b/var/spack/repos/builtin/packages/py-tensorflow-probability/package.py @@ -19,19 +19,28 @@ class PyTensorflowProbability(Package): homepage = "https://www.tensorflow.org/probability" url = "https://github.com/tensorflow/probability/archive/v0.12.1.tar.gz" - maintainers("aweits") + maintainers("aweits", "jonas-eschle") license("Apache-2.0") + # TODO: reactivate once TF 2.15 is ready https://github.com/spack/spack/pull/41069 + # version("0.23.0", sha256="a00769550da9284acbd69e32a005507153ad39b0c190feca2bbbf6373366cc14") + version("0.22.1", sha256="9c1203b454aaeb48ac67dea862a411dba6b04f67c1e874e0e83bd1d7f13829a3") + version("0.22.0", sha256="f9ce55b00c8069246d701c04eaafccde413355f6e76ccf9e549772ecfa0349a4") + version("0.21.0", sha256="69b7510b38b2e48bcfb9ff570ef598d489e4f1bcbe13276f5dd91c878b8d56d1") + version("0.20.0", sha256="f0fb9a1f88a36a8f57d4d9cce4f9bf8dfacb6fc7778751729fe3c3067e5a1363") + version("0.19.0", sha256="b32d2ae211ec727df9791b501839619f5389134bd6d4fe951570f500b0e75f55") version("0.18.0", sha256="f4852c0fea9117333ccb868f7a2ca75aecf5dd765dc39fd4ee5f8ab6fe87e909") - version("0.12.1", sha256="1fe89e85fd053bf36e8645a5a1a53b729bc254cf1516bc224fcbd1e4ff50083a") version( - "0.8.0", - sha256="f6049549f6d6b82962523a6bf61c40c1d0c7ac685f209c0084a6da81dd78181d", - url="https://github.com/tensorflow/probability/archive/0.8.0.tar.gz", + "0.12.1", + sha256="1fe89e85fd053bf36e8645a5a1a53b729bc254cf1516bc224fcbd1e4ff50083a", + deprecated=True, ) - extends("python") + extends("python@3.9:", when="@0.22:") + extends("python@3.8:", when="@0.20:0.21") + extends("python@3.7:", when="@0.13:0.19") + extends("python@3.6:", when="@0.8:0.12") depends_on("py-pip", type="build") depends_on("py-wheel", type="build") depends_on("py-setuptools", type="build") @@ -48,9 +57,31 @@ class PyTensorflowProbability(Package): depends_on("py-dm-tree", when="@0.12:", type=("build", "run")) # tensorflow_probability/python/__init__.py - depends_on("py-tensorflow@2.10:", when="@0.18:", type=("build", "run")) - depends_on("py-tensorflow@2.4:", when="@0.12:", type=("build", "run")) - depends_on("py-tensorflow@1.14:", when="@0.8:", type=("build", "run")) + # TODO: reactivate the JAX versions once the JAX package is available with newer versions + # also add jaxlib as a dependency + # TODO: reactivate once TF 2.15 is ready https://github.com/spack/spack/pull/41069 + + variant("py-tensorflow", default=False, description="Build with TensorFlow support") + with when("+py-tensorflow"): + # depends_on("py-tensorflow@2.15", when="@0.23", type=("build", "run")) + depends_on("py-tensorflow@2.14:2", when="@0.22", type=("build", "run")) + depends_on("py-tensorflow@2.13:2", when="@0.21", type=("build", "run")) + depends_on("py-tensorflow@2.12:2", when="@0.20", type=("build", "run")) + depends_on("py-tensorflow@2.11:2", when="@0.19", type=("build", "run")) + + # jaxlib is not required, as it's already a dependency of py-jax + variant("py-jax", default=False, description="Build with JAX support") + with when("+py-jax"): # TODO: reactivate once the JAX package is available with newer versions + # depends_on("py-jax@0.4.20:0.4", when="@0.23", type=("build", "run")) + # depends_on("py-jax@0.4.16:0.4", when="@0.22", type=("build", "run")) + # depends_on("py-jax@0.4.14:0.4", when="@0.21", type=("build", "run")) + # depends_on("py-jax@0.4.8:0.4", when="@0.20", type=("build", "run")) + depends_on("py-jax@0.3.25:3", when="@0.19", type=("build", "run")) + + depends_on( + "py-tensorflow@2.10:2", when="@0.18", type=("build", "run") + ) # keep here for backwards compatibility + depends_on("py-tensorflow@2.4:", when="@0.12:0.17", type=("build", "run")) depends_on("bazel@3.2:", type="build") -- cgit v1.2.3-70-g09d2