From 8d9a035d122b87c362336fafd1a31986878310c8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 10 Oct 2022 09:04:35 -0500 Subject: py-torch-nvidia-apex: fix +cuda build (#33070) --- .../repos/builtin/packages/py-torch-nvidia-apex/1499.patch | 13 +++++++++++++ .../repos/builtin/packages/py-torch-nvidia-apex/package.py | 5 ++++- 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 var/spack/repos/builtin/packages/py-torch-nvidia-apex/1499.patch diff --git a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/1499.patch b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/1499.patch new file mode 100644 index 0000000000..058d60986b --- /dev/null +++ b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/1499.patch @@ -0,0 +1,13 @@ +diff --git a/setup.py b/setup.py +index 063b42d..7388297 100644 +--- a/setup.py ++++ b/setup.py +@@ -31,7 +31,7 @@ if not torch.cuda.is_available(): + 'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n' + 'If you wish to cross-compile for a single specific architecture,\n' + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') +- if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: ++ if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and cpp_extension.CUDA_HOME is not None: + _, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) + if int(bare_metal_major) == 11: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" diff --git a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py index a0ac3fe60f..b65038b33d 100644 --- a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py +++ b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py @@ -26,7 +26,8 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage): variant("cuda", default=True, description="Build with CUDA") # https://github.com/NVIDIA/apex/issues/1498 - conflicts("~cuda") + # https://github.com/NVIDIA/apex/pull/1499 + patch("1499.patch", when="@2020-10-19") def setup_build_environment(self, env): if "+cuda" in self.spec: @@ -37,6 +38,8 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage): for i in self.spec.variants["cuda_arch"].value ) env.set("TORCH_CUDA_ARCH_LIST", torch_cuda_arch) + else: + env.unset("CUDA_HOME") def global_options(self, spec, prefix): args = [] -- cgit v1.2.3-70-g09d2