summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
authorWouter Deconinck <wdconinc@gmail.com>2024-01-03 07:42:56 -0600
committerGitHub <noreply@github.com>2024-01-03 06:42:56 -0700
commit375bc6fc9443082fd28e4bbcad44d1f627a452eb (patch)
treeb83624da7c905c30cabd864162d3fffd380733c6 /var
parent4b2baa7e91b009fdf2104e1f874dbb60a03f71a7 (diff)
downloadspack-375bc6fc9443082fd28e4bbcad44d1f627a452eb.tar.gz
spack-375bc6fc9443082fd28e4bbcad44d1f627a452eb.tar.bz2
spack-375bc6fc9443082fd28e4bbcad44d1f627a452eb.tar.xz
spack-375bc6fc9443082fd28e4bbcad44d1f627a452eb.zip
py-torch: set env OpenBLAS_HOME (#41745)
* py-torch: set env OpenBLAS_HOME Because [`FindOpenBLAS.cmake`](https://github.com/pytorch/pytorch/blob/main/cmake/Modules/FindOpenBLAS.cmake) uses a hardcoded list of search paths for includes and libraries, we have to pass the `OpenBLAS_HOME` environment variable. * py-torch: patch for ${OpenBLAS_HOME}/include/openblas The context of this patch is unchanged since v0.4.0. * py-torch: move patch before def patch * py-torch: also set Atlas_ROOT_DIR and BLIS_HOME * py-torch: fix openblas patch range to @:2.1 Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index bfe66074e9..6126b33111 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -419,6 +419,13 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
when="@2.0.0:2.0.1",
)
+ # Use correct OpenBLAS include path under prefix
+ patch(
+ "https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/110063.patch?full_index=1",
+ sha256="23fb4009f7337051fc5303927ff977186a5af960245e7212895406477d8b2f66",
+ when="@:2.1",
+ )
+
@when("@1.5.0:")
def patch(self):
# https://github.com/pytorch/pytorch/issues/52208
@@ -560,9 +567,11 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
if self.spec["blas"].name == "atlas":
env.set("BLAS", "ATLAS")
env.set("WITH_BLAS", "atlas")
+ env.set("Atlas_ROOT_DIR", self.spec["atlas"].prefix)
elif self.spec["blas"].name in ["blis", "amdblis"]:
env.set("BLAS", "BLIS")
env.set("WITH_BLAS", "blis")
+ env.set("BLIS_HOME", self.spec["blas"].prefix)
elif self.spec["blas"].name == "eigen":
env.set("BLAS", "Eigen")
elif self.spec["lapack"].name in ["libflame", "amdlibflame"]:
@@ -579,6 +588,7 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
elif self.spec["blas"].name == "openblas":
env.set("BLAS", "OpenBLAS")
env.set("WITH_BLAS", "open")
+ env.set("OpenBLAS_HOME", self.spec["openblas"].prefix)
elif self.spec["blas"].name == "veclibfort":
env.set("BLAS", "vecLib")
env.set("WITH_BLAS", "veclib")