summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index bfe66074e9..6126b33111 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -419,6 +419,13 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
when="@2.0.0:2.0.1",
)
+ # Use correct OpenBLAS include path under prefix
+ patch(
+ "https://patch-diff.githubusercontent.com/raw/pytorch/pytorch/pull/110063.patch?full_index=1",
+ sha256="23fb4009f7337051fc5303927ff977186a5af960245e7212895406477d8b2f66",
+ when="@:2.1",
+ )
+
@when("@1.5.0:")
def patch(self):
# https://github.com/pytorch/pytorch/issues/52208
@@ -560,9 +567,11 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
if self.spec["blas"].name == "atlas":
env.set("BLAS", "ATLAS")
env.set("WITH_BLAS", "atlas")
+ env.set("Atlas_ROOT_DIR", self.spec["atlas"].prefix)
elif self.spec["blas"].name in ["blis", "amdblis"]:
env.set("BLAS", "BLIS")
env.set("WITH_BLAS", "blis")
+ env.set("BLIS_HOME", self.spec["blas"].prefix)
elif self.spec["blas"].name == "eigen":
env.set("BLAS", "Eigen")
elif self.spec["lapack"].name in ["libflame", "amdlibflame"]:
@@ -579,6 +588,7 @@ class PyTorch(PythonPackage, CudaPackage, ROCmPackage):
elif self.spec["blas"].name == "openblas":
env.set("BLAS", "OpenBLAS")
env.set("WITH_BLAS", "open")
+ env.set("OpenBLAS_HOME", self.spec["openblas"].prefix)
elif self.spec["blas"].name == "veclibfort":
env.set("BLAS", "vecLib")
env.set("WITH_BLAS", "veclib")