summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py33
1 files changed, 31 insertions, 2 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index dcec15dc68..2a20235bef 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -62,6 +62,7 @@ class PyTorch(PythonPackage, CudaPackage):
version('0.4.0', tag='v0.4.0', submodules=True)
version('0.3.1', tag='v0.3.1', submodules=True)
+ variant('cuda', default=True, description='Build with CUDA')
variant('cudnn', default=True, description='Enables the cuDNN build')
variant('magma', default=False, description='Enables the MAGMA build')
variant('fbgemm', default=False, description='Enables the FBGEMM build')
@@ -100,6 +101,27 @@ class PyTorch(PythonPackage, CudaPackage):
conflicts('+zstd', when='@:1.0')
conflicts('+tbb', when='@:1.1')
+ cuda_arch_conflict = ('This version of Torch/Caffe2 only supports compute '
+ 'capabilities ')
+
+ conflicts('cuda_arch=none', when='+cuda+caffe2',
+ msg='Must specify CUDA compute capabilities of your GPU, see '
+ 'https://developer.nvidia.com/cuda-gpus')
+ conflicts('cuda_arch=52', when='@1.3.0:+cuda+caffe2',
+ msg=cuda_arch_conflict + '>=5.3')
+ conflicts('cuda_arch=50', when='@1.3.0:+cuda+caffe2',
+ msg=cuda_arch_conflict + '>=5.3')
+ conflicts('cuda_arch=35', when='@1.3.0:+cuda+caffe2',
+ msg=cuda_arch_conflict + '>=5.3')
+ conflicts('cuda_arch=32', when='@1.3.0:+cuda+caffe2',
+ msg=cuda_arch_conflict + '>=5.3')
+ conflicts('cuda_arch=30', when='@1.3.0:+cuda+caffe2',
+ msg=cuda_arch_conflict + '>=5.3')
+ conflicts('cuda_arch=30', when='@1.2.0:+cuda+caffe2',
+ msg=cuda_arch_conflict + '>=3.2')
+ conflicts('cuda_arch=20', when='@1.0.0:+cuda+caffe2',
+ msg=cuda_arch_conflict + '>=3.0')
+
# Required dependencies
depends_on('cmake@3.5:', type='build')
# Use Ninja generator to speed up build times
@@ -128,7 +150,10 @@ class PyTorch(PythonPackage, CudaPackage):
# depends_on('fbgemm', when='+fbgemm')
# TODO: add dependency: https://github.com/ROCmSoftwarePlatform/MIOpen
# depends_on('miopen', when='+miopen')
- depends_on('intel-mkl-dnn', when='+mkldnn')
+ # TODO: See if there is a way to use an external mkldnn installation.
+ # Currently, only older versions of py-torch use an external mkldnn
+ # library.
+ depends_on('intel-mkl-dnn', when='@0.4:0.4.1+mkldnn')
# TODO: add dependency: https://github.com/Maratyszcza/NNPACK
# depends_on('nnpack', when='+nnpack')
depends_on('qnnpack', when='+qnnpack')
@@ -197,6 +222,10 @@ class PyTorch(PythonPackage, CudaPackage):
enable_or_disable('cuda')
if '+cuda' in self.spec:
env.set('CUDA_HOME', self.spec['cuda'].prefix)
+ torch_cuda_arch = ';'.join('{0:.1f}'.format(float(i) / 10.0) for i
+ in
+ self.spec.variants['cuda_arch'].value)
+ env.set('TORCH_CUDA_ARCH_LIST', torch_cuda_arch)
enable_or_disable('cudnn')
if '+cudnn' in self.spec:
@@ -213,7 +242,7 @@ class PyTorch(PythonPackage, CudaPackage):
env.set('MIOPEN_LIBRARY', self.spec['miopen'].libs[0])
enable_or_disable('mkldnn')
- if '+mkldnn' in self.spec:
+ if '@0.4:0.4.1+mkldnn' in self.spec:
env.set('MKLDNN_HOME', self.spec['intel-mkl-dnn'].prefix)
enable_or_disable('nnpack')