diff options
author | Wouter Deconinck <wdconinc@gmail.com> | 2024-01-03 07:42:56 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-03 06:42:56 -0700 |
commit | 375bc6fc9443082fd28e4bbcad44d1f627a452eb (patch) | |
tree | b83624da7c905c30cabd864162d3fffd380733c6 /var | |
parent | 4b2baa7e91b009fdf2104e1f874dbb60a03f71a7 (diff) | |
download | spack-375bc6fc9443082fd28e4bbcad44d1f627a452eb.tar.gz spack-375bc6fc9443082fd28e4bbcad44d1f627a452eb.tar.bz2 spack-375bc6fc9443082fd28e4bbcad44d1f627a452eb.tar.xz spack-375bc6fc9443082fd28e4bbcad44d1f627a452eb.zip |
py-torch: set env OpenBLAS_HOME (#41745)
* py-torch: set env OpenBLAS_HOME
Because [`FindOpenBLAS.cmake`](https://github.com/pytorch/pytorch/blob/main/cmake/Modules/FindOpenBLAS.cmake) uses a hardcoded list of search paths for includes and libraries, we have to pass the `OpenBLAS_HOME` environment variable.
* py-torch: patch for ${OpenBLAS_HOME}/include/openblas
The context of this patch is unchanged since v0.4.0.
* py-torch: move patch before def patch
* py-torch: also set Atlas_ROOT_DIR and BLIS_HOME
* py-torch: fix openblas patch range to @:2.1
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
---------
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
Diffstat (limited to 'var')
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index bfe66074e9..6126b33111 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -419,6 +419,13 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): when="@2.0.0:2.0.1", ) + # Use correct OpenBLAS include path under prefix + patch( + "https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/110063.patch?full_index=1", + sha256="23fb4009f7337051fc5303927ff977186a5af960245e7212895406477d8b2f66", + when="@:2.1", + ) + @when("@1.5.0:") def patch(self): # https://github.com/pytorch/pytorch/issues/52208 @@ -560,9 +567,11 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): if self.spec["blas"].name == "atlas": env.set("BLAS", "ATLAS") env.set("WITH_BLAS", "atlas") + env.set("Atlas_ROOT_DIR", self.spec["atlas"].prefix) elif self.spec["blas"].name in ["blis", "amdblis"]: env.set("BLAS", "BLIS") env.set("WITH_BLAS", "blis") + env.set("BLIS_HOME", self.spec["blas"].prefix) elif self.spec["blas"].name == "eigen": env.set("BLAS", "Eigen") elif self.spec["lapack"].name in ["libflame", "amdlibflame"]: @@ -579,6 +588,7 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): elif self.spec["blas"].name == "openblas": env.set("BLAS", "OpenBLAS") env.set("WITH_BLAS", "open") + env.set("OpenBLAS_HOME", self.spec["openblas"].prefix) elif self.spec["blas"].name == "veclibfort": env.set("BLAS", "vecLib") env.set("WITH_BLAS", "veclib") |