summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py13
1 files changed, 8 insertions, 5 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index 0ba0736b9d..6b4b82dacb 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -10,7 +10,7 @@ from spack.operating_systems.mac_os import macos_version
from spack.package import *
-class PyTorch(PythonPackage, CudaPackage):
+class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
"""Tensors and Dynamic neural networks in Python
with strong GPU acceleration."""
@@ -100,6 +100,7 @@ class PyTorch(PythonPackage, CudaPackage):
)
conflicts("+cuda+rocm")
+ conflicts("+tensorpipe", when="+rocm", msg="TensorPipe doesn't yet support ROCm")
conflicts("+breakpad", when="target=ppc64:")
conflicts("+breakpad", when="target=ppc64le:")
@@ -177,14 +178,14 @@ class PyTorch(PythonPackage, CudaPackage):
depends_on("cudnn@7:", when="@1.6:+cudnn")
depends_on("magma+cuda", when="+magma+cuda")
depends_on("magma+rocm", when="+magma+rocm")
- depends_on("nccl", when="+nccl")
+ depends_on("nccl", when="+nccl+cuda")
depends_on("numactl", when="+numa")
depends_on("llvm-openmp", when="%apple-clang +openmp")
depends_on("valgrind", when="+valgrind")
with when("+rocm"):
depends_on("hsa-rocr-dev")
depends_on("hip")
- depends_on("rccl")
+ depends_on("rccl", when="+nccl")
depends_on("rocprim")
depends_on("hipcub")
depends_on("rocthrust")
@@ -423,6 +424,7 @@ class PyTorch(PythonPackage, CudaPackage):
enable_or_disable("rocm")
if "+rocm" in self.spec:
+ env.set("PYTORCH_ROCM_ARCH", ";".join(self.spec.variants["amdgpu_target"].value))
env.set("HSA_PATH", self.spec["hsa-rocr-dev"].prefix)
env.set("ROCBLAS_PATH", self.spec["rocblas"].prefix)
env.set("ROCFFT_PATH", self.spec["rocfft"].prefix)
@@ -432,7 +434,8 @@ class PyTorch(PythonPackage, CudaPackage):
env.set("HIPRAND_PATH", self.spec["rocrand"].prefix)
env.set("ROCRAND_PATH", self.spec["rocrand"].prefix)
env.set("MIOPEN_PATH", self.spec["miopen-hip"].prefix)
- env.set("RCCL_PATH", self.spec["rccl"].prefix)
+ if "+nccl" in self.spec:
+ env.set("RCCL_PATH", self.spec["rccl"].prefix)
env.set("ROCPRIM_PATH", self.spec["rocprim"].prefix)
env.set("HIPCUB_PATH", self.spec["hipcub"].prefix)
env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix)
@@ -454,7 +457,7 @@ class PyTorch(PythonPackage, CudaPackage):
enable_or_disable("breakpad")
enable_or_disable("nccl")
- if "+nccl" in self.spec:
+ if "+cuda+nccl" in self.spec:
env.set("NCCL_LIB_DIR", self.spec["nccl"].libs.directories[0])
env.set("NCCL_INCLUDE_DIR", self.spec["nccl"].prefix.include)