From 8e5a04098505d58d8b2997487a1893e7370926e2 Mon Sep 17 00:00:00 2001 From: afzpatel <122491982+afzpatel@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:43:53 -0500 Subject: ucc: add ROCm and rccl support (#46580) --- var/spack/repos/builtin/packages/ucc/package.py | 30 ++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/var/spack/repos/builtin/packages/ucc/package.py b/var/spack/repos/builtin/packages/ucc/package.py index fa612b3722..c1427a62a3 100644 --- a/var/spack/repos/builtin/packages/ucc/package.py +++ b/var/spack/repos/builtin/packages/ucc/package.py @@ -5,7 +5,7 @@ from spack.package import * -class Ucc(AutotoolsPackage, CudaPackage): +class Ucc(AutotoolsPackage, CudaPackage, ROCmPackage): """UCC is a collective communication operations API and library that is flexible, complete, and feature-rich for current and emerging programming models and runtimes.""" @@ -23,8 +23,7 @@ class Ucc(AutotoolsPackage, CudaPackage): variant("cuda", default=False, description="Enable CUDA TL") variant("nccl", default=False, description="Enable NCCL TL", when="+cuda") - # RCCL build not tested - # variant("rccl", default=False, description="Enable RCCL TL") + variant("rccl", default=False, description="Enable RCCL TL", when="+rocm") # https://github.com/openucx/ucc/pull/847 patch( @@ -40,7 +39,7 @@ class Ucc(AutotoolsPackage, CudaPackage): depends_on("ucx") depends_on("nccl", when="+nccl") - # depends_on("rccl", when="+rccl") + depends_on("rccl", when="+rccl") with when("+nccl"): for arch in CudaPackage.cuda_arch_values: @@ -55,5 +54,26 @@ class Ucc(AutotoolsPackage, CudaPackage): args = [] args.extend(self.with_or_without("cuda", activation_value="prefix")) args.extend(self.with_or_without("nccl", activation_value="prefix")) - # args.extend(self.with_or_without("rccl", activation_value="prefix")) + if self.spec.satisfies("+rocm"): + cppflags = " ".join( + "-I" + include_dir + for include_dir in ( + self.spec["hip"].prefix.include, + self.spec["hip"].prefix.include.hip, + self.spec["hsa-rocr-dev"].prefix.include.hsa, + ) + ) + ldflags = " ".join( + "-L" + library_dir + for library_dir in ( + self.spec["hip"].prefix.lib, + self.spec["hsa-rocr-dev"].prefix.lib, + ) + ) + args.extend(["CPPFLAGS=" + cppflags, "LDFLAGS=" + ldflags]) + args.append("--with-rocm=" + self.spec["hip"].prefix) + args.append("--with-ucx=" + self.spec["ucx"].prefix) + args.extend(self.with_or_without("rccl", activation_value="prefix")) + else: + args.append("--without-rocm") return args -- cgit v1.2.3-70-g09d2