summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMassimiliano Culpo <massimiliano.culpo@gmail.com>2024-07-08 11:48:39 +0200
committerGitHub <noreply@github.com>2024-07-08 11:48:39 +0200
commit74398d74ace4b09ec9aabc9ce243b98ea4d7fada (patch)
treebef3de2ce71683c94b84a965e75b5cebba53f14b
parentcef9c36183eb627898d5f12590fba4327198872e (diff)
downloadspack-74398d74ace4b09ec9aabc9ce243b98ea4d7fada.tar.gz
spack-74398d74ace4b09ec9aabc9ce243b98ea4d7fada.tar.bz2
spack-74398d74ace4b09ec9aabc9ce243b98ea4d7fada.tar.xz
spack-74398d74ace4b09ec9aabc9ce243b98ea4d7fada.zip
Add type-hints to RepoPath (#45068)
* Also, fix a bug with use_repositories + import spack.pkg
-rw-r--r--lib/spack/spack/cmd/create.py4
-rw-r--r--lib/spack/spack/detection/path.py10
-rw-r--r--lib/spack/spack/patch.py18
-rw-r--r--lib/spack/spack/repo.py127
-rw-r--r--lib/spack/spack/test/conftest.py1
-rw-r--r--lib/spack/spack/test/repo.py50
6 files changed, 121 insertions, 89 deletions
diff --git a/lib/spack/spack/cmd/create.py b/lib/spack/spack/cmd/create.py
index 0481b9d044..c380714297 100644
--- a/lib/spack/spack/cmd/create.py
+++ b/lib/spack/spack/cmd/create.py
@@ -941,9 +941,7 @@ def get_repository(args, name):
)
else:
if spec.namespace:
- repo = spack.repo.PATH.get_repo(spec.namespace, None)
- if not repo:
- tty.die("Unknown namespace: '{0}'".format(spec.namespace))
+ repo = spack.repo.PATH.get_repo(spec.namespace)
else:
repo = spack.repo.PATH.first_repo()
diff --git a/lib/spack/spack/detection/path.py b/lib/spack/spack/detection/path.py
index 711e17467e..943de16ee6 100644
--- a/lib/spack/spack/detection/path.py
+++ b/lib/spack/spack/detection/path.py
@@ -12,7 +12,7 @@ import os.path
import re
import sys
import warnings
-from typing import Dict, List, Optional, Set, Tuple
+from typing import Dict, List, Optional, Set, Tuple, Type
import llnl.util.filesystem
import llnl.util.lang
@@ -200,7 +200,7 @@ class Finder:
def default_path_hints(self) -> List[str]:
return []
- def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]:
+ def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]:
"""Returns the list of patterns used to match candidate files.
Args:
@@ -226,7 +226,7 @@ class Finder:
raise NotImplementedError("must be implemented by derived classes")
def detect_specs(
- self, *, pkg: "spack.package_base.PackageBase", paths: List[str]
+ self, *, pkg: Type["spack.package_base.PackageBase"], paths: List[str]
) -> List[DetectedPackage]:
"""Given a list of files matching the search patterns, returns a list of detected specs.
@@ -327,7 +327,7 @@ class ExecutablesFinder(Finder):
def default_path_hints(self) -> List[str]:
return spack.util.environment.get_path("PATH")
- def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]:
+ def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]:
result = []
if hasattr(pkg, "executables") and hasattr(pkg, "platform_executables"):
result = pkg.platform_executables()
@@ -356,7 +356,7 @@ class LibrariesFinder(Finder):
DYLD_LIBRARY_PATH, DYLD_FALLBACK_LIBRARY_PATH, and standard system library paths
"""
- def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]:
+ def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]:
result = []
if hasattr(pkg, "libraries"):
result = pkg.libraries
diff --git a/lib/spack/spack/patch.py b/lib/spack/spack/patch.py
index 531445b4f9..795a274243 100644
--- a/lib/spack/spack/patch.py
+++ b/lib/spack/spack/patch.py
@@ -9,7 +9,7 @@ import os
import os.path
import pathlib
import sys
-from typing import Any, Dict, Optional, Tuple, Type
+from typing import Any, Dict, Optional, Tuple, Type, Union
import llnl.util.filesystem
from llnl.url import allowed_archive
@@ -65,6 +65,9 @@ def apply_patch(
patch(*args)
+PatchPackageType = Union["spack.package_base.PackageBase", Type["spack.package_base.PackageBase"]]
+
+
class Patch:
"""Base class for patches.
@@ -77,7 +80,7 @@ class Patch:
def __init__(
self,
- pkg: "spack.package_base.PackageBase",
+ pkg: PatchPackageType,
path_or_url: str,
level: int,
working_dir: str,
@@ -159,7 +162,7 @@ class FilePatch(Patch):
def __init__(
self,
- pkg: "spack.package_base.PackageBase",
+ pkg: PatchPackageType,
relative_path: str,
level: int,
working_dir: str,
@@ -183,7 +186,7 @@ class FilePatch(Patch):
abs_path: Optional[str] = None
# At different times we call FilePatch on instances and classes
pkg_cls = pkg if inspect.isclass(pkg) else pkg.__class__
- for cls in inspect.getmro(pkg_cls):
+ for cls in inspect.getmro(pkg_cls): # type: ignore
if not hasattr(cls, "module"):
# We've gone too far up the MRO
break
@@ -242,7 +245,7 @@ class UrlPatch(Patch):
def __init__(
self,
- pkg: "spack.package_base.PackageBase",
+ pkg: PatchPackageType,
url: str,
level: int = 1,
*,
@@ -361,8 +364,9 @@ def from_dict(
"""
repository = repository or spack.repo.PATH
owner = dictionary.get("owner")
- if "owner" not in dictionary:
- raise ValueError("Invalid patch dictionary: %s" % dictionary)
+ if owner is None:
+ raise ValueError(f"Invalid patch dictionary: {dictionary}")
+ assert isinstance(owner, str)
pkg_cls = repository.get_pkg_class(owner)
if "url" in dictionary:
diff --git a/lib/spack/spack/repo.py b/lib/spack/spack/repo.py
index f3394a118c..8fe587d3fd 100644
--- a/lib/spack/spack/repo.py
+++ b/lib/spack/spack/repo.py
@@ -675,15 +675,22 @@ class RepoPath:
repository.
Args:
- repos (list): list Repo objects or paths to put in this RepoPath
+ repos: list Repo objects or paths to put in this RepoPath
+ cache: file cache associated with this repository
+ overrides: dict mapping package name to class attribute overrides for that package
"""
- def __init__(self, *repos, cache, overrides=None):
- self.repos = []
+ def __init__(
+ self,
+ *repos: Union[str, "Repo"],
+ cache: spack.caches.FileCacheType,
+ overrides: Optional[Dict[str, Any]] = None,
+ ) -> None:
+ self.repos: List[Repo] = []
self.by_namespace = nm.NamespaceTrie()
- self._provider_index = None
- self._patch_index = None
- self._tag_index = None
+ self._provider_index: Optional[spack.provider_index.ProviderIndex] = None
+ self._patch_index: Optional[spack.patch.PatchCache] = None
+ self._tag_index: Optional[spack.tag.TagIndex] = None
# Add each repo to this path.
for repo in repos:
@@ -694,13 +701,13 @@ class RepoPath:
self.put_last(repo)
except RepoError as e:
tty.warn(
- "Failed to initialize repository: '%s'." % repo,
+ f"Failed to initialize repository: '{repo}'.",
e.message,
"To remove the bad repository, run this command:",
- " spack repo rm %s" % repo,
+ f" spack repo rm {repo}",
)
- def put_first(self, repo):
+ def put_first(self, repo: "Repo") -> None:
"""Add repo first in the search path."""
if isinstance(repo, RepoPath):
for r in reversed(repo.repos):
@@ -728,50 +735,34 @@ class RepoPath:
if repo in self.repos:
self.repos.remove(repo)
- def get_repo(self, namespace, default=NOT_PROVIDED):
- """Get a repository by namespace.
-
- Arguments:
-
- namespace:
-
- Look up this namespace in the RepoPath, and return it if found.
-
- Optional Arguments:
-
- default:
-
- If default is provided, return it when the namespace
- isn't found. If not, raise an UnknownNamespaceError.
- """
+ def get_repo(self, namespace: str) -> "Repo":
+ """Get a repository by namespace."""
full_namespace = python_package_for_repo(namespace)
if full_namespace not in self.by_namespace:
- if default == NOT_PROVIDED:
- raise UnknownNamespaceError(namespace)
- return default
+ raise UnknownNamespaceError(namespace)
return self.by_namespace[full_namespace]
- def first_repo(self):
+ def first_repo(self) -> Optional["Repo"]:
"""Get the first repo in precedence order."""
return self.repos[0] if self.repos else None
@llnl.util.lang.memoized
- def _all_package_names_set(self, include_virtuals):
+ def _all_package_names_set(self, include_virtuals) -> Set[str]:
return {name for repo in self.repos for name in repo.all_package_names(include_virtuals)}
@llnl.util.lang.memoized
- def _all_package_names(self, include_virtuals):
+ def _all_package_names(self, include_virtuals: bool) -> List[str]:
"""Return all unique package names in all repositories."""
return sorted(self._all_package_names_set(include_virtuals), key=lambda n: n.lower())
- def all_package_names(self, include_virtuals=False):
+ def all_package_names(self, include_virtuals: bool = False) -> List[str]:
return self._all_package_names(include_virtuals)
- def package_path(self, name):
+ def package_path(self, name: str) -> str:
"""Get path to package.py file for this repo."""
return self.repo_for_pkg(name).package_path(name)
- def all_package_paths(self):
+ def all_package_paths(self) -> Generator[str, None, None]:
for name in self.all_package_names():
yield self.package_path(name)
@@ -787,53 +778,52 @@ class RepoPath:
for pkg in repo.packages_with_tags(*tags)
}
- def all_package_classes(self):
+ def all_package_classes(self) -> Generator[Type["spack.package_base.PackageBase"], None, None]:
for name in self.all_package_names():
yield self.get_pkg_class(name)
@property
- def provider_index(self):
+ def provider_index(self) -> spack.provider_index.ProviderIndex:
"""Merged ProviderIndex from all Repos in the RepoPath."""
if self._provider_index is None:
self._provider_index = spack.provider_index.ProviderIndex(repository=self)
for repo in reversed(self.repos):
self._provider_index.merge(repo.provider_index)
-
return self._provider_index
@property
- def tag_index(self):
+ def tag_index(self) -> spack.tag.TagIndex:
"""Merged TagIndex from all Repos in the RepoPath."""
if self._tag_index is None:
self._tag_index = spack.tag.TagIndex(repository=self)
for repo in reversed(self.repos):
self._tag_index.merge(repo.tag_index)
-
return self._tag_index
@property
- def patch_index(self):
+ def patch_index(self) -> spack.patch.PatchCache:
"""Merged PatchIndex from all Repos in the RepoPath."""
if self._patch_index is None:
self._patch_index = spack.patch.PatchCache(repository=self)
for repo in reversed(self.repos):
self._patch_index.update(repo.patch_index)
-
return self._patch_index
@autospec
- def providers_for(self, vpkg_spec):
+ def providers_for(self, virtual_spec: "spack.spec.Spec") -> List["spack.spec.Spec"]:
providers = [
spec
- for spec in self.provider_index.providers_for(vpkg_spec)
+ for spec in self.provider_index.providers_for(virtual_spec)
if spec.name in self._all_package_names_set(include_virtuals=False)
]
if not providers:
- raise UnknownPackageError(vpkg_spec.fullname)
+ raise UnknownPackageError(virtual_spec.fullname)
return providers
@autospec
- def extensions_for(self, extendee_spec):
+ def extensions_for(
+ self, extendee_spec: "spack.spec.Spec"
+ ) -> List["spack.package_base.PackageBase"]:
return [
pkg_cls(spack.spec.Spec(pkg_cls.name))
for pkg_cls in self.all_package_classes()
@@ -844,7 +834,7 @@ class RepoPath:
"""Time a package file in this repo was last updated."""
return max(repo.last_mtime() for repo in self.repos)
- def repo_for_pkg(self, spec):
+ def repo_for_pkg(self, spec: Union[str, "spack.spec.Spec"]) -> "Repo":
"""Given a spec, get the repository for its package."""
# We don't @_autospec this function b/c it's called very frequently
# and we want to avoid parsing str's into Specs unnecessarily.
@@ -869,17 +859,20 @@ class RepoPath:
return repo
# If the package isn't in any repo, return the one with
- # highest precedence. This is for commands like `spack edit`
+ # highest precedence. This is for commands like `spack edit`
# that can operate on packages that don't exist yet.
- return self.first_repo()
+ selected = self.first_repo()
+ if selected is None:
+ raise UnknownPackageError(name)
+ return selected
- def get(self, spec):
+ def get(self, spec: "spack.spec.Spec") -> "spack.package_base.PackageBase":
"""Returns the package associated with the supplied spec."""
msg = "RepoPath.get can only be called on concrete specs"
assert isinstance(spec, spack.spec.Spec) and spec.concrete, msg
return self.repo_for_pkg(spec).get(spec)
- def get_pkg_class(self, pkg_name):
+ def get_pkg_class(self, pkg_name: str) -> Type["spack.package_base.PackageBase"]:
"""Find a class for the spec's package and return the class object."""
return self.repo_for_pkg(pkg_name).get_pkg_class(pkg_name)
@@ -892,26 +885,26 @@ class RepoPath:
"""
return self.repo_for_pkg(spec).dump_provenance(spec, path)
- def dirname_for_package_name(self, pkg_name):
+ def dirname_for_package_name(self, pkg_name: str) -> str:
return self.repo_for_pkg(pkg_name).dirname_for_package_name(pkg_name)
- def filename_for_package_name(self, pkg_name):
+ def filename_for_package_name(self, pkg_name: str) -> str:
return self.repo_for_pkg(pkg_name).filename_for_package_name(pkg_name)
- def exists(self, pkg_name):
+ def exists(self, pkg_name: str) -> bool:
"""Whether package with the give name exists in the path's repos.
Note that virtual packages do not "exist".
"""
return any(repo.exists(pkg_name) for repo in self.repos)
- def _have_name(self, pkg_name):
+ def _have_name(self, pkg_name: str) -> bool:
have_name = pkg_name is not None
if have_name and not isinstance(pkg_name, str):
- raise ValueError("is_virtual(): expected package name, got %s" % type(pkg_name))
+ raise ValueError(f"is_virtual(): expected package name, got {type(pkg_name)}")
return have_name
- def is_virtual(self, pkg_name):
+ def is_virtual(self, pkg_name: str) -> bool:
"""Return True if the package with this name is virtual, False otherwise.
This function use the provider index. If calling from a code block that
@@ -923,7 +916,7 @@ class RepoPath:
have_name = self._have_name(pkg_name)
return have_name and pkg_name in self.provider_index
- def is_virtual_safe(self, pkg_name):
+ def is_virtual_safe(self, pkg_name: str) -> bool:
"""Return True if the package with this name is virtual, False otherwise.
This function doesn't use the provider index.
@@ -1418,7 +1411,9 @@ def _path(configuration=None):
return create(configuration=configuration)
-def create(configuration):
+def create(
+ configuration: Union["spack.config.Configuration", llnl.util.lang.Singleton]
+) -> RepoPath:
"""Create a RepoPath from a configuration object.
Args:
@@ -1454,20 +1449,20 @@ def all_package_names(include_virtuals=False):
@contextlib.contextmanager
-def use_repositories(*paths_and_repos, **kwargs):
+def use_repositories(
+ *paths_and_repos: Union[str, Repo], override: bool = True
+) -> Generator[RepoPath, None, None]:
"""Use the repositories passed as arguments within the context manager.
Args:
*paths_and_repos: paths to the repositories to be used, or
already constructed Repo objects
- override (bool): if True use only the repositories passed as input,
+ override: if True use only the repositories passed as input,
if False add them to the top of the list of current repositories.
Returns:
Corresponding RepoPath object
"""
global PATH
- # TODO (Python 2.7): remove this kwargs on deprecation of Python 2.7 support
- override = kwargs.get("override", True)
paths = [getattr(x, "root", x) for x in paths_and_repos]
scope_name = "use-repo-{}".format(uuid.uuid4())
repos_key = "repos:" if override else "repos"
@@ -1476,7 +1471,8 @@ def use_repositories(*paths_and_repos, **kwargs):
)
PATH, saved = create(configuration=spack.config.CONFIG), PATH
try:
- yield PATH
+ with REPOS_FINDER.switch_repo(PATH): # type: ignore
+ yield PATH
finally:
spack.config.CONFIG.remove_scope(scope_name=scope_name)
PATH = saved
@@ -1576,10 +1572,9 @@ class UnknownNamespaceError(UnknownEntityError):
"""Raised when we encounter an unknown namespace"""
def __init__(self, namespace, name=None):
- msg, long_msg = "Unknown namespace: {}".format(namespace), None
+ msg, long_msg = f"Unknown namespace: {namespace}", None
if name == "yaml":
- long_msg = "Did you mean to specify a filename with './{}.{}'?"
- long_msg = long_msg.format(namespace, name)
+ long_msg = f"Did you mean to specify a filename with './{namespace}.{name}'?"
super().__init__(msg, long_msg)
diff --git a/lib/spack/spack/test/conftest.py b/lib/spack/spack/test/conftest.py
index 1e3b336ad5..c95ce187d0 100644
--- a/lib/spack/spack/test/conftest.py
+++ b/lib/spack/spack/test/conftest.py
@@ -2069,4 +2069,5 @@ def _c_compiler_always_exists():
@pytest.fixture(scope="session")
def mock_test_cache(tmp_path_factory):
cache_dir = tmp_path_factory.mktemp("cache")
+ print(cache_dir)
return spack.util.file_cache.FileCache(str(cache_dir))
diff --git a/lib/spack/spack/test/repo.py b/lib/spack/spack/test/repo.py
index 3aa9f00698..6bb2c2625e 100644
--- a/lib/spack/spack/test/repo.py
+++ b/lib/spack/spack/test/repo.py
@@ -3,6 +3,7 @@
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import os
+import pathlib
import pytest
@@ -205,6 +206,18 @@ def test_path_computation_with_names(method_name, mock_repo_path):
assert qualified == unqualified
+def test_use_repositories_and_import():
+ """Tests that use_repositories changes the import search too"""
+ import spack.paths
+
+ repo_dir = pathlib.Path(spack.paths.repos_path)
+ with spack.repo.use_repositories(str(repo_dir / "compiler_runtime.test")):
+ import spack.pkg.compiler_runtime.test.gcc_runtime
+
+ with spack.repo.use_repositories(str(repo_dir / "builtin.mock")):
+ import spack.pkg.builtin.mock.cmake
+
+
@pytest.mark.usefixtures("nullify_globals")
class TestRepo:
"""Test that the Repo class work correctly, and does not depend on globals,
@@ -219,8 +232,9 @@ class TestRepo:
@pytest.mark.parametrize(
"name,expected", [("mpi", True), ("mpich", False), ("mpileaks", False)]
)
- def test_is_virtual(self, name, expected, mock_test_cache):
- repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache)
+ @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
+ def test_is_virtual(self, repo_cls, name, expected, mock_test_cache):
+ repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
assert repo.is_virtual(name) is expected
assert repo.is_virtual_safe(name) is expected
@@ -258,13 +272,15 @@ class TestRepo:
"extended,expected",
[("python", ["py-extension1", "python-venv"]), ("perl", ["perl-extension"])],
)
- def test_extensions(self, extended, expected, mock_test_cache):
- repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache)
+ @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
+ def test_extensions(self, repo_cls, extended, expected, mock_test_cache):
+ repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
provider_names = {x.name for x in repo.extensions_for(extended)}
assert provider_names.issuperset(expected)
- def test_all_package_names(self, mock_test_cache):
- repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache)
+ @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
+ def test_all_package_names(self, repo_cls, mock_test_cache):
+ repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
all_names = repo.all_package_names(include_virtuals=True)
real_names = repo.all_package_names(include_virtuals=False)
assert set(all_names).issuperset(real_names)
@@ -272,10 +288,28 @@ class TestRepo:
assert repo.is_virtual(name)
assert repo.is_virtual_safe(name)
- def test_packages_with_tags(self, mock_test_cache):
- repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache)
+ @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
+ def test_packages_with_tags(self, repo_cls, mock_test_cache):
+ repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
r1 = repo.packages_with_tags("tag1")
r2 = repo.packages_with_tags("tag1", "tag2")
assert "mpich" in r1 and "mpich" in r2
assert "mpich2" in r1 and "mpich2" not in r2
assert set(r2).issubset(r1)
+
+
+@pytest.mark.usefixtures("nullify_globals")
+class TestRepoPath:
+ def test_creation_from_string(self, mock_test_cache):
+ repo = spack.repo.RepoPath(spack.paths.mock_packages_path, cache=mock_test_cache)
+ assert len(repo.repos) == 1
+ assert repo.repos[0]._finder is repo
+ assert repo.by_namespace["spack.pkg.builtin.mock"] is repo.repos[0]
+
+ def test_get_repo(self, mock_test_cache):
+ repo = spack.repo.RepoPath(spack.paths.mock_packages_path, cache=mock_test_cache)
+ # builtin.mock is there
+ assert repo.get_repo("builtin.mock") is repo.repos[0]
+ # foo is not there, raise
+ with pytest.raises(spack.repo.UnknownNamespaceError):
+ repo.get_repo("foo")