summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
authorAdam J. Stewart <ajstewart426@gmail.com>2022-09-29 04:01:32 -0500
committerGitHub <noreply@github.com>2022-09-29 11:01:32 +0200
commitbc039524dab28e35e2fe5fbe442afd05efaa7cb1 (patch)
tree05848e7b550fc6dff336ca13b354d6b299051737 /var
parent77afad229c5b8a3f3a6847a2b8ca06551e67cf53 (diff)
downloadspack-bc039524dab28e35e2fe5fbe442afd05efaa7cb1.tar.gz
spack-bc039524dab28e35e2fe5fbe442afd05efaa7cb1.tar.bz2
spack-bc039524dab28e35e2fe5fbe442afd05efaa7cb1.tar.xz
spack-bc039524dab28e35e2fe5fbe442afd05efaa7cb1.zip
py-torch: fix +rocm+nccl build (#32771)
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index 0ba0736b9d..6b4b82dacb 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -10,7 +10,7 @@ from spack.operating_systems.mac_os import macos_version
from spack.package import *
-class PyTorch(PythonPackage, CudaPackage):
+class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
"""Tensors and Dynamic neural networks in Python
with strong GPU acceleration."""
@@ -100,6 +100,7 @@ class PyTorch(PythonPackage, CudaPackage):
)
conflicts("+cuda+rocm")
+ conflicts("+tensorpipe", when="+rocm", msg="TensorPipe doesn't yet support ROCm")
conflicts("+breakpad", when="target=ppc64:")
conflicts("+breakpad", when="target=ppc64le:")
@@ -177,14 +178,14 @@ class PyTorch(PythonPackage, CudaPackage):
depends_on("cudnn@7:", when="@1.6:+cudnn")
depends_on("magma+cuda", when="+magma+cuda")
depends_on("magma+rocm", when="+magma+rocm")
- depends_on("nccl", when="+nccl")
+ depends_on("nccl", when="+nccl+cuda")
depends_on("numactl", when="+numa")
depends_on("llvm-openmp", when="%apple-clang +openmp")
depends_on("valgrind", when="+valgrind")
with when("+rocm"):
depends_on("hsa-rocr-dev")
depends_on("hip")
- depends_on("rccl")
+ depends_on("rccl", when="+nccl")
depends_on("rocprim")
depends_on("hipcub")
depends_on("rocthrust")
@@ -423,6 +424,7 @@ class PyTorch(PythonPackage, CudaPackage):
enable_or_disable("rocm")
if "+rocm" in self.spec:
+ env.set("PYTORCH_ROCM_ARCH", ";".join(self.spec.variants["amdgpu_target"].value))
env.set("HSA_PATH", self.spec["hsa-rocr-dev"].prefix)
env.set("ROCBLAS_PATH", self.spec["rocblas"].prefix)
env.set("ROCFFT_PATH", self.spec["rocfft"].prefix)
@@ -432,7 +434,8 @@ class PyTorch(PythonPackage, CudaPackage):
env.set("HIPRAND_PATH", self.spec["rocrand"].prefix)
env.set("ROCRAND_PATH", self.spec["rocrand"].prefix)
env.set("MIOPEN_PATH", self.spec["miopen-hip"].prefix)
- env.set("RCCL_PATH", self.spec["rccl"].prefix)
+ if "+nccl" in self.spec:
+ env.set("RCCL_PATH", self.spec["rccl"].prefix)
env.set("ROCPRIM_PATH", self.spec["rocprim"].prefix)
env.set("HIPCUB_PATH", self.spec["hipcub"].prefix)
env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix)
@@ -454,7 +457,7 @@ class PyTorch(PythonPackage, CudaPackage):
enable_or_disable("breakpad")
enable_or_disable("nccl")
- if "+nccl" in self.spec:
+ if "+cuda+nccl" in self.spec:
env.set("NCCL_LIB_DIR", self.spec["nccl"].libs.directories[0])
env.set("NCCL_INCLUDE_DIR", self.spec["nccl"].prefix.include)