summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
authorAdam J. Stewart <ajstewart426@gmail.com>2022-07-07 07:31:09 -0700
committerGitHub <noreply@github.com>2022-07-07 16:31:09 +0200
commit386f08c1b49cab85513af339d9a97151d79061be (patch)
tree9589bfcd67111d72a78d8429fbcd15a6a6755506 /var
parent9c437e2a107312ad5ac68f29d3dedec2eb5f16fe (diff)
downloadspack-386f08c1b49cab85513af339d9a97151d79061be.tar.gz
spack-386f08c1b49cab85513af339d9a97151d79061be.tar.bz2
spack-386f08c1b49cab85513af339d9a97151d79061be.tar.xz
spack-386f08c1b49cab85513af339d9a97151d79061be.zip
py-torch: re-add headers/libs properties (#31446)
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py31
-rw-r--r--var/spack/repos/builtin/packages/python/package.py1
2 files changed, 32 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index c1737b0f1f..0ae2fb8177 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -269,6 +269,37 @@ class PyTorch(PythonPackage, CudaPackage):
patch('https://github.com/pytorch/pytorch/commit/c74c0c571880df886474be297c556562e95c00e0.patch?full_index=1',
sha256='8ff7d285e52e4718bad1ca01ceb3bb6471d7828329036bb94222717fcaa237da', when='@:1.9.1 ^cuda@11.4.100:')
+ @property
+ def headers(self):
+ """Discover header files in platlib."""
+
+ # Headers may be in either location
+ include = join_path(self.prefix, self.spec['python'].package.include)
+ platlib = join_path(self.prefix, self.spec['python'].package.platlib)
+ headers = find_all_headers(include) + find_all_headers(platlib)
+
+ if headers:
+ return headers
+
+ msg = 'Unable to locate {} headers in {} or {}'
+ raise NoHeadersError(msg.format(self.spec.name, include, platlib))
+
+ @property
+ def libs(self):
+ """Discover libraries in platlib."""
+
+ # Remove py- prefix in package name
+ library = 'lib' + self.spec.name[3:].replace('-', '?')
+ root = join_path(self.prefix, self.spec['python'].package.platlib)
+
+ for shared in [True, False]:
+ libs = find_libraries(library, root, shared=shared, recursive=True)
+ if libs:
+ return libs
+
+ msg = 'Unable to recursively locate {} libraries in {}'
+ raise NoLibrariesError(msg.format(self.spec.name, root))
+
@when('@1.5.0:')
def patch(self):
# https://github.com/pytorch/pytorch/issues/52208
diff --git a/var/spack/repos/builtin/packages/python/package.py b/var/spack/repos/builtin/packages/python/package.py
index d0d18b4de0..f526969f0c 100644
--- a/var/spack/repos/builtin/packages/python/package.py
+++ b/var/spack/repos/builtin/packages/python/package.py
@@ -1297,6 +1297,7 @@ config.update(get_paths())
module.python = self.command
+ module.python_include = join_path(dependent_spec.prefix, self.include)
module.python_platlib = join_path(dependent_spec.prefix, self.platlib)
module.python_purelib = join_path(dependent_spec.prefix, self.purelib)