diff options
author | Todd Gamblin <tgamblin@llnl.gov> | 2024-02-26 14:26:01 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-26 22:26:01 +0000 |
commit | 48088ee24a53e40ac4aec212c3f763c2423faa89 (patch) | |
tree | 5513c6fd3c70c4bf9e9e2d2b56d9a84c0dfe52b7 /lib | |
parent | c7df258ca64809a1cfa7f1d5324e7a713b80fc99 (diff) | |
download | spack-48088ee24a53e40ac4aec212c3f763c2423faa89.tar.gz spack-48088ee24a53e40ac4aec212c3f763c2423faa89.tar.bz2 spack-48088ee24a53e40ac4aec212c3f763c2423faa89.tar.xz spack-48088ee24a53e40ac4aec212c3f763c2423faa89.zip |
refactor: add type annotations and refactor solver conditions (#42081)
Refactoring `SpackSolverSetup` is a bit easier with type annotations, so I started
adding some. This adds annotations for the (many) instance variables on
`SpackSolverSetup` as well as a few other places.
This also refactors `condition()` to reduce redundancy and to allow
`_get_condition_id()` to be called independently of the larger condition
function.
Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com>
Diffstat (limited to 'lib')
-rw-r--r-- | lib/spack/spack/package_base.py | 1 | ||||
-rw-r--r-- | lib/spack/spack/solver/asp.py | 211 |
2 files changed, 121 insertions, 91 deletions
diff --git a/lib/spack/spack/package_base.py b/lib/spack/spack/package_base.py index 8066c1b70f..9e9f4728e1 100644 --- a/lib/spack/spack/package_base.py +++ b/lib/spack/spack/package_base.py @@ -566,6 +566,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta): provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]] provided_together: Dict["spack.spec.Spec", List[Set[str]]] patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]] + variants: Dict[str, Tuple["spack.variant.Variant", "spack.spec.Spec"]] #: By default, packages are not virtual #: Virtual packages override this attribute diff --git a/lib/spack/spack/solver/asp.py b/lib/spack/spack/solver/asp.py index 59164e930c..18c7dd371c 100644 --- a/lib/spack/spack/solver/asp.py +++ b/lib/spack/spack/solver/asp.py @@ -15,7 +15,7 @@ import sys import types import typing import warnings -from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union +from typing import Callable, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Type, Union import archspec.cpu @@ -258,7 +258,7 @@ def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunc return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts)) -def _create_counter(specs, tests): +def _create_counter(specs: List[spack.spec.Spec], tests: bool): strategy = spack.config.CONFIG.get("concretizer:duplicates:strategy", "none") if strategy == "full": return FullDuplicatesCounter(specs, tests=tests) @@ -897,35 +897,41 @@ class ConcreteSpecsByHash(collections.abc.Mapping): return iter(self.data) +# types for condition caching in solver setup +ConditionSpecKey = Tuple[str, Optional[TransformFunction]] +ConditionIdFunctionPair = Tuple[int, List[AspFunction]] +ConditionSpecCache = Dict[str, Dict[ConditionSpecKey, ConditionIdFunctionPair]] + + class SpackSolverSetup: """Class to set up and run a Spack concretization solve.""" - def __init__(self, tests=False): - self.gen = None # set by setup() - - self.assumptions = [] - self.declared_versions = collections.defaultdict(list) - self.possible_versions = collections.defaultdict(set) - self.deprecated_versions = collections.defaultdict(set) + def __init__(self, tests: bool = False): + # these are all initialized in setup() + self.gen: "ProblemInstanceBuilder" = ProblemInstanceBuilder() + self.possible_virtuals: Set[str] = set() - self.possible_virtuals = None - self.possible_compilers = [] - self.possible_oses = set() - self.variant_values_from_specs = set() - self.version_constraints = set() - self.target_constraints = set() - self.default_targets = [] - self.compiler_version_constraints = set() - self.post_facts = [] + self.assumptions: List[Tuple["clingo.Symbol", bool]] = [] # type: ignore[name-defined] + self.declared_versions: Dict[str, List[DeclaredVersion]] = collections.defaultdict(list) + self.possible_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(set) + self.deprecated_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict( + set + ) - # (ID, CompilerSpec) -> dictionary of attributes - self.compiler_info = collections.defaultdict(dict) + self.possible_compilers: List = [] + self.possible_oses: Set = set() + self.variant_values_from_specs: Set = set() + self.version_constraints: Set = set() + self.target_constraints: Set = set() + self.default_targets: List = [] + self.compiler_version_constraints: Set = set() + self.post_facts: List = [] - self.reusable_and_possible = ConcreteSpecsByHash() + self.reusable_and_possible: ConcreteSpecsByHash = ConcreteSpecsByHash() - self._id_counter = itertools.count() - self._trigger_cache = collections.defaultdict(dict) - self._effect_cache = collections.defaultdict(dict) + self._id_counter: Iterator[int] = itertools.count() + self._trigger_cache: ConditionSpecCache = collections.defaultdict(dict) + self._effect_cache: ConditionSpecCache = collections.defaultdict(dict) # Caches to optimize the setup phase of the solver self.target_specs_cache = None @@ -937,8 +943,8 @@ class SpackSolverSetup: self.concretize_everything = True # Set during the call to setup - self.pkgs = None - self.explicitly_required_namespaces = {} + self.pkgs: Set[str] = set() + self.explicitly_required_namespaces: Dict[str, str] = {} def pkg_version_rules(self, pkg): """Output declared versions of a package. @@ -1222,6 +1228,38 @@ class SpackSolverSetup: self.gen.newline() + def _get_condition_id( + self, + named_cond: spack.spec.Spec, + cache: ConditionSpecCache, + body: bool, + transform: Optional[TransformFunction] = None, + ) -> int: + """Get the id for one half of a condition (either a trigger or an imposed constraint). + + Construct a key from the condition spec and any associated transformation, and + cache the ASP functions that they imply. The saved functions will be output + later in ``trigger_rules()`` and ``effect_rules()``. + + Returns: + The id of the cached trigger or effect. + + """ + pkg_cache = cache[named_cond.name] + + named_cond_key = (str(named_cond), transform) + result = pkg_cache.get(named_cond_key) + if result: + return result[0] + + cond_id = next(self._id_counter) + requirements = self.spec_clauses(named_cond, body=body) + if transform: + requirements = transform(named_cond, requirements) + pkg_cache[named_cond_key] = (cond_id, requirements) + + return cond_id + def condition( self, required_spec: spack.spec.Spec, @@ -1247,7 +1285,8 @@ class SpackSolverSetup: """ named_cond = required_spec.copy() named_cond.name = named_cond.name or name - assert named_cond.name, "must provide name for anonymous conditions!" + if not named_cond.name: + raise ValueError(f"Must provide a name for anonymous condition: '{named_cond}'") # Check if we can emit the requirements before updating the condition ID counter. # In this way, if a condition can't be emitted but the exception is handled in the caller, @@ -1257,35 +1296,19 @@ class SpackSolverSetup: self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id))) self.gen.fact(fn.condition_reason(condition_id, msg)) - cache = self._trigger_cache[named_cond.name] - - named_cond_key = (str(named_cond), transform_required) - if named_cond_key not in cache: - trigger_id = next(self._id_counter) - requirements = self.spec_clauses(named_cond, body=True, required_from=name) - - if transform_required: - requirements = transform_required(named_cond, requirements) - - cache[named_cond_key] = (trigger_id, requirements) - trigger_id, requirements = cache[named_cond_key] + trigger_id = self._get_condition_id( + named_cond, cache=self._trigger_cache, body=True, transform=transform_required + ) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id))) if not imposed_spec: return condition_id - cache = self._effect_cache[named_cond.name] - imposed_spec_key = (str(imposed_spec), transform_imposed) - if imposed_spec_key not in cache: - effect_id = next(self._id_counter) - requirements = self.spec_clauses(imposed_spec, body=False, required_from=name) - - if transform_imposed: - requirements = transform_imposed(imposed_spec, requirements) - - cache[imposed_spec_key] = (effect_id, requirements) - effect_id, requirements = cache[imposed_spec_key] + effect_id = self._get_condition_id( + imposed_spec, cache=self._effect_cache, body=False, transform=transform_imposed + ) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id))) + return condition_id def impose(self, condition_id, imposed_spec, node=True, name=None, body=False): @@ -1387,23 +1410,13 @@ class SpackSolverSetup: def provider_defaults(self): self.gen.h2("Default virtual providers") - msg = ( - "Internal Error: possible_virtuals is not populated. Please report to the spack" - " maintainers" - ) - assert self.possible_virtuals is not None, msg self.virtual_preferences( "all", lambda v, p, i: self.gen.fact(fn.default_provider_preference(v, p, i)) ) def provider_requirements(self): self.gen.h2("Requirements on virtual providers") - msg = ( - "Internal Error: possible_virtuals is not populated. Please report to the spack" - " maintainers" - ) parser = RequirementParser(spack.config.CONFIG) - assert self.possible_virtuals is not None, msg for virtual_str in sorted(self.possible_virtuals): rules = parser.rules_from_virtual(virtual_str) if rules: @@ -1602,35 +1615,57 @@ class SpackSolverSetup: fn.compiler_version_flag(compiler.name, compiler.version, name, flag) ) - def spec_clauses(self, *args, **kwargs): - """Wrap a call to `_spec_clauses()` into a try/except block that - raises a comprehensible error message in case of failure. + def spec_clauses( + self, + spec: spack.spec.Spec, + *, + body: bool = False, + transitive: bool = True, + expand_hashes: bool = False, + concrete_build_deps=False, + required_from: Optional[str] = None, + ) -> List[AspFunction]: + """Wrap a call to `_spec_clauses()` into a try/except block with better error handling. + + Arguments are as for ``_spec_clauses()`` except ``required_from``. + + Arguments: + required_from: name of package that caused this call. """ - requestor = kwargs.pop("required_from", None) try: - clauses = self._spec_clauses(*args, **kwargs) + clauses = self._spec_clauses( + spec, + body=body, + transitive=transitive, + expand_hashes=expand_hashes, + concrete_build_deps=concrete_build_deps, + ) except RuntimeError as exc: msg = str(exc) - if requestor: - msg += ' [required from package "{0}"]'.format(requestor) + if required_from: + msg += f" [required from package '{required_from}']" raise RuntimeError(msg) return clauses def _spec_clauses( - self, spec, body=False, transitive=True, expand_hashes=False, concrete_build_deps=False - ): + self, + spec: spack.spec.Spec, + *, + body: bool = False, + transitive: bool = True, + expand_hashes: bool = False, + concrete_build_deps: bool = False, + ) -> List[AspFunction]: """Return a list of clauses for a spec mandates are true. Arguments: - spec (spack.spec.Spec): the spec to analyze - body (bool): if True, generate clauses to be used in rule bodies - (final values) instead of rule heads (setters). - transitive (bool): if False, don't generate clauses from - dependencies (default True) - expand_hashes (bool): if True, descend into hashes of concrete specs - (default False) - concrete_build_deps (bool): if False, do not include pure build deps - of concrete specs (as they have no effect on runtime constraints) + spec: the spec to analyze + body: if True, generate clauses to be used in rule bodies (final values) instead + of rule heads (setters). + transitive: if False, don't generate clauses from dependencies (default True) + expand_hashes: if True, descend into hashes of concrete specs (default False) + concrete_build_deps: if False, do not include pure build deps of concrete specs + (as they have no effect on runtime constraints) Normally, if called with ``transitive=True``, ``spec_clauses()`` just generates hashes for the dependency requirements of concrete specs. If ``expand_hashes`` @@ -1640,7 +1675,7 @@ class SpackSolverSetup: """ clauses = [] - f = _Body if body else _Head + f: Union[Type[_Head], Type[_Body]] = _Body if body else _Head if spec.name: clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name)) @@ -1729,8 +1764,9 @@ class SpackSolverSetup: # dependencies if spec.concrete: # older specs do not have package hashes, so we have to do this carefully - if getattr(spec, "_package_hash", None): - clauses.append(fn.attr("package_hash", spec.name, spec._package_hash)) + package_hash = getattr(spec, "_package_hash", None) + if package_hash: + clauses.append(fn.attr("package_hash", spec.name, package_hash)) clauses.append(fn.attr("hash", spec.name, spec.dag_hash())) edges = spec.edges_from_dependents() @@ -1789,7 +1825,7 @@ class SpackSolverSetup: return clauses def define_package_versions_and_validate_preferences( - self, possible_pkgs, *, require_checksum: bool, allow_deprecated: bool + self, possible_pkgs: Set[str], *, require_checksum: bool, allow_deprecated: bool ): """Declare any versions in specs not declared in packages.""" packages_yaml = spack.config.get("packages") @@ -1822,7 +1858,7 @@ class SpackSolverSetup: if pkg_name not in packages_yaml or "version" not in packages_yaml[pkg_name]: continue - version_defs = [] + version_defs: List[GitOrStandardVersion] = [] for vstr in packages_yaml[pkg_name]["version"]: v = vn.ver(vstr) @@ -2033,13 +2069,6 @@ class SpackSolverSetup: def virtual_providers(self): self.gen.h2("Virtual providers") - msg = ( - "Internal Error: possible_virtuals is not populated. Please report to the spack" - " maintainers" - ) - assert self.possible_virtuals is not None, msg - - # what provides what for vspec in sorted(self.possible_virtuals): self.gen.fact(fn.virtual(vspec)) self.gen.newline() @@ -2236,7 +2265,7 @@ class SpackSolverSetup: def setup( self, - specs: Sequence[spack.spec.Spec], + specs: List[spack.spec.Spec], *, reuse: Optional[List[spack.spec.Spec]] = None, allow_deprecated: bool = False, |