summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-jaxlib/package.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py
index 951aa4d9d3..fcd624cdd2 100644
--- a/var/spack/repos/builtin/packages/py-jaxlib/package.py
+++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py
@@ -99,6 +99,12 @@ class PyJaxlib(PythonPackage, CudaPackage):
depends_on("py-numpy@:1", when="@:0.4.25")
depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
+ patch(
+ "https://github.com/google/jax/pull/20101.patch?full_index=1",
+ sha256="4dfb9f32d4eeb0a0fb3a6f4124c4170e3fe49511f1b768cd634c78d489962275",
+ when="@:0.4.25",
+ )
+
conflicts(
"cuda_arch=none",
when="+cuda",