summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRobert Underwood <robertu94@users.noreply.github.com>2023-04-21 10:52:47 -0400
committerGitHub <noreply@github.com>2023-04-21 09:52:47 -0500
commitff689be250e6dbb55657b4660e14e8f492057f83 (patch)
tree9ea904e227456092b68b68be79d5de5f9a198d50
parentaaac8b0545657aabe316d2d9036538931fd55ce7 (diff)
downloadspack-ff689be250e6dbb55657b4660e14e8f492057f83.tar.gz
spack-ff689be250e6dbb55657b4660e14e8f492057f83.tar.bz2
spack-ff689be250e6dbb55657b4660e14e8f492057f83.tar.xz
spack-ff689be250e6dbb55657b4660e14e8f492057f83.zip
py-cupy allow customizing architecture and threads (#37072)
* py-cupy allow customizing architecture and threads * Update var/spack/repos/builtin/packages/py-cupy/package.py Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com> * add missing self --------- Co-authored-by: Robert Underwood <runderwood@anl.gov> Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com>
-rw-r--r--var/spack/repos/builtin/packages/py-cupy/package.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/var/spack/repos/builtin/packages/py-cupy/package.py b/var/spack/repos/builtin/packages/py-cupy/package.py
index e1cf3a4a8f..f0f40c53b6 100644
--- a/var/spack/repos/builtin/packages/py-cupy/package.py
+++ b/var/spack/repos/builtin/packages/py-cupy/package.py
@@ -6,7 +6,7 @@
from spack.package import *
-class PyCupy(PythonPackage):
+class PyCupy(PythonPackage, CudaPackage):
"""CuPy is an open-source array library accelerated with
NVIDIA CUDA. CuPy provides GPU accelerated computing with
Python. CuPy uses CUDA-related libraries including cuBLAS,
@@ -32,3 +32,12 @@ class PyCupy(PythonPackage):
depends_on("nccl")
depends_on("cudnn")
depends_on("cutensor")
+
+ conflicts("~cuda")
+
+ def setup_build_environment(self, env):
+ env.set("CUPY_NUM_BUILD_JOBS", make_jobs)
+ if not self.spec.satisfies("cuda_arch=none"):
+ cuda_arch = self.spec.variants["cuda_arch"].value
+ arch_str = ";".join("arch=compute_{0},code=sm_{0}".format(i) for i in cuda_arch)
+ env.set("CUPY_NVCC_GENERATE_CODE", arch_str)