From 375bc6fc9443082fd28e4bbcad44d1f627a452eb Mon Sep 17 00:00:00 2001 From: Wouter Deconinck Date: Wed, 3 Jan 2024 07:42:56 -0600 Subject: py-torch: set env OpenBLAS_HOME (#41745) * py-torch: set env OpenBLAS_HOME Because [`FindOpenBLAS.cmake`](https://github.com/pytorch/pytorch/blob/main/cmake/Modules/FindOpenBLAS.cmake) uses a hardcoded list of search paths for includes and libraries, we have to pass the `OpenBLAS_HOME` environment variable. * py-torch: patch for ${OpenBLAS_HOME}/include/openblas The context of this patch is unchanged since v0.4.0. * py-torch: move patch before def patch * py-torch: also set Atlas_ROOT_DIR and BLIS_HOME * py-torch: fix openblas patch range to @:2.1 Co-authored-by: Adam J. Stewart --------- Co-authored-by: Adam J. Stewart --- var/spack/repos/builtin/packages/py-torch/package.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'var') diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index bfe66074e9..6126b33111 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -419,6 +419,13 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): when="@2.0.0:2.0.1", ) + # Use correct OpenBLAS include path under prefix + patch( + "https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/110063.patch?full_index=1", + sha256="23fb4009f7337051fc5303927ff977186a5af960245e7212895406477d8b2f66", + when="@:2.1", + ) + @when("@1.5.0:") def patch(self): # https://github.com/pytorch/pytorch/issues/52208 @@ -560,9 +567,11 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): if self.spec["blas"].name == "atlas": env.set("BLAS", "ATLAS") env.set("WITH_BLAS", "atlas") + env.set("Atlas_ROOT_DIR", self.spec["atlas"].prefix) elif self.spec["blas"].name in ["blis", "amdblis"]: env.set("BLAS", "BLIS") env.set("WITH_BLAS", "blis") + env.set("BLIS_HOME", self.spec["blas"].prefix) elif self.spec["blas"].name == "eigen": env.set("BLAS", "Eigen") elif self.spec["lapack"].name in ["libflame", "amdlibflame"]: @@ -579,6 +588,7 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): elif self.spec["blas"].name == "openblas": env.set("BLAS", "OpenBLAS") env.set("WITH_BLAS", "open") + env.set("OpenBLAS_HOME", self.spec["openblas"].prefix) elif self.spec["blas"].name == "veclibfort": env.set("BLAS", "vecLib") env.set("WITH_BLAS", "veclib") -- cgit v1.2.3-70-g09d2