diff options
-rw-r--r-- | var/spack/repos/builtin/packages/py-jaxlib/package.py | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index 951aa4d9d3..fcd624cdd2 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -99,6 +99,12 @@ class PyJaxlib(PythonPackage, CudaPackage): depends_on("py-numpy@:1", when="@:0.4.25") depends_on("py-ml-dtypes@0.4:", when="@0.4.29") + patch( + "https://github.com/google/jax/pull/20101.patch?full_index=1", + sha256="4dfb9f32d4eeb0a0fb3a6f4124c4170e3fe49511f1b768cd634c78d489962275", + when="@:0.4.25", + ) + conflicts( "cuda_arch=none", when="+cuda", |