diff options
Diffstat (limited to 'var')
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index bfe66074e9..6126b33111 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -419,6 +419,13 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): when="@2.0.0:2.0.1", ) + # Use correct OpenBLAS include path under prefix + patch( + "https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/110063.patch?full_index=1", + sha256="23fb4009f7337051fc5303927ff977186a5af960245e7212895406477d8b2f66", + when="@:2.1", + ) + @when("@1.5.0:") def patch(self): # https://github.com/pytorch/pytorch/issues/52208 @@ -560,9 +567,11 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): if self.spec["blas"].name == "atlas": env.set("BLAS", "ATLAS") env.set("WITH_BLAS", "atlas") + env.set("Atlas_ROOT_DIR", self.spec["atlas"].prefix) elif self.spec["blas"].name in ["blis", "amdblis"]: env.set("BLAS", "BLIS") env.set("WITH_BLAS", "blis") + env.set("BLIS_HOME", self.spec["blas"].prefix) elif self.spec["blas"].name == "eigen": env.set("BLAS", "Eigen") elif self.spec["lapack"].name in ["libflame", "amdlibflame"]: @@ -579,6 +588,7 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage): elif self.spec["blas"].name == "openblas": env.set("BLAS", "OpenBLAS") env.set("WITH_BLAS", "open") + env.set("OpenBLAS_HOME", self.spec["openblas"].prefix) elif self.spec["blas"].name == "veclibfort": env.set("BLAS", "vecLib") env.set("WITH_BLAS", "veclib") |