diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/spack/spack/build_systems/cuda.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/lib/spack/spack/build_systems/cuda.py b/lib/spack/spack/build_systems/cuda.py index 20f7ede139..9320a137a5 100644 --- a/lib/spack/spack/build_systems/cuda.py +++ b/lib/spack/spack/build_systems/cuda.py @@ -3,6 +3,9 @@ # # SPDX-License-Identifier: (Apache-2.0 OR MIT) +import re +from typing import Iterable, List + import spack.variant from spack.directives import conflicts, depends_on, variant from spack.multimethod import when @@ -44,6 +47,7 @@ class CudaPackage(PackageBase): "87", "89", "90", + "90a", ) # FIXME: keep cuda and cuda_arch separate to make usage easier until @@ -70,6 +74,27 @@ class CudaPackage(PackageBase): for s in arch_list ] + @staticmethod + def compute_capabilities(arch_list: Iterable[str]) -> List[str]: + """Adds a decimal place to each CUDA arch. + + >>> compute_capabilities(['90', '90a']) + ['9.0', '9.0a'] + + Args: + arch_list: A list of integer strings, optionally followed by a suffix. + + Returns: + A list of float strings, optionally followed by a suffix + """ + pattern = re.compile(r"(\d+)") + capabilities = [] + for arch in arch_list: + _, number, letter = re.split(pattern, arch) + number = "{0:.1f}".format(float(number) / 10.0) + capabilities.append(number + letter) + return capabilities + depends_on("cuda", when="+cuda") # CUDA version vs Architecture |