summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Van Essen <vanessen1@llnl.gov>2022-04-27 08:48:06 -0700
committerGitHub <noreply@github.com>2022-04-27 10:48:06 -0500
commit06e72498505c80cdf792f6cee049f117bc8cf5b6 (patch)
tree45e5ba2abbf128f2afd21e563e8e62506a5b1111
parentc4ad003af2c1502ae3b3e11864411034fbbef28c (diff)
downloadspack-06e72498505c80cdf792f6cee049f117bc8cf5b6.tar.gz
spack-06e72498505c80cdf792f6cee049f117bc8cf5b6.tar.bz2
spack-06e72498505c80cdf792f6cee049f117bc8cf5b6.tar.xz
spack-06e72498505c80cdf792f6cee049f117bc8cf5b6.zip
Allow PyTorch to forward gcc-toolchain cxxcflag to CUDA toolchains (#30318)
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index b7d6fa0c0e..dbab39daf2 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -301,6 +301,10 @@ class PyTorch(PythonPackage, CudaPackage):
in
self.spec.variants['cuda_arch'].value)
env.set('TORCH_CUDA_ARCH_LIST', torch_cuda_arch)
+ if self.spec.satisfies('%clang'):
+ for flag in self.spec.compiler_flags['cxxflags']:
+ if 'gcc-toolchain' in flag:
+ env.set('CMAKE_CUDA_FLAGS', '=-Xcompiler={0}'.format(flag))
enable_or_disable('rocm')