summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
authorBaptiste Jonglez <30461003+jonglezb@users.noreply.github.com>2021-01-26 14:58:41 +0100
committerGitHub <noreply@github.com>2021-01-26 07:58:41 -0600
commit79afe20bb0916b377c23fc8fc0dd592c5357e72b (patch)
tree2213ad413099ab7554dc4607a7a8a29cc6c26866 /var
parentb45a31aefe065c0cb6761034384af2aa1d2c41f3 (diff)
downloadspack-79afe20bb0916b377c23fc8fc0dd592c5357e72b.tar.gz
spack-79afe20bb0916b377c23fc8fc0dd592c5357e72b.tar.bz2
spack-79afe20bb0916b377c23fc8fc0dd592c5357e72b.tar.xz
spack-79afe20bb0916b377c23fc8fc0dd592c5357e72b.zip
mxnet: Add optional cuda_arch spec support, enable CUDA by default (#21266)
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/mxnet/package.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/var/spack/repos/builtin/packages/mxnet/package.py b/var/spack/repos/builtin/packages/mxnet/package.py
index 6f42d78f50..5cc9977ef3 100644
--- a/var/spack/repos/builtin/packages/mxnet/package.py
+++ b/var/spack/repos/builtin/packages/mxnet/package.py
@@ -6,7 +6,7 @@
from spack import *
-class Mxnet(MakefilePackage):
+class Mxnet(MakefilePackage, CudaPackage):
"""MXNet is a deep learning framework
designed for both efficiency and flexibility."""
@@ -18,7 +18,7 @@ class Mxnet(MakefilePackage):
version('1.6.0', sha256='01eb06069c90f33469c7354946261b0a94824bbaf819fd5d5a7318e8ee596def')
version('1.3.0', sha256='c00d6fbb2947144ce36c835308e603f002c1eb90a9f4c5a62f4d398154eed4d2')
- variant('cuda', default=False, description='Enable CUDA support')
+ variant('cuda', default=True, description='Enable CUDA support')
variant('opencv', default=True, description='Enable OpenCV support')
variant('openmp', default=False, description='Enable OpenMP support')
variant('profiler', default=False, description='Enable Profiler (for verification and debug only).')
@@ -111,6 +111,11 @@ class Mxnet(MakefilePackage):
args.extend(['USE_CUDA_PATH=%s' % spec['cuda'].prefix,
'CUDNN_PATH=%s' % spec['cudnn'].prefix,
'CUB_INCLUDE=%s' % spec['cub'].prefix.include])
+ # By default, all cuda architectures are built. Restrict only
+ # if a specific list of architectures is specified in cuda_arch.
+ if 'cuda_arch=none' not in spec:
+ cuda_flags = self.cuda_flags(self.spec.variants['cuda_arch'].value)
+ args.append('CUDA_ARCH={0}'.format(' '.join(cuda_flags)))
make(*args)