summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorTodd Gamblin <tgamblin@llnl.gov>2024-02-26 14:26:01 -0800
committerGitHub <noreply@github.com>2024-02-26 22:26:01 +0000
commit48088ee24a53e40ac4aec212c3f763c2423faa89 (patch)
tree5513c6fd3c70c4bf9e9e2d2b56d9a84c0dfe52b7 /lib
parentc7df258ca64809a1cfa7f1d5324e7a713b80fc99 (diff)
downloadspack-48088ee24a53e40ac4aec212c3f763c2423faa89.tar.gz
spack-48088ee24a53e40ac4aec212c3f763c2423faa89.tar.bz2
spack-48088ee24a53e40ac4aec212c3f763c2423faa89.tar.xz
spack-48088ee24a53e40ac4aec212c3f763c2423faa89.zip
refactor: add type annotations and refactor solver conditions (#42081)
Refactoring `SpackSolverSetup` is a bit easier with type annotations, so I started adding some. This adds annotations for the (many) instance variables on `SpackSolverSetup` as well as a few other places. This also refactors `condition()` to reduce redundancy and to allow `_get_condition_id()` to be called independently of the larger condition function. Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com>
Diffstat (limited to 'lib')
-rw-r--r--lib/spack/spack/package_base.py1
-rw-r--r--lib/spack/spack/solver/asp.py211
2 files changed, 121 insertions, 91 deletions
diff --git a/lib/spack/spack/package_base.py b/lib/spack/spack/package_base.py
index 8066c1b70f..9e9f4728e1 100644
--- a/lib/spack/spack/package_base.py
+++ b/lib/spack/spack/package_base.py
@@ -566,6 +566,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]]
provided_together: Dict["spack.spec.Spec", List[Set[str]]]
patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]]
+ variants: Dict[str, Tuple["spack.variant.Variant", "spack.spec.Spec"]]
#: By default, packages are not virtual
#: Virtual packages override this attribute
diff --git a/lib/spack/spack/solver/asp.py b/lib/spack/spack/solver/asp.py
index 59164e930c..18c7dd371c 100644
--- a/lib/spack/spack/solver/asp.py
+++ b/lib/spack/spack/solver/asp.py
@@ -15,7 +15,7 @@ import sys
import types
import typing
import warnings
-from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
+from typing import Callable, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Type, Union
import archspec.cpu
@@ -258,7 +258,7 @@ def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunc
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
-def _create_counter(specs, tests):
+def _create_counter(specs: List[spack.spec.Spec], tests: bool):
strategy = spack.config.CONFIG.get("concretizer:duplicates:strategy", "none")
if strategy == "full":
return FullDuplicatesCounter(specs, tests=tests)
@@ -897,35 +897,41 @@ class ConcreteSpecsByHash(collections.abc.Mapping):
return iter(self.data)
+# types for condition caching in solver setup
+ConditionSpecKey = Tuple[str, Optional[TransformFunction]]
+ConditionIdFunctionPair = Tuple[int, List[AspFunction]]
+ConditionSpecCache = Dict[str, Dict[ConditionSpecKey, ConditionIdFunctionPair]]
+
+
class SpackSolverSetup:
"""Class to set up and run a Spack concretization solve."""
- def __init__(self, tests=False):
- self.gen = None # set by setup()
-
- self.assumptions = []
- self.declared_versions = collections.defaultdict(list)
- self.possible_versions = collections.defaultdict(set)
- self.deprecated_versions = collections.defaultdict(set)
+ def __init__(self, tests: bool = False):
+ # these are all initialized in setup()
+ self.gen: "ProblemInstanceBuilder" = ProblemInstanceBuilder()
+ self.possible_virtuals: Set[str] = set()
- self.possible_virtuals = None
- self.possible_compilers = []
- self.possible_oses = set()
- self.variant_values_from_specs = set()
- self.version_constraints = set()
- self.target_constraints = set()
- self.default_targets = []
- self.compiler_version_constraints = set()
- self.post_facts = []
+ self.assumptions: List[Tuple["clingo.Symbol", bool]] = [] # type: ignore[name-defined]
+ self.declared_versions: Dict[str, List[DeclaredVersion]] = collections.defaultdict(list)
+ self.possible_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(set)
+ self.deprecated_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(
+ set
+ )
- # (ID, CompilerSpec) -> dictionary of attributes
- self.compiler_info = collections.defaultdict(dict)
+ self.possible_compilers: List = []
+ self.possible_oses: Set = set()
+ self.variant_values_from_specs: Set = set()
+ self.version_constraints: Set = set()
+ self.target_constraints: Set = set()
+ self.default_targets: List = []
+ self.compiler_version_constraints: Set = set()
+ self.post_facts: List = []
- self.reusable_and_possible = ConcreteSpecsByHash()
+ self.reusable_and_possible: ConcreteSpecsByHash = ConcreteSpecsByHash()
- self._id_counter = itertools.count()
- self._trigger_cache = collections.defaultdict(dict)
- self._effect_cache = collections.defaultdict(dict)
+ self._id_counter: Iterator[int] = itertools.count()
+ self._trigger_cache: ConditionSpecCache = collections.defaultdict(dict)
+ self._effect_cache: ConditionSpecCache = collections.defaultdict(dict)
# Caches to optimize the setup phase of the solver
self.target_specs_cache = None
@@ -937,8 +943,8 @@ class SpackSolverSetup:
self.concretize_everything = True
# Set during the call to setup
- self.pkgs = None
- self.explicitly_required_namespaces = {}
+ self.pkgs: Set[str] = set()
+ self.explicitly_required_namespaces: Dict[str, str] = {}
def pkg_version_rules(self, pkg):
"""Output declared versions of a package.
@@ -1222,6 +1228,38 @@ class SpackSolverSetup:
self.gen.newline()
+ def _get_condition_id(
+ self,
+ named_cond: spack.spec.Spec,
+ cache: ConditionSpecCache,
+ body: bool,
+ transform: Optional[TransformFunction] = None,
+ ) -> int:
+ """Get the id for one half of a condition (either a trigger or an imposed constraint).
+
+ Construct a key from the condition spec and any associated transformation, and
+ cache the ASP functions that they imply. The saved functions will be output
+ later in ``trigger_rules()`` and ``effect_rules()``.
+
+ Returns:
+ The id of the cached trigger or effect.
+
+ """
+ pkg_cache = cache[named_cond.name]
+
+ named_cond_key = (str(named_cond), transform)
+ result = pkg_cache.get(named_cond_key)
+ if result:
+ return result[0]
+
+ cond_id = next(self._id_counter)
+ requirements = self.spec_clauses(named_cond, body=body)
+ if transform:
+ requirements = transform(named_cond, requirements)
+ pkg_cache[named_cond_key] = (cond_id, requirements)
+
+ return cond_id
+
def condition(
self,
required_spec: spack.spec.Spec,
@@ -1247,7 +1285,8 @@ class SpackSolverSetup:
"""
named_cond = required_spec.copy()
named_cond.name = named_cond.name or name
- assert named_cond.name, "must provide name for anonymous conditions!"
+ if not named_cond.name:
+ raise ValueError(f"Must provide a name for anonymous condition: '{named_cond}'")
# Check if we can emit the requirements before updating the condition ID counter.
# In this way, if a condition can't be emitted but the exception is handled in the caller,
@@ -1257,35 +1296,19 @@ class SpackSolverSetup:
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
self.gen.fact(fn.condition_reason(condition_id, msg))
- cache = self._trigger_cache[named_cond.name]
-
- named_cond_key = (str(named_cond), transform_required)
- if named_cond_key not in cache:
- trigger_id = next(self._id_counter)
- requirements = self.spec_clauses(named_cond, body=True, required_from=name)
-
- if transform_required:
- requirements = transform_required(named_cond, requirements)
-
- cache[named_cond_key] = (trigger_id, requirements)
- trigger_id, requirements = cache[named_cond_key]
+ trigger_id = self._get_condition_id(
+ named_cond, cache=self._trigger_cache, body=True, transform=transform_required
+ )
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
if not imposed_spec:
return condition_id
- cache = self._effect_cache[named_cond.name]
- imposed_spec_key = (str(imposed_spec), transform_imposed)
- if imposed_spec_key not in cache:
- effect_id = next(self._id_counter)
- requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
-
- if transform_imposed:
- requirements = transform_imposed(imposed_spec, requirements)
-
- cache[imposed_spec_key] = (effect_id, requirements)
- effect_id, requirements = cache[imposed_spec_key]
+ effect_id = self._get_condition_id(
+ imposed_spec, cache=self._effect_cache, body=False, transform=transform_imposed
+ )
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
+
return condition_id
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
@@ -1387,23 +1410,13 @@ class SpackSolverSetup:
def provider_defaults(self):
self.gen.h2("Default virtual providers")
- msg = (
- "Internal Error: possible_virtuals is not populated. Please report to the spack"
- " maintainers"
- )
- assert self.possible_virtuals is not None, msg
self.virtual_preferences(
"all", lambda v, p, i: self.gen.fact(fn.default_provider_preference(v, p, i))
)
def provider_requirements(self):
self.gen.h2("Requirements on virtual providers")
- msg = (
- "Internal Error: possible_virtuals is not populated. Please report to the spack"
- " maintainers"
- )
parser = RequirementParser(spack.config.CONFIG)
- assert self.possible_virtuals is not None, msg
for virtual_str in sorted(self.possible_virtuals):
rules = parser.rules_from_virtual(virtual_str)
if rules:
@@ -1602,35 +1615,57 @@ class SpackSolverSetup:
fn.compiler_version_flag(compiler.name, compiler.version, name, flag)
)
- def spec_clauses(self, *args, **kwargs):
- """Wrap a call to `_spec_clauses()` into a try/except block that
- raises a comprehensible error message in case of failure.
+ def spec_clauses(
+ self,
+ spec: spack.spec.Spec,
+ *,
+ body: bool = False,
+ transitive: bool = True,
+ expand_hashes: bool = False,
+ concrete_build_deps=False,
+ required_from: Optional[str] = None,
+ ) -> List[AspFunction]:
+ """Wrap a call to `_spec_clauses()` into a try/except block with better error handling.
+
+ Arguments are as for ``_spec_clauses()`` except ``required_from``.
+
+ Arguments:
+ required_from: name of package that caused this call.
"""
- requestor = kwargs.pop("required_from", None)
try:
- clauses = self._spec_clauses(*args, **kwargs)
+ clauses = self._spec_clauses(
+ spec,
+ body=body,
+ transitive=transitive,
+ expand_hashes=expand_hashes,
+ concrete_build_deps=concrete_build_deps,
+ )
except RuntimeError as exc:
msg = str(exc)
- if requestor:
- msg += ' [required from package "{0}"]'.format(requestor)
+ if required_from:
+ msg += f" [required from package '{required_from}']"
raise RuntimeError(msg)
return clauses
def _spec_clauses(
- self, spec, body=False, transitive=True, expand_hashes=False, concrete_build_deps=False
- ):
+ self,
+ spec: spack.spec.Spec,
+ *,
+ body: bool = False,
+ transitive: bool = True,
+ expand_hashes: bool = False,
+ concrete_build_deps: bool = False,
+ ) -> List[AspFunction]:
"""Return a list of clauses for a spec mandates are true.
Arguments:
- spec (spack.spec.Spec): the spec to analyze
- body (bool): if True, generate clauses to be used in rule bodies
- (final values) instead of rule heads (setters).
- transitive (bool): if False, don't generate clauses from
- dependencies (default True)
- expand_hashes (bool): if True, descend into hashes of concrete specs
- (default False)
- concrete_build_deps (bool): if False, do not include pure build deps
- of concrete specs (as they have no effect on runtime constraints)
+ spec: the spec to analyze
+ body: if True, generate clauses to be used in rule bodies (final values) instead
+ of rule heads (setters).
+ transitive: if False, don't generate clauses from dependencies (default True)
+ expand_hashes: if True, descend into hashes of concrete specs (default False)
+ concrete_build_deps: if False, do not include pure build deps of concrete specs
+ (as they have no effect on runtime constraints)
Normally, if called with ``transitive=True``, ``spec_clauses()`` just generates
hashes for the dependency requirements of concrete specs. If ``expand_hashes``
@@ -1640,7 +1675,7 @@ class SpackSolverSetup:
"""
clauses = []
- f = _Body if body else _Head
+ f: Union[Type[_Head], Type[_Body]] = _Body if body else _Head
if spec.name:
clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
@@ -1729,8 +1764,9 @@ class SpackSolverSetup:
# dependencies
if spec.concrete:
# older specs do not have package hashes, so we have to do this carefully
- if getattr(spec, "_package_hash", None):
- clauses.append(fn.attr("package_hash", spec.name, spec._package_hash))
+ package_hash = getattr(spec, "_package_hash", None)
+ if package_hash:
+ clauses.append(fn.attr("package_hash", spec.name, package_hash))
clauses.append(fn.attr("hash", spec.name, spec.dag_hash()))
edges = spec.edges_from_dependents()
@@ -1789,7 +1825,7 @@ class SpackSolverSetup:
return clauses
def define_package_versions_and_validate_preferences(
- self, possible_pkgs, *, require_checksum: bool, allow_deprecated: bool
+ self, possible_pkgs: Set[str], *, require_checksum: bool, allow_deprecated: bool
):
"""Declare any versions in specs not declared in packages."""
packages_yaml = spack.config.get("packages")
@@ -1822,7 +1858,7 @@ class SpackSolverSetup:
if pkg_name not in packages_yaml or "version" not in packages_yaml[pkg_name]:
continue
- version_defs = []
+ version_defs: List[GitOrStandardVersion] = []
for vstr in packages_yaml[pkg_name]["version"]:
v = vn.ver(vstr)
@@ -2033,13 +2069,6 @@ class SpackSolverSetup:
def virtual_providers(self):
self.gen.h2("Virtual providers")
- msg = (
- "Internal Error: possible_virtuals is not populated. Please report to the spack"
- " maintainers"
- )
- assert self.possible_virtuals is not None, msg
-
- # what provides what
for vspec in sorted(self.possible_virtuals):
self.gen.fact(fn.virtual(vspec))
self.gen.newline()
@@ -2236,7 +2265,7 @@ class SpackSolverSetup:
def setup(
self,
- specs: Sequence[spack.spec.Spec],
+ specs: List[spack.spec.Spec],
*,
reuse: Optional[List[spack.spec.Spec]] = None,
allow_deprecated: bool = False,