diff options
author | Adam J. Stewart <ajstewart426@gmail.com> | 2022-09-29 04:01:32 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-29 11:01:32 +0200 |
commit | bc039524dab28e35e2fe5fbe442afd05efaa7cb1 (patch) | |
tree | 05848e7b550fc6dff336ca13b354d6b299051737 /var | |
parent | 77afad229c5b8a3f3a6847a2b8ca06551e67cf53 (diff) | |
download | spack-bc039524dab28e35e2fe5fbe442afd05efaa7cb1.tar.gz spack-bc039524dab28e35e2fe5fbe442afd05efaa7cb1.tar.bz2 spack-bc039524dab28e35e2fe5fbe442afd05efaa7cb1.tar.xz spack-bc039524dab28e35e2fe5fbe442afd05efaa7cb1.zip |
py-torch: fix +rocm+nccl build (#32771)
Diffstat (limited to 'var')
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index 0ba0736b9d..6b4b82dacb 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -10,7 +10,7 @@ from spack.operating_systems.mac_os import macos_version from spack.package import * -class PyTorch(PythonPackage, CudaPackage): +class PyTorch(PythonPackage, CudaPackage, ROCmPackage): """Tensors and Dynamic neural networks in Python with strong GPU acceleration.""" @@ -100,6 +100,7 @@ class PyTorch(PythonPackage, CudaPackage): ) conflicts("+cuda+rocm") + conflicts("+tensorpipe", when="+rocm", msg="TensorPipe doesn't yet support ROCm") conflicts("+breakpad", when="target=ppc64:") conflicts("+breakpad", when="target=ppc64le:") @@ -177,14 +178,14 @@ class PyTorch(PythonPackage, CudaPackage): depends_on("cudnn@7:", when="@1.6:+cudnn") depends_on("magma+cuda", when="+magma+cuda") depends_on("magma+rocm", when="+magma+rocm") - depends_on("nccl", when="+nccl") + depends_on("nccl", when="+nccl+cuda") depends_on("numactl", when="+numa") depends_on("llvm-openmp", when="%apple-clang +openmp") depends_on("valgrind", when="+valgrind") with when("+rocm"): depends_on("hsa-rocr-dev") depends_on("hip") - depends_on("rccl") + depends_on("rccl", when="+nccl") depends_on("rocprim") depends_on("hipcub") depends_on("rocthrust") @@ -423,6 +424,7 @@ class PyTorch(PythonPackage, CudaPackage): enable_or_disable("rocm") if "+rocm" in self.spec: + env.set("PYTORCH_ROCM_ARCH", ";".join(self.spec.variants["amdgpu_target"].value)) env.set("HSA_PATH", self.spec["hsa-rocr-dev"].prefix) env.set("ROCBLAS_PATH", self.spec["rocblas"].prefix) env.set("ROCFFT_PATH", self.spec["rocfft"].prefix) @@ -432,7 +434,8 @@ class PyTorch(PythonPackage, CudaPackage): env.set("HIPRAND_PATH", self.spec["rocrand"].prefix) env.set("ROCRAND_PATH", self.spec["rocrand"].prefix) env.set("MIOPEN_PATH", self.spec["miopen-hip"].prefix) - env.set("RCCL_PATH", self.spec["rccl"].prefix) + if "+nccl" in self.spec: + env.set("RCCL_PATH", self.spec["rccl"].prefix) env.set("ROCPRIM_PATH", self.spec["rocprim"].prefix) env.set("HIPCUB_PATH", self.spec["hipcub"].prefix) env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix) @@ -454,7 +457,7 @@ class PyTorch(PythonPackage, CudaPackage): enable_or_disable("breakpad") enable_or_disable("nccl") - if "+nccl" in self.spec: + if "+cuda+nccl" in self.spec: env.set("NCCL_LIB_DIR", self.spec["nccl"].libs.directories[0]) env.set("NCCL_INCLUDE_DIR", self.spec["nccl"].prefix.include) |