blob: 058d60986b162aa4c0e192c6bcf3c4fb5ce0a80a (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
|
diff --git a/setup.py b/setup.py
index 063b42d..7388297 100644
--- a/setup.py
+++ b/setup.py
@@ -31,7 +31,7 @@ if not torch.cuda.is_available():
'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
'If you wish to cross-compile for a single specific architecture,\n'
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
- if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
+ if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and cpp_extension.CUDA_HOME is not None:
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
|