diff options
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 31 | ||||
-rw-r--r-- | var/spack/repos/builtin/packages/python/package.py | 1 |
2 files changed, 32 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index c1737b0f1f..0ae2fb8177 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -269,6 +269,37 @@ class PyTorch(PythonPackage, CudaPackage): patch('https://github.com/pytorch/pytorch/commit/c74c0c571880df886474be297c556562e95c00e0.patch?full_index=1', sha256='8ff7d285e52e4718bad1ca01ceb3bb6471d7828329036bb94222717fcaa237da', when='@:1.9.1 ^cuda@11.4.100:') + @property + def headers(self): + """Discover header files in platlib.""" + + # Headers may be in either location + include = join_path(self.prefix, self.spec['python'].package.include) + platlib = join_path(self.prefix, self.spec['python'].package.platlib) + headers = find_all_headers(include) + find_all_headers(platlib) + + if headers: + return headers + + msg = 'Unable to locate {} headers in {} or {}' + raise NoHeadersError(msg.format(self.spec.name, include, platlib)) + + @property + def libs(self): + """Discover libraries in platlib.""" + + # Remove py- prefix in package name + library = 'lib' + self.spec.name[3:].replace('-', '?') + root = join_path(self.prefix, self.spec['python'].package.platlib) + + for shared in [True, False]: + libs = find_libraries(library, root, shared=shared, recursive=True) + if libs: + return libs + + msg = 'Unable to recursively locate {} libraries in {}' + raise NoLibrariesError(msg.format(self.spec.name, root)) + @when('@1.5.0:') def patch(self): # https://github.com/pytorch/pytorch/issues/52208 diff --git a/var/spack/repos/builtin/packages/python/package.py b/var/spack/repos/builtin/packages/python/package.py index d0d18b4de0..f526969f0c 100644 --- a/var/spack/repos/builtin/packages/python/package.py +++ b/var/spack/repos/builtin/packages/python/package.py @@ -1297,6 +1297,7 @@ config.update(get_paths()) module.python = self.command + module.python_include = join_path(dependent_spec.prefix, self.include) module.python_platlib = join_path(dependent_spec.prefix, self.platlib) module.python_purelib = join_path(dependent_spec.prefix, self.purelib) |