diff options
author | Baptiste Jonglez <30461003+jonglezb@users.noreply.github.com> | 2021-01-26 14:58:41 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-26 07:58:41 -0600 |
commit | 79afe20bb0916b377c23fc8fc0dd592c5357e72b (patch) | |
tree | 2213ad413099ab7554dc4607a7a8a29cc6c26866 | |
parent | b45a31aefe065c0cb6761034384af2aa1d2c41f3 (diff) | |
download | spack-79afe20bb0916b377c23fc8fc0dd592c5357e72b.tar.gz spack-79afe20bb0916b377c23fc8fc0dd592c5357e72b.tar.bz2 spack-79afe20bb0916b377c23fc8fc0dd592c5357e72b.tar.xz spack-79afe20bb0916b377c23fc8fc0dd592c5357e72b.zip |
mxnet: Add optional cuda_arch spec support, enable CUDA by default (#21266)
-rw-r--r-- | var/spack/repos/builtin/packages/mxnet/package.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/var/spack/repos/builtin/packages/mxnet/package.py b/var/spack/repos/builtin/packages/mxnet/package.py index 6f42d78f50..5cc9977ef3 100644 --- a/var/spack/repos/builtin/packages/mxnet/package.py +++ b/var/spack/repos/builtin/packages/mxnet/package.py @@ -6,7 +6,7 @@ from spack import * -class Mxnet(MakefilePackage): +class Mxnet(MakefilePackage, CudaPackage): """MXNet is a deep learning framework designed for both efficiency and flexibility.""" @@ -18,7 +18,7 @@ class Mxnet(MakefilePackage): version('1.6.0', sha256='01eb06069c90f33469c7354946261b0a94824bbaf819fd5d5a7318e8ee596def') version('1.3.0', sha256='c00d6fbb2947144ce36c835308e603f002c1eb90a9f4c5a62f4d398154eed4d2') - variant('cuda', default=False, description='Enable CUDA support') + variant('cuda', default=True, description='Enable CUDA support') variant('opencv', default=True, description='Enable OpenCV support') variant('openmp', default=False, description='Enable OpenMP support') variant('profiler', default=False, description='Enable Profiler (for verification and debug only).') @@ -111,6 +111,11 @@ class Mxnet(MakefilePackage): args.extend(['USE_CUDA_PATH=%s' % spec['cuda'].prefix, 'CUDNN_PATH=%s' % spec['cudnn'].prefix, 'CUB_INCLUDE=%s' % spec['cub'].prefix.include]) + # By default, all cuda architectures are built. Restrict only + # if a specific list of architectures is specified in cuda_arch. + if 'cuda_arch=none' not in spec: + cuda_flags = self.cuda_flags(self.spec.variants['cuda_arch'].value) + args.append('CUDA_ARCH={0}'.format(' '.join(cuda_flags))) make(*args) |