diff options
author | Andrew W Elble <aweits@rit.edu> | 2020-04-14 17:43:30 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-14 16:43:30 -0500 |
commit | a031bc3166b8a03839d4218296456bc27dd18ce1 (patch) | |
tree | c7a20d9e301555340344f4f1e0237ea9f8aba5c9 /var | |
parent | 993491c83c381a1db58dc56684ecc4ba47345af6 (diff) | |
download | spack-a031bc3166b8a03839d4218296456bc27dd18ce1.tar.gz spack-a031bc3166b8a03839d4218296456bc27dd18ce1.tar.bz2 spack-a031bc3166b8a03839d4218296456bc27dd18ce1.tar.xz spack-a031bc3166b8a03839d4218296456bc27dd18ce1.zip |
new package: py-torch-nvidia-apex (#16050)
Diffstat (limited to 'var')
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py new file mode 100644 index 0000000000..919a6c41ad --- /dev/null +++ b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py @@ -0,0 +1,41 @@ +# Copyright 2013-2020 Lawrence Livermore National Security, LLC and other +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) + + +class PyTorchNvidiaApex(PythonPackage, CudaPackage): + """A PyTorch Extension: Tools for easy mixed precision and + distributed training in Pytorch """ + + homepage = "https://github.com/nvidia/apex/" + git = "https://github.com/nvidia/apex/" + + phases = ['install'] + + version('master', branch='master') + + depends_on('python@3:', type=('build', 'run')) + depends_on('py-setuptools', type='build') + depends_on('py-torch@0.4:', type=('build', 'run')) + depends_on('cuda@9:', when='+cuda') + + variant('cuda', default=True, description='Build with CUDA') + + def setup_build_environment(self, env): + if '+cuda' in self.spec: + env.set('CUDA_HOME', self.spec['cuda'].prefix) + if (self.spec.variants['cuda_arch'].value[0] != 'none'): + torch_cuda_arch = ';'.join( + '{0:.1f}'.format(float(i) / 10.0) for i + in + self.spec.variants['cuda_arch'].value) + env.set('TORCH_CUDA_ARCH_LIST', torch_cuda_arch) + + def install_args(self, spec, prefix): + args = super(PyTorchNvidiaApex, self).install_args(spec, prefix) + if spec.satisfies('^py-torch@1.0:'): + args.append('--cpp_ext') + if '+cuda' in spec: + args.append('--cuda_ext') + return args |