diff options
-rw-r--r-- | var/spack/repos/builtin/packages/py-tensorflow/package.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/var/spack/repos/builtin/packages/py-tensorflow/package.py b/var/spack/repos/builtin/packages/py-tensorflow/package.py index d6380bee01..5c6b3dc1ac 100644 --- a/var/spack/repos/builtin/packages/py-tensorflow/package.py +++ b/var/spack/repos/builtin/packages/py-tensorflow/package.py @@ -291,7 +291,7 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage): depends_on("rocblas") depends_on("rocfft") depends_on("hipfft") - depends_on("rccl") + depends_on("rccl", when="+nccl") depends_on("hipsparse") depends_on("hipcub") depends_on("rocsolver") @@ -348,7 +348,7 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage): depends_on("cudnn@:6", when="@0.5:0.6 +cuda") depends_on("cudnn@:7", when="@0.7:2.2 +cuda") # depends_on('tensorrt', when='+tensorrt') - depends_on("nccl", when="+nccl") + depends_on("nccl", when="+nccl+cuda") depends_on("mpi", when="+mpi") # depends_on('android-ndk@10:18', when='+android') # depends_on('android-sdk', when='+android') @@ -418,7 +418,7 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage): msg="Currently TensorRT is only supported on Linux platform", ) conflicts("+nccl", when="@:1.7") - conflicts("+nccl", when="~cuda") + conflicts("+nccl", when="~cuda~rocm") conflicts( "+nccl", when="platform=darwin", msg="Currently NCCL is only supported on Linux platform" ) |