diff options
author | Massimiliano Culpo <massimiliano.culpo@gmail.com> | 2024-07-08 11:48:39 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-07-08 11:48:39 +0200 |
commit | 74398d74ace4b09ec9aabc9ce243b98ea4d7fada (patch) | |
tree | bef3de2ce71683c94b84a965e75b5cebba53f14b | |
parent | cef9c36183eb627898d5f12590fba4327198872e (diff) | |
download | spack-74398d74ace4b09ec9aabc9ce243b98ea4d7fada.tar.gz spack-74398d74ace4b09ec9aabc9ce243b98ea4d7fada.tar.bz2 spack-74398d74ace4b09ec9aabc9ce243b98ea4d7fada.tar.xz spack-74398d74ace4b09ec9aabc9ce243b98ea4d7fada.zip |
Add type-hints to RepoPath (#45068)
* Also, fix a bug with use_repositories + import spack.pkg
-rw-r--r-- | lib/spack/spack/cmd/create.py | 4 | ||||
-rw-r--r-- | lib/spack/spack/detection/path.py | 10 | ||||
-rw-r--r-- | lib/spack/spack/patch.py | 18 | ||||
-rw-r--r-- | lib/spack/spack/repo.py | 127 | ||||
-rw-r--r-- | lib/spack/spack/test/conftest.py | 1 | ||||
-rw-r--r-- | lib/spack/spack/test/repo.py | 50 |
6 files changed, 121 insertions, 89 deletions
diff --git a/lib/spack/spack/cmd/create.py b/lib/spack/spack/cmd/create.py index 0481b9d044..c380714297 100644 --- a/lib/spack/spack/cmd/create.py +++ b/lib/spack/spack/cmd/create.py @@ -941,9 +941,7 @@ def get_repository(args, name): ) else: if spec.namespace: - repo = spack.repo.PATH.get_repo(spec.namespace, None) - if not repo: - tty.die("Unknown namespace: '{0}'".format(spec.namespace)) + repo = spack.repo.PATH.get_repo(spec.namespace) else: repo = spack.repo.PATH.first_repo() diff --git a/lib/spack/spack/detection/path.py b/lib/spack/spack/detection/path.py index 711e17467e..943de16ee6 100644 --- a/lib/spack/spack/detection/path.py +++ b/lib/spack/spack/detection/path.py @@ -12,7 +12,7 @@ import os.path import re import sys import warnings -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Type import llnl.util.filesystem import llnl.util.lang @@ -200,7 +200,7 @@ class Finder: def default_path_hints(self) -> List[str]: return [] - def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]: + def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]: """Returns the list of patterns used to match candidate files. Args: @@ -226,7 +226,7 @@ class Finder: raise NotImplementedError("must be implemented by derived classes") def detect_specs( - self, *, pkg: "spack.package_base.PackageBase", paths: List[str] + self, *, pkg: Type["spack.package_base.PackageBase"], paths: List[str] ) -> List[DetectedPackage]: """Given a list of files matching the search patterns, returns a list of detected specs. @@ -327,7 +327,7 @@ class ExecutablesFinder(Finder): def default_path_hints(self) -> List[str]: return spack.util.environment.get_path("PATH") - def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]: + def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]: result = [] if hasattr(pkg, "executables") and hasattr(pkg, "platform_executables"): result = pkg.platform_executables() @@ -356,7 +356,7 @@ class LibrariesFinder(Finder): DYLD_LIBRARY_PATH, DYLD_FALLBACK_LIBRARY_PATH, and standard system library paths """ - def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]: + def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]: result = [] if hasattr(pkg, "libraries"): result = pkg.libraries diff --git a/lib/spack/spack/patch.py b/lib/spack/spack/patch.py index 531445b4f9..795a274243 100644 --- a/lib/spack/spack/patch.py +++ b/lib/spack/spack/patch.py @@ -9,7 +9,7 @@ import os import os.path import pathlib import sys -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, Optional, Tuple, Type, Union import llnl.util.filesystem from llnl.url import allowed_archive @@ -65,6 +65,9 @@ def apply_patch( patch(*args) +PatchPackageType = Union["spack.package_base.PackageBase", Type["spack.package_base.PackageBase"]] + + class Patch: """Base class for patches. @@ -77,7 +80,7 @@ class Patch: def __init__( self, - pkg: "spack.package_base.PackageBase", + pkg: PatchPackageType, path_or_url: str, level: int, working_dir: str, @@ -159,7 +162,7 @@ class FilePatch(Patch): def __init__( self, - pkg: "spack.package_base.PackageBase", + pkg: PatchPackageType, relative_path: str, level: int, working_dir: str, @@ -183,7 +186,7 @@ class FilePatch(Patch): abs_path: Optional[str] = None # At different times we call FilePatch on instances and classes pkg_cls = pkg if inspect.isclass(pkg) else pkg.__class__ - for cls in inspect.getmro(pkg_cls): + for cls in inspect.getmro(pkg_cls): # type: ignore if not hasattr(cls, "module"): # We've gone too far up the MRO break @@ -242,7 +245,7 @@ class UrlPatch(Patch): def __init__( self, - pkg: "spack.package_base.PackageBase", + pkg: PatchPackageType, url: str, level: int = 1, *, @@ -361,8 +364,9 @@ def from_dict( """ repository = repository or spack.repo.PATH owner = dictionary.get("owner") - if "owner" not in dictionary: - raise ValueError("Invalid patch dictionary: %s" % dictionary) + if owner is None: + raise ValueError(f"Invalid patch dictionary: {dictionary}") + assert isinstance(owner, str) pkg_cls = repository.get_pkg_class(owner) if "url" in dictionary: diff --git a/lib/spack/spack/repo.py b/lib/spack/spack/repo.py index f3394a118c..8fe587d3fd 100644 --- a/lib/spack/spack/repo.py +++ b/lib/spack/spack/repo.py @@ -675,15 +675,22 @@ class RepoPath: repository. Args: - repos (list): list Repo objects or paths to put in this RepoPath + repos: list Repo objects or paths to put in this RepoPath + cache: file cache associated with this repository + overrides: dict mapping package name to class attribute overrides for that package """ - def __init__(self, *repos, cache, overrides=None): - self.repos = [] + def __init__( + self, + *repos: Union[str, "Repo"], + cache: spack.caches.FileCacheType, + overrides: Optional[Dict[str, Any]] = None, + ) -> None: + self.repos: List[Repo] = [] self.by_namespace = nm.NamespaceTrie() - self._provider_index = None - self._patch_index = None - self._tag_index = None + self._provider_index: Optional[spack.provider_index.ProviderIndex] = None + self._patch_index: Optional[spack.patch.PatchCache] = None + self._tag_index: Optional[spack.tag.TagIndex] = None # Add each repo to this path. for repo in repos: @@ -694,13 +701,13 @@ class RepoPath: self.put_last(repo) except RepoError as e: tty.warn( - "Failed to initialize repository: '%s'." % repo, + f"Failed to initialize repository: '{repo}'.", e.message, "To remove the bad repository, run this command:", - " spack repo rm %s" % repo, + f" spack repo rm {repo}", ) - def put_first(self, repo): + def put_first(self, repo: "Repo") -> None: """Add repo first in the search path.""" if isinstance(repo, RepoPath): for r in reversed(repo.repos): @@ -728,50 +735,34 @@ class RepoPath: if repo in self.repos: self.repos.remove(repo) - def get_repo(self, namespace, default=NOT_PROVIDED): - """Get a repository by namespace. - - Arguments: - - namespace: - - Look up this namespace in the RepoPath, and return it if found. - - Optional Arguments: - - default: - - If default is provided, return it when the namespace - isn't found. If not, raise an UnknownNamespaceError. - """ + def get_repo(self, namespace: str) -> "Repo": + """Get a repository by namespace.""" full_namespace = python_package_for_repo(namespace) if full_namespace not in self.by_namespace: - if default == NOT_PROVIDED: - raise UnknownNamespaceError(namespace) - return default + raise UnknownNamespaceError(namespace) return self.by_namespace[full_namespace] - def first_repo(self): + def first_repo(self) -> Optional["Repo"]: """Get the first repo in precedence order.""" return self.repos[0] if self.repos else None @llnl.util.lang.memoized - def _all_package_names_set(self, include_virtuals): + def _all_package_names_set(self, include_virtuals) -> Set[str]: return {name for repo in self.repos for name in repo.all_package_names(include_virtuals)} @llnl.util.lang.memoized - def _all_package_names(self, include_virtuals): + def _all_package_names(self, include_virtuals: bool) -> List[str]: """Return all unique package names in all repositories.""" return sorted(self._all_package_names_set(include_virtuals), key=lambda n: n.lower()) - def all_package_names(self, include_virtuals=False): + def all_package_names(self, include_virtuals: bool = False) -> List[str]: return self._all_package_names(include_virtuals) - def package_path(self, name): + def package_path(self, name: str) -> str: """Get path to package.py file for this repo.""" return self.repo_for_pkg(name).package_path(name) - def all_package_paths(self): + def all_package_paths(self) -> Generator[str, None, None]: for name in self.all_package_names(): yield self.package_path(name) @@ -787,53 +778,52 @@ class RepoPath: for pkg in repo.packages_with_tags(*tags) } - def all_package_classes(self): + def all_package_classes(self) -> Generator[Type["spack.package_base.PackageBase"], None, None]: for name in self.all_package_names(): yield self.get_pkg_class(name) @property - def provider_index(self): + def provider_index(self) -> spack.provider_index.ProviderIndex: """Merged ProviderIndex from all Repos in the RepoPath.""" if self._provider_index is None: self._provider_index = spack.provider_index.ProviderIndex(repository=self) for repo in reversed(self.repos): self._provider_index.merge(repo.provider_index) - return self._provider_index @property - def tag_index(self): + def tag_index(self) -> spack.tag.TagIndex: """Merged TagIndex from all Repos in the RepoPath.""" if self._tag_index is None: self._tag_index = spack.tag.TagIndex(repository=self) for repo in reversed(self.repos): self._tag_index.merge(repo.tag_index) - return self._tag_index @property - def patch_index(self): + def patch_index(self) -> spack.patch.PatchCache: """Merged PatchIndex from all Repos in the RepoPath.""" if self._patch_index is None: self._patch_index = spack.patch.PatchCache(repository=self) for repo in reversed(self.repos): self._patch_index.update(repo.patch_index) - return self._patch_index @autospec - def providers_for(self, vpkg_spec): + def providers_for(self, virtual_spec: "spack.spec.Spec") -> List["spack.spec.Spec"]: providers = [ spec - for spec in self.provider_index.providers_for(vpkg_spec) + for spec in self.provider_index.providers_for(virtual_spec) if spec.name in self._all_package_names_set(include_virtuals=False) ] if not providers: - raise UnknownPackageError(vpkg_spec.fullname) + raise UnknownPackageError(virtual_spec.fullname) return providers @autospec - def extensions_for(self, extendee_spec): + def extensions_for( + self, extendee_spec: "spack.spec.Spec" + ) -> List["spack.package_base.PackageBase"]: return [ pkg_cls(spack.spec.Spec(pkg_cls.name)) for pkg_cls in self.all_package_classes() @@ -844,7 +834,7 @@ class RepoPath: """Time a package file in this repo was last updated.""" return max(repo.last_mtime() for repo in self.repos) - def repo_for_pkg(self, spec): + def repo_for_pkg(self, spec: Union[str, "spack.spec.Spec"]) -> "Repo": """Given a spec, get the repository for its package.""" # We don't @_autospec this function b/c it's called very frequently # and we want to avoid parsing str's into Specs unnecessarily. @@ -869,17 +859,20 @@ class RepoPath: return repo # If the package isn't in any repo, return the one with - # highest precedence. This is for commands like `spack edit` + # highest precedence. This is for commands like `spack edit` # that can operate on packages that don't exist yet. - return self.first_repo() + selected = self.first_repo() + if selected is None: + raise UnknownPackageError(name) + return selected - def get(self, spec): + def get(self, spec: "spack.spec.Spec") -> "spack.package_base.PackageBase": """Returns the package associated with the supplied spec.""" msg = "RepoPath.get can only be called on concrete specs" assert isinstance(spec, spack.spec.Spec) and spec.concrete, msg return self.repo_for_pkg(spec).get(spec) - def get_pkg_class(self, pkg_name): + def get_pkg_class(self, pkg_name: str) -> Type["spack.package_base.PackageBase"]: """Find a class for the spec's package and return the class object.""" return self.repo_for_pkg(pkg_name).get_pkg_class(pkg_name) @@ -892,26 +885,26 @@ class RepoPath: """ return self.repo_for_pkg(spec).dump_provenance(spec, path) - def dirname_for_package_name(self, pkg_name): + def dirname_for_package_name(self, pkg_name: str) -> str: return self.repo_for_pkg(pkg_name).dirname_for_package_name(pkg_name) - def filename_for_package_name(self, pkg_name): + def filename_for_package_name(self, pkg_name: str) -> str: return self.repo_for_pkg(pkg_name).filename_for_package_name(pkg_name) - def exists(self, pkg_name): + def exists(self, pkg_name: str) -> bool: """Whether package with the give name exists in the path's repos. Note that virtual packages do not "exist". """ return any(repo.exists(pkg_name) for repo in self.repos) - def _have_name(self, pkg_name): + def _have_name(self, pkg_name: str) -> bool: have_name = pkg_name is not None if have_name and not isinstance(pkg_name, str): - raise ValueError("is_virtual(): expected package name, got %s" % type(pkg_name)) + raise ValueError(f"is_virtual(): expected package name, got {type(pkg_name)}") return have_name - def is_virtual(self, pkg_name): + def is_virtual(self, pkg_name: str) -> bool: """Return True if the package with this name is virtual, False otherwise. This function use the provider index. If calling from a code block that @@ -923,7 +916,7 @@ class RepoPath: have_name = self._have_name(pkg_name) return have_name and pkg_name in self.provider_index - def is_virtual_safe(self, pkg_name): + def is_virtual_safe(self, pkg_name: str) -> bool: """Return True if the package with this name is virtual, False otherwise. This function doesn't use the provider index. @@ -1418,7 +1411,9 @@ def _path(configuration=None): return create(configuration=configuration) -def create(configuration): +def create( + configuration: Union["spack.config.Configuration", llnl.util.lang.Singleton] +) -> RepoPath: """Create a RepoPath from a configuration object. Args: @@ -1454,20 +1449,20 @@ def all_package_names(include_virtuals=False): @contextlib.contextmanager -def use_repositories(*paths_and_repos, **kwargs): +def use_repositories( + *paths_and_repos: Union[str, Repo], override: bool = True +) -> Generator[RepoPath, None, None]: """Use the repositories passed as arguments within the context manager. Args: *paths_and_repos: paths to the repositories to be used, or already constructed Repo objects - override (bool): if True use only the repositories passed as input, + override: if True use only the repositories passed as input, if False add them to the top of the list of current repositories. Returns: Corresponding RepoPath object """ global PATH - # TODO (Python 2.7): remove this kwargs on deprecation of Python 2.7 support - override = kwargs.get("override", True) paths = [getattr(x, "root", x) for x in paths_and_repos] scope_name = "use-repo-{}".format(uuid.uuid4()) repos_key = "repos:" if override else "repos" @@ -1476,7 +1471,8 @@ def use_repositories(*paths_and_repos, **kwargs): ) PATH, saved = create(configuration=spack.config.CONFIG), PATH try: - yield PATH + with REPOS_FINDER.switch_repo(PATH): # type: ignore + yield PATH finally: spack.config.CONFIG.remove_scope(scope_name=scope_name) PATH = saved @@ -1576,10 +1572,9 @@ class UnknownNamespaceError(UnknownEntityError): """Raised when we encounter an unknown namespace""" def __init__(self, namespace, name=None): - msg, long_msg = "Unknown namespace: {}".format(namespace), None + msg, long_msg = f"Unknown namespace: {namespace}", None if name == "yaml": - long_msg = "Did you mean to specify a filename with './{}.{}'?" - long_msg = long_msg.format(namespace, name) + long_msg = f"Did you mean to specify a filename with './{namespace}.{name}'?" super().__init__(msg, long_msg) diff --git a/lib/spack/spack/test/conftest.py b/lib/spack/spack/test/conftest.py index 1e3b336ad5..c95ce187d0 100644 --- a/lib/spack/spack/test/conftest.py +++ b/lib/spack/spack/test/conftest.py @@ -2069,4 +2069,5 @@ def _c_compiler_always_exists(): @pytest.fixture(scope="session") def mock_test_cache(tmp_path_factory): cache_dir = tmp_path_factory.mktemp("cache") + print(cache_dir) return spack.util.file_cache.FileCache(str(cache_dir)) diff --git a/lib/spack/spack/test/repo.py b/lib/spack/spack/test/repo.py index 3aa9f00698..6bb2c2625e 100644 --- a/lib/spack/spack/test/repo.py +++ b/lib/spack/spack/test/repo.py @@ -3,6 +3,7 @@ # # SPDX-License-Identifier: (Apache-2.0 OR MIT) import os +import pathlib import pytest @@ -205,6 +206,18 @@ def test_path_computation_with_names(method_name, mock_repo_path): assert qualified == unqualified +def test_use_repositories_and_import(): + """Tests that use_repositories changes the import search too""" + import spack.paths + + repo_dir = pathlib.Path(spack.paths.repos_path) + with spack.repo.use_repositories(str(repo_dir / "compiler_runtime.test")): + import spack.pkg.compiler_runtime.test.gcc_runtime + + with spack.repo.use_repositories(str(repo_dir / "builtin.mock")): + import spack.pkg.builtin.mock.cmake + + @pytest.mark.usefixtures("nullify_globals") class TestRepo: """Test that the Repo class work correctly, and does not depend on globals, @@ -219,8 +232,9 @@ class TestRepo: @pytest.mark.parametrize( "name,expected", [("mpi", True), ("mpich", False), ("mpileaks", False)] ) - def test_is_virtual(self, name, expected, mock_test_cache): - repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache) + @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath]) + def test_is_virtual(self, repo_cls, name, expected, mock_test_cache): + repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache) assert repo.is_virtual(name) is expected assert repo.is_virtual_safe(name) is expected @@ -258,13 +272,15 @@ class TestRepo: "extended,expected", [("python", ["py-extension1", "python-venv"]), ("perl", ["perl-extension"])], ) - def test_extensions(self, extended, expected, mock_test_cache): - repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache) + @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath]) + def test_extensions(self, repo_cls, extended, expected, mock_test_cache): + repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache) provider_names = {x.name for x in repo.extensions_for(extended)} assert provider_names.issuperset(expected) - def test_all_package_names(self, mock_test_cache): - repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache) + @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath]) + def test_all_package_names(self, repo_cls, mock_test_cache): + repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache) all_names = repo.all_package_names(include_virtuals=True) real_names = repo.all_package_names(include_virtuals=False) assert set(all_names).issuperset(real_names) @@ -272,10 +288,28 @@ class TestRepo: assert repo.is_virtual(name) assert repo.is_virtual_safe(name) - def test_packages_with_tags(self, mock_test_cache): - repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache) + @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath]) + def test_packages_with_tags(self, repo_cls, mock_test_cache): + repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache) r1 = repo.packages_with_tags("tag1") r2 = repo.packages_with_tags("tag1", "tag2") assert "mpich" in r1 and "mpich" in r2 assert "mpich2" in r1 and "mpich2" not in r2 assert set(r2).issubset(r1) + + +@pytest.mark.usefixtures("nullify_globals") +class TestRepoPath: + def test_creation_from_string(self, mock_test_cache): + repo = spack.repo.RepoPath(spack.paths.mock_packages_path, cache=mock_test_cache) + assert len(repo.repos) == 1 + assert repo.repos[0]._finder is repo + assert repo.by_namespace["spack.pkg.builtin.mock"] is repo.repos[0] + + def test_get_repo(self, mock_test_cache): + repo = spack.repo.RepoPath(spack.paths.mock_packages_path, cache=mock_test_cache) + # builtin.mock is there + assert repo.get_repo("builtin.mock") is repo.repos[0] + # foo is not there, raise + with pytest.raises(spack.repo.UnknownNamespaceError): + repo.get_repo("foo") |