summaryrefslogtreecommitdiff
path: root/var/spack/repos/builtin/packages/py-torch-nvidia-apex/1499.patch
blob: 058d60986b162aa4c0e192c6bcf3c4fb5ce0a80a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
diff --git a/setup.py b/setup.py
index 063b42d..7388297 100644
--- a/setup.py
+++ b/setup.py
@@ -31,7 +31,7 @@ if not torch.cuda.is_available():
           'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
           'If you wish to cross-compile for a single specific architecture,\n'
           'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
-    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
+    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and cpp_extension.CUDA_HOME is not None:
         _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
         if int(bare_metal_major) == 11:
             os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"