diff options
-rw-r--r-- | var/spack/repos/builtin/packages/py-jaxlib/package.py | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index fcd624cdd2..09eb522c56 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -7,8 +7,25 @@ import tempfile from spack.package import * - -class PyJaxlib(PythonPackage, CudaPackage): +rocm_dependencies = [ + "hsa-rocr-dev", + "hip", + "rccl", + "rocprim", + "hipcub", + "rocthrust", + "roctracer-dev", + "rocrand", + "hipsparse", + "hipfft", + "rocfft", + "rocblas", + "miopen-hip", + "rocminfo", +] + + +class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): """XLA library for Jax""" homepage = "https://github.com/google/jax" @@ -62,6 +79,12 @@ class PyJaxlib(PythonPackage, CudaPackage): depends_on("nccl@2.16:", when="@0.4.18:") depends_on("nccl") + with when("+rocm"): + for pkg_dep in rocm_dependencies: + depends_on(f"{pkg_dep}@6:", when="@0.4.28:") + depends_on(pkg_dep) + depends_on("py-nanobind") + with default_args(type="build"): # .bazelversion depends_on("bazel@6.5.0", when="@0.4.28:") @@ -161,6 +184,10 @@ build --local_cpu_resources={make_jobs} "--bazel_startup_options=" "--output_user_root={0}".format(self.wrapped_package_object.buildtmp) ) + if "+rocm" in spec: + args.append("--enable_rocm") + args.append("--rocm_path={0}".format(self.spec["hip"].prefix)) + python(*args) with working_dir(self.wrapped_package_object.tmp_path): args = std_pip_args + ["--prefix=" + self.prefix, "."] |