summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorrenjithravindrankannath <94420380+renjithravindrankannath@users.noreply.github.com>2022-08-12 23:17:20 -0700
committerGitHub <noreply@github.com>2022-08-13 01:17:20 -0500
commitb32cb5765c598e59f2f91d99807b4331f4eba569 (patch)
tree354bff13b6ace5fd61fad375d14f755737f27001
parent4ec31003aabb80051a225f17591fcab43d82f991 (diff)
downloadspack-b32cb5765c598e59f2f91d99807b4331f4eba569.tar.gz
spack-b32cb5765c598e59f2f91d99807b4331f4eba569.tar.bz2
spack-b32cb5765c598e59f2f91d99807b4331f4eba569.tar.xz
spack-b32cb5765c598e59f2f91d99807b4331f4eba569.zip
Add new dependencies for rocm variant for py-torch recipe (#32100)
* Cmake module path updated for ROCm 5.2 * nccl is already set below for PyTorch 1.6+ * Threadpool is set below for PyTorch 1.6+
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index e9de9d257b..0ba0736b9d 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -195,6 +195,7 @@ class PyTorch(PythonPackage, CudaPackage):
depends_on("rocfft")
depends_on("rocblas")
depends_on("miopen-hip")
+ depends_on("rocminfo")
# https://github.com/pytorch/pytorch/issues/60332
# depends_on('xnnpack@2022-02-16', when='@1.12:+xnnpack')
# depends_on('xnnpack@2021-06-21', when='@1.10:1.11+xnnpack')
@@ -427,7 +428,6 @@ class PyTorch(PythonPackage, CudaPackage):
env.set("ROCFFT_PATH", self.spec["rocfft"].prefix)
env.set("HIPFFT_PATH", self.spec["hipfft"].prefix)
env.set("HIPSPARSE_PATH", self.spec["hipsparse"].prefix)
- env.set("THRUST_PATH", self.spec["rocthrust"].prefix.include)
env.set("HIP_PATH", self.spec["hip"].prefix)
env.set("HIPRAND_PATH", self.spec["rocrand"].prefix)
env.set("ROCRAND_PATH", self.spec["rocrand"].prefix)
@@ -437,6 +437,8 @@ class PyTorch(PythonPackage, CudaPackage):
env.set("HIPCUB_PATH", self.spec["hipcub"].prefix)
env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix)
env.set("ROCTRACER_PATH", self.spec["roctracer-dev"].prefix)
+ if self.spec.satisfies("^hip@5.2.0:"):
+ env.set("CMAKE_MODULE_PATH", self.spec["hip"].prefix.lib.cmake.hip)
enable_or_disable("cudnn")
if "+cudnn" in self.spec: