summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--var/spack/repos/builtin/packages/py-jaxlib/package.py31
1 files changed, 29 insertions, 2 deletions
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py
index fcd624cdd2..09eb522c56 100644
--- a/var/spack/repos/builtin/packages/py-jaxlib/package.py
+++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py
@@ -7,8 +7,25 @@ import tempfile
from spack.package import *
-
-class PyJaxlib(PythonPackage, CudaPackage):
+rocm_dependencies = [
+ "hsa-rocr-dev",
+ "hip",
+ "rccl",
+ "rocprim",
+ "hipcub",
+ "rocthrust",
+ "roctracer-dev",
+ "rocrand",
+ "hipsparse",
+ "hipfft",
+ "rocfft",
+ "rocblas",
+ "miopen-hip",
+ "rocminfo",
+]
+
+
+class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
"""XLA library for Jax"""
homepage = "https://github.com/google/jax"
@@ -62,6 +79,12 @@ class PyJaxlib(PythonPackage, CudaPackage):
depends_on("nccl@2.16:", when="@0.4.18:")
depends_on("nccl")
+ with when("+rocm"):
+ for pkg_dep in rocm_dependencies:
+ depends_on(f"{pkg_dep}@6:", when="@0.4.28:")
+ depends_on(pkg_dep)
+ depends_on("py-nanobind")
+
with default_args(type="build"):
# .bazelversion
depends_on("bazel@6.5.0", when="@0.4.28:")
@@ -161,6 +184,10 @@ build --local_cpu_resources={make_jobs}
"--bazel_startup_options="
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
)
+ if "+rocm" in spec:
+ args.append("--enable_rocm")
+ args.append("--rocm_path={0}".format(self.spec["hip"].prefix))
+
python(*args)
with working_dir(self.wrapped_package_object.tmp_path):
args = std_pip_args + ["--prefix=" + self.prefix, "."]