diff options
Diffstat (limited to 'var')
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index 62ad1860cd..9c7abe41e7 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -158,6 +158,20 @@ class PyTorch(PythonPackage, CudaPackage): depends_on('numactl', when='+numa') depends_on('llvm-openmp', when='%apple-clang +openmp') depends_on('valgrind', when='+valgrind') + with when("+rocm"): + depends_on('hsa-rocr-dev') + depends_on('hip') + depends_on('rccl') + depends_on('rocprim') + depends_on('hipcub') + depends_on('rocthrust') + depends_on('roctracer-dev') + depends_on('rocrand') + depends_on('hipsparse') + depends_on('hipfft') + depends_on('rocfft') + depends_on('rocblas') + depends_on('miopen-hip') # https://github.com/pytorch/pytorch/issues/60332 # depends_on('xnnpack@2021-02-22', when='@1.8:+xnnpack') # depends_on('xnnpack@2020-03-23', when='@1.6:1.7+xnnpack') @@ -332,6 +346,22 @@ class PyTorch(PythonPackage, CudaPackage): env.set('CMAKE_CUDA_FLAGS', '=-Xcompiler={0}'.format(flag)) enable_or_disable('rocm') + if '+rocm' in self.spec: + env.set('HSA_PATH', self.spec['hsa-rocr-dev'].prefix) + env.set('ROCBLAS_PATH', self.spec['rocblas'].prefix) + env.set('ROCFFT_PATH', self.spec['rocfft'].prefix) + env.set('HIPFFT_PATH', self.spec['hipfft'].prefix) + env.set('HIPSPARSE_PATH', self.spec['hipsparse'].prefix) + env.set('THRUST_PATH', self.spec['rocthrust'].prefix.include) + env.set('HIP_PATH', self.spec['hip'].prefix) + env.set('HIPRAND_PATH', self.spec['rocrand'].prefix) + env.set('ROCRAND_PATH', self.spec['rocrand'].prefix) + env.set('MIOPEN_PATH', self.spec['miopen-hip'].prefix) + env.set('RCCL_PATH', self.spec['rccl'].prefix) + env.set('ROCPRIM_PATH', self.spec['rocprim'].prefix) + env.set('HIPCUB_PATH', self.spec['hipcub'].prefix) + env.set('ROCTHRUST_PATH', self.spec['rocthrust'].prefix) + env.set('ROCTRACER_PATH', self.spec['roctracer-dev'].prefix) enable_or_disable('cudnn') if '+cudnn' in self.spec: |