summaryrefslogtreecommitdiff
path: root/var/spack/repos/builtin/packages/py-torch/package.py
diff options
context:
space:
mode:
Diffstat (limited to 'var/spack/repos/builtin/packages/py-torch/package.py')
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index e9de9d257b..0ba0736b9d 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -195,6 +195,7 @@ class PyTorch(PythonPackage, CudaPackage):
depends_on("rocfft")
depends_on("rocblas")
depends_on("miopen-hip")
+ depends_on("rocminfo")
# https://github.com/pytorch/pytorch/issues/60332
# depends_on('xnnpack@2022-02-16', when='@1.12:+xnnpack')
# depends_on('xnnpack@2021-06-21', when='@1.10:1.11+xnnpack')
@@ -427,7 +428,6 @@ class PyTorch(PythonPackage, CudaPackage):
env.set("ROCFFT_PATH", self.spec["rocfft"].prefix)
env.set("HIPFFT_PATH", self.spec["hipfft"].prefix)
env.set("HIPSPARSE_PATH", self.spec["hipsparse"].prefix)
- env.set("THRUST_PATH", self.spec["rocthrust"].prefix.include)
env.set("HIP_PATH", self.spec["hip"].prefix)
env.set("HIPRAND_PATH", self.spec["rocrand"].prefix)
env.set("ROCRAND_PATH", self.spec["rocrand"].prefix)
@@ -437,6 +437,8 @@ class PyTorch(PythonPackage, CudaPackage):
env.set("HIPCUB_PATH", self.spec["hipcub"].prefix)
env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix)
env.set("ROCTRACER_PATH", self.spec["roctracer-dev"].prefix)
+ if self.spec.satisfies("^hip@5.2.0:"):
+ env.set("CMAKE_MODULE_PATH", self.spec["hip"].prefix.lib.cmake.hip)
enable_or_disable("cudnn")
if "+cudnn" in self.spec: