diff options
-rw-r--r-- | lib/spack/spack/provider_index.py | 82 | ||||
-rw-r--r-- | lib/spack/spack/test/concretize.py | 3 | ||||
-rw-r--r-- | lib/spack/spack/test/provider_index.py | 59 |
3 files changed, 117 insertions, 27 deletions
diff --git a/lib/spack/spack/provider_index.py b/lib/spack/spack/provider_index.py index 6cd2134e96..ecdc25c4d2 100644 --- a/lib/spack/spack/provider_index.py +++ b/lib/spack/spack/provider_index.py @@ -26,6 +26,8 @@ The ``virtual`` module contains utility classes for virtual dependencies. """ import itertools +from pprint import pformat + import yaml from yaml.error import MarkedYAMLError @@ -48,15 +50,30 @@ class ProviderIndex(object): Calling providers_for(spec) will find specs that provide a matching implementation of MPI. + """ - def __init__(self, specs=None, **kwargs): - # TODO: come up with another name for this. This "restricts" - # values to the verbatim impu specs (i.e., it doesn't - # pre-apply package's constraints, and keeps things as broad - # as possible, so it's really the wrong name) + + + def __init__(self, specs=None, restrict=False): + """Create a new ProviderIndex. + + Optional arguments: + + specs + List (or sequence) of specs. If provided, will call + `update` on this ProviderIndex with each spec in the list. + + restrict + "restricts" values to the verbatim input specs; do not + pre-apply package's constraints. + + TODO: rename this. It is intended to keep things as broad + as possible without overly restricting results, so it is + not the best name. + """ if specs is None: specs = [] - self.restrict = kwargs.setdefault('restrict', False) + self.restrict = restrict self.providers = {} for spec in specs: @@ -174,10 +191,9 @@ class ProviderIndex(object): def to_yaml(self, stream=None): - provider_list = dict( - (name, [[vpkg.to_node_dict(), [p.to_node_dict() for p in pset]] - for vpkg, pset in pdict.items()]) - for name, pdict in self.providers.items()) + provider_list = self._transform( + lambda vpkg, pset: [ + vpkg.to_node_dict(), [p.to_node_dict() for p in pset]], list) yaml.dump({'provider_index': {'providers': provider_list}}, stream=stream) @@ -201,12 +217,11 @@ class ProviderIndex(object): index = ProviderIndex() providers = yfile['provider_index']['providers'] - index.providers = dict( - (name, dict((spack.spec.Spec.from_node_dict(vpkg), - set(spack.spec.Spec.from_node_dict(p) for p in plist)) - for vpkg, plist in pdict_list)) - for name, pdict_list in providers.items()) - + index.providers = _transform( + providers, + lambda vpkg, plist: ( + spack.spec.Spec.from_node_dict(vpkg), + set(spack.spec.Spec.from_node_dict(p) for p in plist))) return index @@ -253,12 +268,39 @@ class ProviderIndex(object): def copy(self): """Deep copy of this ProviderIndex.""" clone = ProviderIndex() - clone.providers = dict( - (name, dict((vpkg, set((p.copy() for p in pset))) - for vpkg, pset in pdict.items())) - for name, pdict in self.providers.items()) + clone.providers = self._transform( + lambda vpkg, pset: (vpkg, set((p.copy() for p in pset)))) return clone def __eq__(self, other): return self.providers == other.providers + + + def _transform(self, transform_fun, out_mapping_type=dict): + return _transform(self.providers, transform_fun, out_mapping_type) + + + def __str__(self): + return pformat( + _transform(self.providers, + lambda k, v: (k, list(v)))) + + +def _transform(providers, transform_fun, out_mapping_type=dict): + """Syntactic sugar for transforming a providers dict. + + transform_fun takes a (vpkg, pset) mapping and runs it on each + pair in nested dicts. + + """ + def mapiter(mappings): + if isinstance(mappings, dict): + return mappings.iteritems() + else: + return iter(mappings) + + return dict( + (name, out_mapping_type([ + transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)])) + for name, mappings in providers.items()) diff --git a/lib/spack/spack/test/concretize.py b/lib/spack/spack/test/concretize.py index ae3ceecfc8..ec0a2ec244 100644 --- a/lib/spack/spack/test/concretize.py +++ b/lib/spack/spack/test/concretize.py @@ -153,9 +153,6 @@ class ConcretizeTest(MockPackagesTest): self.assertTrue(not any(spec.satisfies('mpich2@:1.1') for spec in spack.repo.providers_for('mpi@2.2'))) - self.assertTrue(not any(spec.satisfies('mpich2@:1.1') - for spec in spack.repo.providers_for('mpi@2.2'))) - self.assertTrue(not any(spec.satisfies('mpich@:1') for spec in spack.repo.providers_for('mpi@2'))) diff --git a/lib/spack/spack/test/provider_index.py b/lib/spack/spack/test/provider_index.py index 15fb9acff2..7d5f997b0a 100644 --- a/lib/spack/spack/test/provider_index.py +++ b/lib/spack/spack/test/provider_index.py @@ -26,12 +26,27 @@ from StringIO import StringIO import unittest import spack +from spack.spec import Spec from spack.provider_index import ProviderIndex +from spack.test.mock_packages_test import * +# Test assume that mock packages provide this: +# +# {'blas': { +# blas: set([netlib-blas, openblas, openblas-with-lapack])}, +# 'lapack': {lapack: set([netlib-lapack, openblas-with-lapack])}, +# 'mpi': {mpi@:1: set([mpich@:1]), +# mpi@:2.0: set([mpich2]), +# mpi@:2.1: set([mpich2@1.1:]), +# mpi@:2.2: set([mpich2@1.2:]), +# mpi@:3: set([mpich@3:]), +# mpi@:10.0: set([zmpi])}, +# 'stuff': {stuff: set([externalvirtual])}} +# -class ProviderIndexTest(unittest.TestCase): +class ProviderIndexTest(MockPackagesTest): - def test_write_and_read(self): + def test_yaml_round_trip(self): p = ProviderIndex(spack.repo.all_package_names()) ostream = StringIO() @@ -40,10 +55,46 @@ class ProviderIndexTest(unittest.TestCase): istream = StringIO(ostream.getvalue()) q = ProviderIndex.from_yaml(istream) - self.assertTrue(p == q) + self.assertEqual(p, q) + + + def test_providers_for_simple(self): + p = ProviderIndex(spack.repo.all_package_names()) + + blas_providers = p.providers_for('blas') + self.assertTrue(Spec('netlib-blas') in blas_providers) + self.assertTrue(Spec('openblas') in blas_providers) + self.assertTrue(Spec('openblas-with-lapack') in blas_providers) + + lapack_providers = p.providers_for('lapack') + self.assertTrue(Spec('netlib-lapack') in lapack_providers) + self.assertTrue(Spec('openblas-with-lapack') in lapack_providers) + + + def test_mpi_providers(self): + p = ProviderIndex(spack.repo.all_package_names()) + + mpi_2_providers = p.providers_for('mpi@2') + self.assertTrue(Spec('mpich2') in mpi_2_providers) + self.assertTrue(Spec('mpich@3:') in mpi_2_providers) + + mpi_3_providers = p.providers_for('mpi@3') + self.assertTrue(Spec('mpich2') not in mpi_3_providers) + self.assertTrue(Spec('mpich@3:') in mpi_3_providers) + self.assertTrue(Spec('zmpi') in mpi_3_providers) + + + def test_equal(self): + p = ProviderIndex(spack.repo.all_package_names()) + q = ProviderIndex(spack.repo.all_package_names()) + self.assertEqual(p, q) def test_copy(self): p = ProviderIndex(spack.repo.all_package_names()) q = p.copy() - self.assertTrue(p == q) + self.assertEqual(p, q) + + + def test_copy(self): + pass |