diff options
author | renjithravindrankannath <94420380+renjithravindrankannath@users.noreply.github.com> | 2022-08-12 23:17:20 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-13 01:17:20 -0500 |
commit | b32cb5765c598e59f2f91d99807b4331f4eba569 (patch) | |
tree | 354bff13b6ace5fd61fad375d14f755737f27001 | |
parent | 4ec31003aabb80051a225f17591fcab43d82f991 (diff) | |
download | spack-b32cb5765c598e59f2f91d99807b4331f4eba569.tar.gz spack-b32cb5765c598e59f2f91d99807b4331f4eba569.tar.bz2 spack-b32cb5765c598e59f2f91d99807b4331f4eba569.tar.xz spack-b32cb5765c598e59f2f91d99807b4331f4eba569.zip |
Add new dependencies for rocm variant for py-torch recipe (#32100)
* Cmake module path updated for ROCm 5.2
* nccl is already set below for PyTorch 1.6+
* Threadpool is set below for PyTorch 1.6+
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index e9de9d257b..0ba0736b9d 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -195,6 +195,7 @@ class PyTorch(PythonPackage, CudaPackage): depends_on("rocfft") depends_on("rocblas") depends_on("miopen-hip") + depends_on("rocminfo") # https://github.com/pytorch/pytorch/issues/60332 # depends_on('xnnpack@2022-02-16', when='@1.12:+xnnpack') # depends_on('xnnpack@2021-06-21', when='@1.10:1.11+xnnpack') @@ -427,7 +428,6 @@ class PyTorch(PythonPackage, CudaPackage): env.set("ROCFFT_PATH", self.spec["rocfft"].prefix) env.set("HIPFFT_PATH", self.spec["hipfft"].prefix) env.set("HIPSPARSE_PATH", self.spec["hipsparse"].prefix) - env.set("THRUST_PATH", self.spec["rocthrust"].prefix.include) env.set("HIP_PATH", self.spec["hip"].prefix) env.set("HIPRAND_PATH", self.spec["rocrand"].prefix) env.set("ROCRAND_PATH", self.spec["rocrand"].prefix) @@ -437,6 +437,8 @@ class PyTorch(PythonPackage, CudaPackage): env.set("HIPCUB_PATH", self.spec["hipcub"].prefix) env.set("ROCTHRUST_PATH", self.spec["rocthrust"].prefix) env.set("ROCTRACER_PATH", self.spec["roctracer-dev"].prefix) + if self.spec.satisfies("^hip@5.2.0:"): + env.set("CMAKE_MODULE_PATH", self.spec["hip"].prefix.lib.cmake.hip) enable_or_disable("cudnn") if "+cudnn" in self.spec: |