diff options
Diffstat (limited to 'var')
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index e9de9d257b..0ba0736b9d 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -195,6 +195,7 @@ class PyTorch(PythonPackage, CudaPackage): depends_on("rocfft") depends_on("rocblas") depends_on("miopen-hip") + depends_on("rocminfo") # https://github.com/pytorch/pytorch/issues/60332 # depends_on('xnnpack@2022-02-16', when='@1.12:+xnnpack') # depends_on('xnnpack@2021-06-21', when='@1.10:1.11+xnnpack') @@ -427,7 +428,6 @@ class PyTorch(PythonPackage, CudaPackage): env.set("ROCFFT_PATH", self.spec["rocfft"].prefix) env.set("HIPFFT_PATH", self.spec["hipfft"].prefix) env.set("HIPSPARSE_PATH", self.spec["hipsparse"].prefix) - env.set("THRUST_PATH", self.spec["rocthrust"].prefix.include) env.set("HIP_PATH", self.spec["hip"].prefix) env.set("HIPRAND_PATH", self.spec["rocrand"].prefix) env.set("ROCRAND_PATH", self.spec["rocrand"].prefix) @@ -437,6 +437,8 @@ class PyTorch(PythonPackage, CudaPackage): env.set("HIPCUB_PATH", self.spec["hipcub"].prefix) env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix) env.set("ROCTRACER_PATH", self.spec["roctracer-dev"].prefix) + if self.spec.satisfies("^hip@5.2.0:"): + env.set("CMAKE_MODULE_PATH", self.spec["hip"].prefix.lib.cmake.hip) enable_or_disable("cudnn") if "+cudnn" in self.spec: |