summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--var/spack/repos/builtin/packages/py-tensorflow/package.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/var/spack/repos/builtin/packages/py-tensorflow/package.py b/var/spack/repos/builtin/packages/py-tensorflow/package.py
index d6380bee01..5c6b3dc1ac 100644
--- a/var/spack/repos/builtin/packages/py-tensorflow/package.py
+++ b/var/spack/repos/builtin/packages/py-tensorflow/package.py
@@ -291,7 +291,7 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage):
depends_on("rocblas")
depends_on("rocfft")
depends_on("hipfft")
- depends_on("rccl")
+ depends_on("rccl", when="+nccl")
depends_on("hipsparse")
depends_on("hipcub")
depends_on("rocsolver")
@@ -348,7 +348,7 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage):
depends_on("cudnn@:6", when="@0.5:0.6 +cuda")
depends_on("cudnn@:7", when="@0.7:2.2 +cuda")
# depends_on('tensorrt', when='+tensorrt')
- depends_on("nccl", when="+nccl")
+ depends_on("nccl", when="+nccl+cuda")
depends_on("mpi", when="+mpi")
# depends_on('android-ndk@10:18', when='+android')
# depends_on('android-sdk', when='+android')
@@ -418,7 +418,7 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage):
msg="Currently TensorRT is only supported on Linux platform",
)
conflicts("+nccl", when="@:1.7")
- conflicts("+nccl", when="~cuda")
+ conflicts("+nccl", when="~cuda~rocm")
conflicts(
"+nccl", when="platform=darwin", msg="Currently NCCL is only supported on Linux platform"
)