From 6753cc0b814d80ab240d6aebf6d756497dd046d3 Mon Sep 17 00:00:00 2001 From: Todd Gamblin Date: Wed, 21 Jun 2023 00:20:32 -0700 Subject: refactor: Index provided virtuals by `when` spec Part 4 of reworking all package metadata to key by `when` conditions. Changes conflict dictionary structure from this: { provided_spec: {when_spec, ...} } to this: { when_spec: {provided_spec, ...} } --- lib/spack/spack/audit.py | 34 +++++++++++++++++++++++----------- lib/spack/spack/cmd/info.py | 8 +------- lib/spack/spack/directives.py | 10 +++++----- lib/spack/spack/package_base.py | 20 ++++++++++++++------ lib/spack/spack/provider_index.py | 4 ++-- lib/spack/spack/solver/asp.py | 23 ++++++++++++----------- lib/spack/spack/spec.py | 16 +++++++--------- 7 files changed, 64 insertions(+), 51 deletions(-) diff --git a/lib/spack/spack/audit.py b/lib/spack/spack/audit.py index fedb6c1138..a3591346be 100644 --- a/lib/spack/spack/audit.py +++ b/lib/spack/spack/audit.py @@ -702,15 +702,13 @@ def _unknown_variants_in_directives(pkgs, error_cls): ) ) - # Check "patch" directive - for _, triggers in pkg_cls.provided.items(): - triggers = [spack.spec.Spec(x) for x in triggers] - for vrn in triggers: - errors.extend( - _analyze_variants_in_directive( - pkg_cls, vrn, directive="patch", error_cls=error_cls - ) + # Check "provides" directive + for when_spec in pkg_cls.provided: + errors.extend( + _analyze_variants_in_directive( + pkg_cls, when_spec, directive="provides", error_cls=error_cls ) + ) # Check "resource" directive for vrn in pkg_cls.resources: @@ -752,6 +750,18 @@ def _issues_in_depends_on_directive(pkgs, error_cls): ] errors.append(error_cls(summary=summary, details=details)) + def check_virtual_with_variants(spec, msg): + if not spec.virtual or not spec.variants: + return + error = error_cls( + f"{pkg_name}: {msg}", + f"remove variants from '{spec}' in depends_on directive in {filename}", + ) + errors.append(error) + + check_virtual_with_variants(dep.spec, "virtual dependency cannot have variants") + check_virtual_with_variants(dep.spec, "virtual when= spec cannot have variants") + # No need to analyze virtual packages if spack.repo.PATH.is_virtual(dep_name): continue @@ -963,9 +973,11 @@ def _named_specs_in_when_arguments(pkgs, error_cls): summary = f"{pkg_name}: wrong 'when=' condition for the '{vname}' variant" errors.extend(_extracts_errors(triggers, summary)) - for provided, triggers in pkg_cls.provided.items(): - summary = f"{pkg_name}: wrong 'when=' condition for the '{provided}' virtual" - errors.extend(_extracts_errors(triggers, summary)) + for when, providers, details in _error_items(pkg_cls.provided): + errors.extend( + error_cls(f"{pkg_name}: wrong 'when=' condition for '{provided}' virtual", details) + for provided in providers + ) for when, requirements, details in _error_items(pkg_cls.requirements): errors.append( diff --git a/lib/spack/spack/cmd/info.py b/lib/spack/spack/cmd/info.py index fd5ccc5087..b007c60516 100644 --- a/lib/spack/spack/cmd/info.py +++ b/lib/spack/spack/cmd/info.py @@ -474,13 +474,7 @@ def print_virtuals(pkg, args): color.cprint("") color.cprint(section_title("Virtual Packages: ")) if pkg.provided: - inverse_map = {} - for spec, whens in pkg.provided.items(): - for when in whens: - if when not in inverse_map: - inverse_map[when] = set() - inverse_map[when].add(spec) - for when, specs in reversed(sorted(inverse_map.items())): + for when, specs in reversed(sorted(pkg.provided.items())): line = " %s provides %s" % ( when.colorized(), ", ".join(s.colorized() for s in specs), diff --git a/lib/spack/spack/directives.py b/lib/spack/spack/directives.py index 80aee968c8..2524562428 100644 --- a/lib/spack/spack/directives.py +++ b/lib/spack/spack/directives.py @@ -613,7 +613,7 @@ def extends(spec, when=None, type=("build", "run"), patches=None): @directive(dicts=("provided", "provided_together")) -def provides(*specs, when: Optional[str] = None): +def provides(*specs: SpecType, when: WhenType = None): """Allows packages to provide a virtual dependency. If a package provides "mpi", other packages can declare that they depend on "mpi", @@ -624,7 +624,7 @@ def provides(*specs, when: Optional[str] = None): when: condition when this provides clause needs to be considered """ - def _execute_provides(pkg): + def _execute_provides(pkg: "spack.package_base.PackageBase"): import spack.parser # Avoid circular dependency when_spec = _make_when_spec(when) @@ -634,6 +634,7 @@ def provides(*specs, when: Optional[str] = None): # ``when`` specs for ``provides()`` need a name, as they are used # to build the ProviderIndex. when_spec.name = pkg.name + spec_objs = [spack.spec.Spec(x) for x in specs] spec_names = [x.name for x in spec_objs] if len(spec_names) > 1: @@ -643,9 +644,8 @@ def provides(*specs, when: Optional[str] = None): if pkg.name == provided_spec.name: raise CircularReferenceError("Package '%s' cannot provide itself." % pkg.name) - if provided_spec not in pkg.provided: - pkg.provided[provided_spec] = set() - pkg.provided[provided_spec].add(when_spec) + provided_set = pkg.provided.setdefault(when_spec, set()) + provided_set.add(provided_spec) return _execute_provides diff --git a/lib/spack/spack/package_base.py b/lib/spack/spack/package_base.py index bb9fb6408c..92c456f7fe 100644 --- a/lib/spack/spack/package_base.py +++ b/lib/spack/spack/package_base.py @@ -25,7 +25,7 @@ import textwrap import time import traceback import warnings -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union import llnl.util.filesystem as fsys import llnl.util.tty as tty @@ -565,6 +565,8 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta): requirements: Dict[ "spack.spec.Spec", List[Tuple[Tuple["spack.spec.Spec", ...], str, Optional[str]]] ] + provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]] + provided_together: Dict["spack.spec.Spec", List[Set[str]]] patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]] #: By default, packages are not virtual @@ -1342,9 +1344,9 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta): True if this package provides a virtual package with the specified name """ return any( - any(self.spec.intersects(c) for c in constraints) - for s, constraints in self.provided.items() - if s.name == vpkg_name + any(spec.name == vpkg_name for spec in provided) + for when_spec, provided in self.provided.items() + if self.spec.intersects(when_spec) ) @property @@ -1354,10 +1356,16 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta): """ return [ vspec - for vspec, constraints in self.provided.items() - if any(self.spec.satisfies(c) for c in constraints) + for when_spec, provided in self.provided.items() + for vspec in provided + if self.spec.satisfies(when_spec) ] + @classmethod + def provided_virtual_names(cls): + """Return sorted list of names of virtuals that can be provided by this package.""" + return sorted(set(vpkg.name for virtuals in cls.provided.values() for vpkg in virtuals)) + @property def prefix(self): """Get the prefix into which this package should be installed.""" diff --git a/lib/spack/spack/provider_index.py b/lib/spack/spack/provider_index.py index a59cb1be80..29c32ce1b5 100644 --- a/lib/spack/spack/provider_index.py +++ b/lib/spack/spack/provider_index.py @@ -128,8 +128,8 @@ class ProviderIndex(_IndexBase): assert not self.repository.is_virtual_safe(spec.name), msg pkg_provided = self.repository.get_pkg_class(spec.name).provided - for provided_spec, provider_specs in pkg_provided.items(): - for provider_spec_readonly in provider_specs: + for provider_spec_readonly, provided_specs in pkg_provided.items(): + for provided_spec in provided_specs: # TODO: fix this comment. # We want satisfaction other than flags provider_spec = provider_spec_readonly.copy() diff --git a/lib/spack/spack/solver/asp.py b/lib/spack/spack/solver/asp.py index 82eff23119..0f38331412 100644 --- a/lib/spack/spack/solver/asp.py +++ b/lib/spack/spack/solver/asp.py @@ -1628,19 +1628,20 @@ class SpackSolverSetup: self.gen.fact(fn.imposed_constraint(condition_id, *pred.args)) def package_provider_rules(self, pkg): - for provider_name in sorted(set(s.name for s in pkg.provided.keys())): - if provider_name not in self.possible_virtuals: + for vpkg_name in pkg.provided_virtual_names(): + if vpkg_name not in self.possible_virtuals: continue - self.gen.fact(fn.pkg_fact(pkg.name, fn.possible_provider(provider_name))) + self.gen.fact(fn.pkg_fact(pkg.name, fn.possible_provider(vpkg_name))) - for provided, whens in pkg.provided.items(): - if provided.name not in self.possible_virtuals: - continue - for when in whens: - msg = "%s provides %s when %s" % (pkg.name, provided, when) - condition_id = self.condition(when, provided, pkg.name, msg) + for when, provided in pkg.provided.items(): + for vpkg in provided: + if vpkg.name not in self.possible_virtuals: + continue + + msg = f"{pkg.name} provides {vpkg} when {when}" + condition_id = self.condition(when, vpkg, pkg.name, msg) self.gen.fact( - fn.pkg_fact(when.name, fn.provider_condition(condition_id, provided.name)) + fn.pkg_fact(when.name, fn.provider_condition(condition_id, vpkg.name)) ) self.gen.newline() @@ -3383,7 +3384,7 @@ def _is_reusable(spec: spack.spec.Spec, packages, local: bool) -> bool: return True try: - provided = [p.name for p in spec.package.provided] + provided = spack.repo.PATH.get(spec).provided_virtual_names() except spack.repo.RepoError: provided = [] diff --git a/lib/spack/spack/spec.py b/lib/spack/spack/spec.py index b27272e8e1..767a69ee67 100644 --- a/lib/spack/spack/spec.py +++ b/lib/spack/spack/spec.py @@ -2788,7 +2788,7 @@ class Spec: for dep in self.traverse(): visited_user_specs.add(dep.name) pkg_cls = spack.repo.PATH.get_pkg_class(dep.name) - visited_user_specs.update(x.name for x in pkg_cls(dep).provided) + visited_user_specs.update(pkg_cls(dep).provided_virtual_names()) extra = set(user_spec_deps.keys()).difference(visited_user_specs) if extra: @@ -3774,11 +3774,9 @@ class Spec: return False if pkg.provides(virtual_spec.name): - for provided, when_specs in pkg.provided.items(): - if any( - non_virtual_spec.intersects(when, deps=False) for when in when_specs - ): - if provided.intersects(virtual_spec): + for when_spec, provided in pkg.provided.items(): + if non_virtual_spec.intersects(when_spec, deps=False): + if any(vpkg.intersects(virtual_spec) for vpkg in provided): return True return False @@ -3881,9 +3879,9 @@ class Spec: return False if pkg.provides(other.name): - for provided, when_specs in pkg.provided.items(): - if any(self.satisfies(when, deps=False) for when in when_specs): - if provided.intersects(other): + for when_spec, provided in pkg.provided.items(): + if self.satisfies(when_spec, deps=False): + if any(vpkg.intersects(other) for vpkg in provided): return True return False -- cgit v1.2.3-70-g09d2