diff options
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 21 |
1 files changed, 16 insertions, 5 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index fbd628637a..2b609047b5 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -6,6 +6,7 @@ from spack import * +# TODO: try switching to CMakePackage for more control over build class PyTorch(PythonPackage): """Tensors and Dynamic neural networks in Python with strong GPU acceleration.""" @@ -102,12 +103,16 @@ class PyTorch(PythonPackage): depends_on('cmake@3.5:', type='build') depends_on('python@2.7:2.8,3.5:', type=('build', 'run')) depends_on('py-setuptools', type='build') - depends_on('py-numpy', type=('run', 'build')) + depends_on('py-numpy', type=('build', 'run')) depends_on('py-future', when='@1.1: ^python@:2', type='build') - depends_on('py-pyyaml', type=('run', 'build')) - depends_on('py-typing', when='@0.4: ^python@:3.4', type=('run', 'build')) + depends_on('py-pyyaml', type=('build', 'run')) + depends_on('py-typing', when='@0.4: ^python@:3.4', type=('build', 'run')) + depends_on('py-pybind11', when='@0.4:', type=('build', 'run')) depends_on('blas') depends_on('lapack') + depends_on('protobuf', when='@0.4:') + depends_on('eigen', when='@0.4:') + # TODO: replace all third_party packages with Spack packages # Optional dependencies depends_on('cuda@7.5:', when='+cuda', type=('build', 'link', 'run')) @@ -175,6 +180,14 @@ class PyTorch(PythonPackage): env.set('MAX_JOBS', make_jobs) + # Don't use vendored third-party libraries + env.set('BUILD_CUSTOM_PROTOBUF', 'OFF') + env.set('USE_PYTORCH_QNNPACK', 'OFF') + env.set('USE_SYSTEM_EIGEN_INSTALL', 'ON') + env.set('pybind11_DIR', self.spec['py-pybind11'].prefix) + env.set('pybind11_INCLUDE_DIR', + self.spec['py-pybind11'].prefix.include) + enable_or_disable('cuda') if '+cuda' in self.spec: env.set('CUDA_HOME', self.spec['cuda'].prefix) @@ -200,8 +213,6 @@ class PyTorch(PythonPackage): enable_or_disable('nnpack') enable_or_disable('qnnpack') - # Never use vendored copy of QNNPACK - env.set('USE_PYTORCH_QNNPACK=OFF') enable_or_disable('distributed') enable_or_disable('nccl') |