diff options
author | afzpatel <122491982+afzpatel@users.noreply.github.com> | 2024-09-24 18:38:58 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-24 16:38:58 -0600 |
commit | a474034023b177abcc9519dc14fd8da19dc09f39 (patch) | |
tree | a6bbe5ff6db7a13cb9706694d5346294bd47dcd7 /var | |
parent | 022eca1cfe5156c1552e62d08207c61b75924630 (diff) | |
download | spack-a474034023b177abcc9519dc14fd8da19dc09f39.tar.gz spack-a474034023b177abcc9519dc14fd8da19dc09f39.tar.bz2 spack-a474034023b177abcc9519dc14fd8da19dc09f39.tar.xz spack-a474034023b177abcc9519dc14fd8da19dc09f39.zip |
py-jaxlib: add external ROCm support (#46467)
* add external ROCm support for py-jaxlib
* fix style
* remove fork releases
Diffstat (limited to 'var')
-rw-r--r-- | var/spack/repos/builtin/packages/py-jaxlib/package.py | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index fcd624cdd2..09eb522c56 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -7,8 +7,25 @@ import tempfile from spack.package import * - -class PyJaxlib(PythonPackage, CudaPackage): +rocm_dependencies = [ + "hsa-rocr-dev", + "hip", + "rccl", + "rocprim", + "hipcub", + "rocthrust", + "roctracer-dev", + "rocrand", + "hipsparse", + "hipfft", + "rocfft", + "rocblas", + "miopen-hip", + "rocminfo", +] + + +class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): """XLA library for Jax""" homepage = "https://github.com/google/jax" @@ -62,6 +79,12 @@ class PyJaxlib(PythonPackage, CudaPackage): depends_on("nccl@2.16:", when="@0.4.18:") depends_on("nccl") + with when("+rocm"): + for pkg_dep in rocm_dependencies: + depends_on(f"{pkg_dep}@6:", when="@0.4.28:") + depends_on(pkg_dep) + depends_on("py-nanobind") + with default_args(type="build"): # .bazelversion depends_on("bazel@6.5.0", when="@0.4.28:") @@ -161,6 +184,10 @@ build --local_cpu_resources={make_jobs} "--bazel_startup_options=" "--output_user_root={0}".format(self.wrapped_package_object.buildtmp) ) + if "+rocm" in spec: + args.append("--enable_rocm") + args.append("--rocm_path={0}".format(self.spec["hip"].prefix)) + python(*args) with working_dir(self.wrapped_package_object.tmp_path): args = std_pip_args + ["--prefix=" + self.prefix, "."] |