diff options
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index 0ba0736b9d..6b4b82dacb 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -10,7 +10,7 @@ from spack.operating_systems.mac_os import macos_version from spack.package import * -class PyTorch(PythonPackage, CudaPackage): +class PyTorch(PythonPackage, CudaPackage, ROCmPackage): """Tensors and Dynamic neural networks in Python with strong GPU acceleration.""" @@ -100,6 +100,7 @@ class PyTorch(PythonPackage, CudaPackage): ) conflicts("+cuda+rocm") + conflicts("+tensorpipe", when="+rocm", msg="TensorPipe doesn't yet support ROCm") conflicts("+breakpad", when="target=ppc64:") conflicts("+breakpad", when="target=ppc64le:") @@ -177,14 +178,14 @@ class PyTorch(PythonPackage, CudaPackage): depends_on("cudnn@7:", when="@1.6:+cudnn") depends_on("magma+cuda", when="+magma+cuda") depends_on("magma+rocm", when="+magma+rocm") - depends_on("nccl", when="+nccl") + depends_on("nccl", when="+nccl+cuda") depends_on("numactl", when="+numa") depends_on("llvm-openmp", when="%apple-clang +openmp") depends_on("valgrind", when="+valgrind") with when("+rocm"): depends_on("hsa-rocr-dev") depends_on("hip") - depends_on("rccl") + depends_on("rccl", when="+nccl") depends_on("rocprim") depends_on("hipcub") depends_on("rocthrust") @@ -423,6 +424,7 @@ class PyTorch(PythonPackage, CudaPackage): enable_or_disable("rocm") if "+rocm" in self.spec: + env.set("PYTORCH_ROCM_ARCH", ";".join(self.spec.variants["amdgpu_target"].value)) env.set("HSA_PATH", self.spec["hsa-rocr-dev"].prefix) env.set("ROCBLAS_PATH", self.spec["rocblas"].prefix) env.set("ROCFFT_PATH", self.spec["rocfft"].prefix) @@ -432,7 +434,8 @@ class PyTorch(PythonPackage, CudaPackage): env.set("HIPRAND_PATH", self.spec["rocrand"].prefix) env.set("ROCRAND_PATH", self.spec["rocrand"].prefix) env.set("MIOPEN_PATH", self.spec["miopen-hip"].prefix) - env.set("RCCL_PATH", self.spec["rccl"].prefix) + if "+nccl" in self.spec: + env.set("RCCL_PATH", self.spec["rccl"].prefix) env.set("ROCPRIM_PATH", self.spec["rocprim"].prefix) env.set("HIPCUB_PATH", self.spec["hipcub"].prefix) env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix) @@ -454,7 +457,7 @@ class PyTorch(PythonPackage, CudaPackage): enable_or_disable("breakpad") enable_or_disable("nccl") - if "+nccl" in self.spec: + if "+cuda+nccl" in self.spec: env.set("NCCL_LIB_DIR", self.spec["nccl"].libs.directories[0]) env.set("NCCL_INCLUDE_DIR", self.spec["nccl"].prefix.include) |