summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorMassimiliano Culpo <massimiliano.culpo@gmail.com>2020-03-25 17:48:05 +0100
committerGitHub <noreply@github.com>2020-03-25 09:48:05 -0700
commitb42a96df980458602c6c5314b3d3fac08f2d7442 (patch)
tree1b8c6755709c4b4d2801ed5566e0c07d070c48ce /lib
parent3aa225cd5cc14b1e108fc75b128759ad4d6bb1eb (diff)
downloadspack-b42a96df980458602c6c5314b3d3fac08f2d7442.tar.gz
spack-b42a96df980458602c6c5314b3d3fac08f2d7442.tar.bz2
spack-b42a96df980458602c6c5314b3d3fac08f2d7442.tar.xz
spack-b42a96df980458602c6c5314b3d3fac08f2d7442.zip
provider index: removed import from + refactored a few parts (#15570)
Removed provider_index use of 'import from' and refactored a few routines to a further subclassing of _IndexBase for implementing user defined bindings of provider specs.
Diffstat (limited to 'lib')
-rw-r--r--lib/spack/spack/provider_index.py307
-rw-r--r--lib/spack/spack/test/spec_dag.py2
2 files changed, 179 insertions, 130 deletions
diff --git a/lib/spack/spack/provider_index.py b/lib/spack/spack/provider_index.py
index 9bf4af8911..326f6aa8f1 100644
--- a/lib/spack/spack/provider_index.py
+++ b/lib/spack/spack/provider_index.py
@@ -2,54 +2,147 @@
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+"""Classes and functions to manage providers of virtual dependencies"""
+import itertools
-"""
-The ``virtual`` module contains utility classes for virtual dependencies.
-"""
-
-from itertools import product as iproduct
-from six import iteritems
-from pprint import pformat
-
+import six
import spack.error
import spack.util.spack_json as sjson
-class ProviderIndex(object):
- """This is a dict of dicts used for finding providers of particular
- virtual dependencies. The dict of dicts looks like:
+def _cross_provider_maps(lmap, rmap):
+ """Return a dictionary that combines constraint requests from both input.
- { vpkg name :
- { full vpkg spec : set(packages providing spec) } }
+ Args:
+ lmap: main provider map
+ rmap: provider map with additional constraints
+ """
+ # TODO: this is pretty darned nasty, and inefficient, but there
+ # TODO: are not that many vdeps in most specs.
+ result = {}
+ for lspec, rspec in itertools.product(lmap, rmap):
+ try:
+ constrained = lspec.constrained(rspec)
+ except spack.error.UnsatisfiableSpecError:
+ continue
+
+ # lp and rp are left and right provider specs.
+ for lp_spec, rp_spec in itertools.product(lmap[lspec], rmap[rspec]):
+ if lp_spec.name == rp_spec.name:
+ try:
+ const = lp_spec.constrained(rp_spec, deps=False)
+ result.setdefault(constrained, set()).add(const)
+ except spack.error.UnsatisfiableSpecError:
+ continue
+ return result
+
+
+class _IndexBase(object):
+ #: This is a dict of dicts used for finding providers of particular
+ #: virtual dependencies. The dict of dicts looks like:
+ #:
+ #: { vpkg name :
+ #: { full vpkg spec : set(packages providing spec) } }
+ #:
+ #: Callers can use this to first find which packages provide a vpkg,
+ #: then find a matching full spec. e.g., in this scenario:
+ #:
+ #: { 'mpi' :
+ #: { mpi@:1.1 : set([mpich]),
+ #: mpi@:2.3 : set([mpich2@1.9:]) } }
+ #:
+ #: Calling providers_for(spec) will find specs that provide a
+ #: matching implementation of MPI. Derived class need to construct
+ #: this attribute according to the semantics above.
+ providers = None
+
+ def providers_for(self, virtual_spec):
+ """Return a list of specs of all packages that provide virtual
+ packages with the supplied spec.
+
+ Args:
+ virtual_spec: virtual spec to be provided
+ """
+ result = set()
+ # Allow string names to be passed as input, as well as specs
+ if isinstance(virtual_spec, six.string_types):
+ virtual_spec = spack.spec.Spec(virtual_spec)
- Callers can use this to first find which packages provide a vpkg,
- then find a matching full spec. e.g., in this scenario:
+ # Add all the providers that satisfy the vpkg spec.
+ if virtual_spec.name in self.providers:
+ for p_spec, spec_set in self.providers[virtual_spec.name].items():
+ if p_spec.satisfies(virtual_spec, deps=False):
+ result.update(spec_set)
- { 'mpi' :
- { mpi@:1.1 : set([mpich]),
- mpi@:2.3 : set([mpich2@1.9:]) } }
+ # Return providers in order. Defensively copy.
+ return sorted(s.copy() for s in result)
- Calling providers_for(spec) will find specs that provide a
- matching implementation of MPI.
+ def __contains__(self, name):
+ return name in self.providers
- """
+ def satisfies(self, other):
+ """Determine if the providers of virtual specs are compatible.
- def __init__(self, specs=None, restrict=False):
- """Create a new ProviderIndex.
+ Args:
+ other: another provider index
+
+ Returns:
+ True if the providers are compatible, False otherwise.
+ """
+ common = set(self.providers) & set(other.providers)
+ if not common:
+ return True
+
+ # This ensures that some provider in other COULD satisfy the
+ # vpkg constraints on self.
+ result = {}
+ for name in common:
+ crossed = _cross_provider_maps(
+ self.providers[name], other.providers[name]
+ )
+ if crossed:
+ result[name] = crossed
+
+ return all(c in result for c in common)
+
+ def __eq__(self, other):
+ return self.providers == other.providers
- Optional arguments:
+ def _transform(self, transform_fun, out_mapping_type=dict):
+ """Transform this provider index dictionary and return it.
+
+ Args:
+ transform_fun: transform_fun takes a (vpkg, pset) mapping and runs
+ it on each pair in nested dicts.
+ out_mapping_type: type to be used internally on the
+ transformed (vpkg, pset)
+
+ Returns:
+ Transformed mapping
+ """
+ return _transform(self.providers, transform_fun, out_mapping_type)
+
+ def __str__(self):
+ return str(self.providers)
- specs
- List (or sequence) of specs. If provided, will call
- `update` on this ProviderIndex with each spec in the list.
+ def __repr__(self):
+ return repr(self.providers)
- restrict
- "restricts" values to the verbatim input specs; do not
- pre-apply package's constraints.
- TODO: rename this. It is intended to keep things as broad
- as possible without overly restricting results, so it is
- not the best name.
+class ProviderIndex(_IndexBase):
+ def __init__(self, specs=None, restrict=False):
+ """Provider index based on a single mapping of providers.
+
+ Args:
+ specs (list of specs): if provided, will call update on each
+ single spec to initialize this provider index.
+
+ restrict: "restricts" values to the verbatim input specs; do not
+ pre-apply package's constraints.
+
+ TODO: rename this. It is intended to keep things as broad
+ TODO: as possible without overly restricting results, so it is
+ TODO: not the best name.
"""
if specs is None:
specs = []
@@ -67,6 +160,11 @@ class ProviderIndex(object):
self.update(spec)
def update(self, spec):
+ """Update the provider index with additional virtual specs.
+
+ Args:
+ spec: spec potentially providing additional virtual specs
+ """
if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec)
@@ -74,10 +172,10 @@ class ProviderIndex(object):
# Empty specs do not have a package
return
- assert(not spec.virtual)
+ assert not spec.virtual, "cannot update an index using a virtual spec"
pkg_provided = spec.package_class.provided
- for provided_spec, provider_specs in iteritems(pkg_provided):
+ for provided_spec, provider_specs in six.iteritems(pkg_provided):
for provider_spec in provider_specs:
# TODO: fix this comment.
# We want satisfaction other than flags
@@ -110,94 +208,24 @@ class ProviderIndex(object):
constrained.constrain(provider_spec)
provider_map[provided_spec].add(constrained)
- def providers_for(self, *vpkg_specs):
- """Gives specs of all packages that provide virtual packages
- with the supplied specs."""
- providers = set()
- for vspec in vpkg_specs:
- # Allow string names to be passed as input, as well as specs
- if type(vspec) == str:
- vspec = spack.spec.Spec(vspec)
-
- # Add all the providers that satisfy the vpkg spec.
- if vspec.name in self.providers:
- for p_spec, spec_set in self.providers[vspec.name].items():
- if p_spec.satisfies(vspec, deps=False):
- providers.update(spec_set)
-
- # Return providers in order. Defensively copy.
- return sorted(s.copy() for s in providers)
-
- # TODO: this is pretty darned nasty, and inefficient, but there
- # are not that many vdeps in most specs.
- def _cross_provider_maps(self, lmap, rmap):
- result = {}
- for lspec, rspec in iproduct(lmap, rmap):
- try:
- constrained = lspec.constrained(rspec)
- except spack.error.UnsatisfiableSpecError:
- continue
-
- # lp and rp are left and right provider specs.
- for lp_spec, rp_spec in iproduct(lmap[lspec], rmap[rspec]):
- if lp_spec.name == rp_spec.name:
- try:
- const = lp_spec.constrained(rp_spec, deps=False)
- result.setdefault(constrained, set()).add(const)
- except spack.error.UnsatisfiableSpecError:
- continue
- return result
-
- def __contains__(self, name):
- """Whether a particular vpkg name is in the index."""
- return name in self.providers
-
- def satisfies(self, other):
- """Check that providers of virtual specs are compatible."""
- common = set(self.providers) & set(other.providers)
- if not common:
- return True
-
- # This ensures that some provider in other COULD satisfy the
- # vpkg constraints on self.
- result = {}
- for name in common:
- crossed = self._cross_provider_maps(self.providers[name],
- other.providers[name])
- if crossed:
- result[name] = crossed
-
- return all(c in result for c in common)
-
def to_json(self, stream=None):
+ """Dump a JSON representation of this object.
+
+ Args:
+ stream: stream where to dump
+ """
provider_list = self._transform(
lambda vpkg, pset: [
vpkg.to_node_dict(), [p.to_node_dict() for p in pset]], list)
sjson.dump({'provider_index': {'providers': provider_list}}, stream)
- @staticmethod
- def from_json(stream):
- data = sjson.load(stream)
-
- if not isinstance(data, dict):
- raise ProviderIndexError("JSON ProviderIndex data was not a dict.")
-
- if 'provider_index' not in data:
- raise ProviderIndexError(
- "YAML ProviderIndex does not start with 'provider_index'")
-
- index = ProviderIndex()
- providers = data['provider_index']['providers']
- index.providers = _transform(
- providers,
- lambda vpkg, plist: (
- spack.spec.Spec.from_node_dict(vpkg),
- set(spack.spec.Spec.from_node_dict(p) for p in plist)))
- return index
-
def merge(self, other):
- """Merge `other` ProviderIndex into this one."""
+ """Merge another provider index into this one.
+
+ Args:
+ other (ProviderIndex): provider index to be merged
+ """
other = other.copy() # defensive copy.
for pkg in other.providers:
@@ -236,40 +264,61 @@ class ProviderIndex(object):
del self.providers[pkg]
def copy(self):
- """Deep copy of this ProviderIndex."""
+ """Return a deep copy of this index."""
clone = ProviderIndex()
clone.providers = self._transform(
lambda vpkg, pset: (vpkg, set((p.copy() for p in pset))))
return clone
- def __eq__(self, other):
- return self.providers == other.providers
+ @staticmethod
+ def from_json(stream):
+ """Construct a provider index from its JSON representation.
- def _transform(self, transform_fun, out_mapping_type=dict):
- return _transform(self.providers, transform_fun, out_mapping_type)
+ Args:
+ stream: stream where to read from the JSON data
+ """
+ data = sjson.load(stream)
- def __str__(self):
- return pformat(
- _transform(self.providers,
- lambda k, v: (k, list(v))))
+ if not isinstance(data, dict):
+ raise ProviderIndexError("JSON ProviderIndex data was not a dict.")
+
+ if 'provider_index' not in data:
+ raise ProviderIndexError(
+ "YAML ProviderIndex does not start with 'provider_index'")
+
+ index = ProviderIndex()
+ providers = data['provider_index']['providers']
+ index.providers = _transform(
+ providers,
+ lambda vpkg, plist: (
+ spack.spec.Spec.from_node_dict(vpkg),
+ set(spack.spec.Spec.from_node_dict(p) for p in plist)))
+ return index
def _transform(providers, transform_fun, out_mapping_type=dict):
"""Syntactic sugar for transforming a providers dict.
- transform_fun takes a (vpkg, pset) mapping and runs it on each
- pair in nested dicts.
+ Args:
+ providers: provider dictionary
+ transform_fun: transform_fun takes a (vpkg, pset) mapping and runs
+ it on each pair in nested dicts.
+ out_mapping_type: type to be used internally on the
+ transformed (vpkg, pset)
+ Returns:
+ Transformed mapping
"""
def mapiter(mappings):
if isinstance(mappings, dict):
- return iteritems(mappings)
+ return six.iteritems(mappings)
else:
return iter(mappings)
return dict(
- (name, out_mapping_type([
- transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)]))
+ (name, out_mapping_type(
+ [transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)]
+ ))
for name, mappings in providers.items())
diff --git a/lib/spack/spack/test/spec_dag.py b/lib/spack/spack/test/spec_dag.py
index 419a39968e..25917f9424 100644
--- a/lib/spack/spack/test/spec_dag.py
+++ b/lib/spack/spack/test/spec_dag.py
@@ -189,7 +189,7 @@ def test_conditional_dep_with_user_constraints():
assert ('y@3' in spec)
-@pytest.mark.usefixtures('mutable_mock_repo')
+@pytest.mark.usefixtures('mutable_mock_repo', 'config')
class TestSpecDag(object):
def test_conflicting_package_constraints(self, set_dependency):