summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
authorAdam J. Stewart <ajstewart426@gmail.com>2022-10-10 09:04:35 -0500
committerGitHub <noreply@github.com>2022-10-10 16:04:35 +0200
commit8d9a035d122b87c362336fafd1a31986878310c8 (patch)
tree2188f6530b92cc1d71254dd01e02c09c56489102 /var
parentcbc867a24c51b7137768c1c20685c8c9e6cef1fa (diff)
downloadspack-8d9a035d122b87c362336fafd1a31986878310c8.tar.gz
spack-8d9a035d122b87c362336fafd1a31986878310c8.tar.bz2
spack-8d9a035d122b87c362336fafd1a31986878310c8.tar.xz
spack-8d9a035d122b87c362336fafd1a31986878310c8.zip
py-torch-nvidia-apex: fix +cuda build (#33070)
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-torch-nvidia-apex/1499.patch13
-rw-r--r--var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py5
2 files changed, 17 insertions, 1 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/1499.patch b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/1499.patch
new file mode 100644
index 0000000000..058d60986b
--- /dev/null
+++ b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/1499.patch
@@ -0,0 +1,13 @@
+diff --git a/setup.py b/setup.py
+index 063b42d..7388297 100644
+--- a/setup.py
++++ b/setup.py
+@@ -31,7 +31,7 @@ if not torch.cuda.is_available():
+ 'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
+ 'If you wish to cross-compile for a single specific architecture,\n'
+ 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
+- if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
++ if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and cpp_extension.CUDA_HOME is not None:
+ _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
+ if int(bare_metal_major) == 11:
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
diff --git a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py
index a0ac3fe60f..b65038b33d 100644
--- a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py
+++ b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py
@@ -26,7 +26,8 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage):
variant("cuda", default=True, description="Build with CUDA")
# https://github.com/NVIDIA/apex/issues/1498
- conflicts("~cuda")
+ # https://github.com/NVIDIA/apex/pull/1499
+ patch("1499.patch", when="@2020-10-19")
def setup_build_environment(self, env):
if "+cuda" in self.spec:
@@ -37,6 +38,8 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage):
for i in self.spec.variants["cuda_arch"].value
)
env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch)
+ else:
+ env.unset("CUDA_HOME")
def global_options(self, spec, prefix):
args = []