summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index 62ad1860cd..9c7abe41e7 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -158,6 +158,20 @@ class PyTorch(PythonPackage, CudaPackage):
depends_on('numactl', when='+numa')
depends_on('llvm-openmp', when='%apple-clang +openmp')
depends_on('valgrind', when='+valgrind')
+ with when("+rocm"):
+ depends_on('hsa-rocr-dev')
+ depends_on('hip')
+ depends_on('rccl')
+ depends_on('rocprim')
+ depends_on('hipcub')
+ depends_on('rocthrust')
+ depends_on('roctracer-dev')
+ depends_on('rocrand')
+ depends_on('hipsparse')
+ depends_on('hipfft')
+ depends_on('rocfft')
+ depends_on('rocblas')
+ depends_on('miopen-hip')
# https://github.com/pytorch/pytorch/issues/60332
# depends_on('xnnpack@2021-02-22', when='@1.8:+xnnpack')
# depends_on('xnnpack@2020-03-23', when='@1.6:1.7+xnnpack')
@@ -332,6 +346,22 @@ class PyTorch(PythonPackage, CudaPackage):
env.set('CMAKE_CUDA_FLAGS', '=-Xcompiler={0}'.format(flag))
enable_or_disable('rocm')
+ if '+rocm' in self.spec:
+ env.set('HSA_PATH', self.spec['hsa-rocr-dev'].prefix)
+ env.set('ROCBLAS_PATH', self.spec['rocblas'].prefix)
+ env.set('ROCFFT_PATH', self.spec['rocfft'].prefix)
+ env.set('HIPFFT_PATH', self.spec['hipfft'].prefix)
+ env.set('HIPSPARSE_PATH', self.spec['hipsparse'].prefix)
+ env.set('THRUST_PATH', self.spec['rocthrust'].prefix.include)
+ env.set('HIP_PATH', self.spec['hip'].prefix)
+ env.set('HIPRAND_PATH', self.spec['rocrand'].prefix)
+ env.set('ROCRAND_PATH', self.spec['rocrand'].prefix)
+ env.set('MIOPEN_PATH', self.spec['miopen-hip'].prefix)
+ env.set('RCCL_PATH', self.spec['rccl'].prefix)
+ env.set('ROCPRIM_PATH', self.spec['rocprim'].prefix)
+ env.set('HIPCUB_PATH', self.spec['hipcub'].prefix)
+ env.set('ROCTHRUST_PATH', self.spec['rocthrust'].prefix)
+ env.set('ROCTRACER_PATH', self.spec['roctracer-dev'].prefix)
enable_or_disable('cudnn')
if '+cudnn' in self.spec: