summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py31
-rw-r--r--var/spack/repos/builtin/packages/python/package.py1
2 files changed, 32 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index c1737b0f1f..0ae2fb8177 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -269,6 +269,37 @@ class PyTorch(PythonPackage, CudaPackage):
patch('https://github.com/pytorch/pytorch/commit/c74c0c571880df886474be297c556562e95c00e0.patch?full_index=1',
sha256='8ff7d285e52e4718bad1ca01ceb3bb6471d7828329036bb94222717fcaa237da', when='@:1.9.1 ^cuda@11.4.100:')
+ @property
+ def headers(self):
+ """Discover header files in platlib."""
+
+ # Headers may be in either location
+ include = join_path(self.prefix, self.spec['python'].package.include)
+ platlib = join_path(self.prefix, self.spec['python'].package.platlib)
+ headers = find_all_headers(include) + find_all_headers(platlib)
+
+ if headers:
+ return headers
+
+ msg = 'Unable to locate {} headers in {} or {}'
+ raise NoHeadersError(msg.format(self.spec.name, include, platlib))
+
+ @property
+ def libs(self):
+ """Discover libraries in platlib."""
+
+ # Remove py- prefix in package name
+ library = 'lib' + self.spec.name[3:].replace('-', '?')
+ root = join_path(self.prefix, self.spec['python'].package.platlib)
+
+ for shared in [True, False]:
+ libs = find_libraries(library, root, shared=shared, recursive=True)
+ if libs:
+ return libs
+
+ msg = 'Unable to recursively locate {} libraries in {}'
+ raise NoLibrariesError(msg.format(self.spec.name, root))
+
@when('@1.5.0:')
def patch(self):
# https://github.com/pytorch/pytorch/issues/52208
diff --git a/var/spack/repos/builtin/packages/python/package.py b/var/spack/repos/builtin/packages/python/package.py
index d0d18b4de0..f526969f0c 100644
--- a/var/spack/repos/builtin/packages/python/package.py
+++ b/var/spack/repos/builtin/packages/python/package.py
@@ -1297,6 +1297,7 @@ config.update(get_paths())
module.python = self.command
+ module.python_include = join_path(dependent_spec.prefix, self.include)
module.python_platlib = join_path(dependent_spec.prefix, self.platlib)
module.python_purelib = join_path(dependent_spec.prefix, self.purelib)