From db07c7f611da4ddc6127dde28d48e5624e4f1172 Mon Sep 17 00:00:00 2001 From: Todd Gamblin Date: Tue, 15 Oct 2013 03:04:25 -0700 Subject: Spec constraints and normalization now work. - Specs can be "constrained" by other specs, throw exceptions when constraint can't be satisfied. - Normalize will put a spec in DAG form and merge all package constraints with the spec. - Ready to add concretization policies for abstract specs now. --- lib/spack/spack/cmd/spec.py | 5 +- lib/spack/spack/globals.py | 2 +- lib/spack/spack/package.py | 2 +- lib/spack/spack/packages/__init__.py | 12 +- lib/spack/spack/packages/callpath.py | 14 ++ lib/spack/spack/packages/dyninst.py | 14 ++ lib/spack/spack/packages/libdwarf.py | 2 +- lib/spack/spack/packages/mpich.py | 11 ++ lib/spack/spack/packages/mpileaks.py | 14 ++ lib/spack/spack/parse.py | 30 ++-- lib/spack/spack/spec.py | 336 +++++++++++++++++++---------------- lib/spack/spack/test/concretize.py | 6 +- lib/spack/spack/test/specs.py | 31 ++++ lib/spack/spack/test/versions.py | 19 ++ lib/spack/spack/util/lang.py | 63 +++++++ lib/spack/spack/util/string.py | 23 +++ lib/spack/spack/version.py | 80 +++++++-- 17 files changed, 473 insertions(+), 191 deletions(-) create mode 100644 lib/spack/spack/packages/callpath.py create mode 100644 lib/spack/spack/packages/dyninst.py create mode 100644 lib/spack/spack/packages/mpich.py create mode 100644 lib/spack/spack/packages/mpileaks.py create mode 100644 lib/spack/spack/util/string.py (limited to 'lib') diff --git a/lib/spack/spack/cmd/spec.py b/lib/spack/spack/cmd/spec.py index 5c389bd04a..f87e94dae6 100644 --- a/lib/spack/spack/cmd/spec.py +++ b/lib/spack/spack/cmd/spec.py @@ -12,6 +12,5 @@ def setup_parser(subparser): def spec(parser, args): specs = spack.cmd.parse_specs(args.specs) for spec in specs: - print spec.colorized() - print " --> ", spec.concretized().colorized() - print spec.concretized().concrete() + spec.normalize() + print spec.tree() diff --git a/lib/spack/spack/globals.py b/lib/spack/spack/globals.py index a9fa25c784..3b2359e21d 100644 --- a/lib/spack/spack/globals.py +++ b/lib/spack/spack/globals.py @@ -32,7 +32,7 @@ install_path = new_path(prefix, "opt") install_layout = DefaultDirectoryLayout(install_path) # Version information -spack_version = Version("0.2") +spack_version = Version("0.5") # User's editor from the environment editor = Executable(os.environ.get("EDITOR", "")) diff --git a/lib/spack/spack/package.py b/lib/spack/spack/package.py index 7e11c92d2d..9eece8afcf 100644 --- a/lib/spack/spack/package.py +++ b/lib/spack/spack/package.py @@ -381,7 +381,7 @@ class Package(object): @property @memoized def all_dependencies(self): - """Set of all transitive dependencies of this package.""" + """Dict(str -> Package) of all transitive dependencies of this package.""" all_deps = set(self.dependencies) for dep in self.dependencies: dep_pkg = packages.get(dep.name) diff --git a/lib/spack/spack/packages/__init__.py b/lib/spack/spack/packages/__init__.py index b9987aa040..6d515274b6 100644 --- a/lib/spack/spack/packages/__init__.py +++ b/lib/spack/spack/packages/__init__.py @@ -9,6 +9,7 @@ import spack import spack.error import spack.spec from spack.util.filesystem import new_path +from spack.util.lang import list_modules import spack.arch as arch # Valid package names can contain '-' but can't start with it. @@ -19,13 +20,12 @@ invalid_package_re = r'[_-][_-]+' instances = {} -def get(spec): - spec = spack.spec.make_spec(spec) - if not spec in instances: - package_class = get_class_for_package_name(spec.name) - instances[spec] = package_class(spec) +def get(pkg_name): + if not pkg_name in instances: + package_class = get_class_for_package_name(pkg_name) + instances[pkg_name] = package_class(pkg_name) - return instances[spec] + return instances[pkg_name] def valid_package_name(pkg_name): diff --git a/lib/spack/spack/packages/callpath.py b/lib/spack/spack/packages/callpath.py new file mode 100644 index 0000000000..958960e0ab --- /dev/null +++ b/lib/spack/spack/packages/callpath.py @@ -0,0 +1,14 @@ +from spack import * + +class Callpath(Package): + homepage = "https://github.com/tgamblin/callpath" + url = "http://github.com/tgamblin/callpath-0.2.tar.gz" + md5 = "foobarbaz" + + depends_on("dyninst") + depends_on("mpich") + + def install(self, prefix): + configure("--prefix=%s" % prefix) + make() + make("install") diff --git a/lib/spack/spack/packages/dyninst.py b/lib/spack/spack/packages/dyninst.py new file mode 100644 index 0000000000..f550cde54f --- /dev/null +++ b/lib/spack/spack/packages/dyninst.py @@ -0,0 +1,14 @@ +from spack import * + +class Dyninst(Package): + homepage = "https://paradyn.org" + url = "http://www.dyninst.org/sites/default/files/downloads/dyninst/8.1.2/DyninstAPI-8.1.2.tgz" + md5 = "bf03b33375afa66fe0efa46ce3f4b17a" + + depends_on("libelf") + depends_on("libdwarf") + + def install(self, prefix): + configure("--prefix=%s" % prefix) + make() + make("install") diff --git a/lib/spack/spack/packages/libdwarf.py b/lib/spack/spack/packages/libdwarf.py index bae701b38b..edaba6a216 100644 --- a/lib/spack/spack/packages/libdwarf.py +++ b/lib/spack/spack/packages/libdwarf.py @@ -11,7 +11,7 @@ class Libdwarf(Package): list_url = "http://reality.sgiweb.org/davea/dwarf.html" - depends_on("libelf") + depends_on("libelf@0:1") def clean(self): diff --git a/lib/spack/spack/packages/mpich.py b/lib/spack/spack/packages/mpich.py new file mode 100644 index 0000000000..d8cd67d528 --- /dev/null +++ b/lib/spack/spack/packages/mpich.py @@ -0,0 +1,11 @@ +from spack import * + +class Mpich(Package): + homepage = "http://www.mpich.org" + url = "http://www.mpich.org/static/downloads/3.0.4/mpich-3.0.4.tar.gz" + md5 = "9c5d5d4fe1e17dd12153f40bc5b6dbc0" + + def install(self, prefix): + configure("--prefix=%s" % prefix) + make() + make("install") diff --git a/lib/spack/spack/packages/mpileaks.py b/lib/spack/spack/packages/mpileaks.py new file mode 100644 index 0000000000..224557cc52 --- /dev/null +++ b/lib/spack/spack/packages/mpileaks.py @@ -0,0 +1,14 @@ +from spack import * + +class Mpileaks(Package): + homepage = "http://www.llnl.gov" + url = "http://www.llnl.gov/mpileaks-1.0.tar.gz" + md5 = "foobarbaz" + + depends_on("mpich") + depends_on("callpath") + + def install(self, prefix): + configure("--prefix=%s" % prefix) + make() + make("install") diff --git a/lib/spack/spack/parse.py b/lib/spack/spack/parse.py index 5bcdced6ca..5431aa454d 100644 --- a/lib/spack/spack/parse.py +++ b/lib/spack/spack/parse.py @@ -1,20 +1,6 @@ import re -import spack.error as err import itertools - - -class ParseError(err.SpackError): - """Raised when we don't hit an error while parsing.""" - def __init__(self, message, string, pos): - super(ParseError, self).__init__(message) - self.string = string - self.pos = pos - - -class LexError(ParseError): - """Raised when we don't know how to lex something.""" - def __init__(self, message, string, pos): - super(LexError, self).__init__(message, string, pos) +import spack.error class Token: @@ -109,3 +95,17 @@ class Parser(object): self.text = text self.push_tokens(self.lexer.lex(text)) return self.do_parse() + + +class ParseError(spack.error.SpackError): + """Raised when we don't hit an error while parsing.""" + def __init__(self, message, string, pos): + super(ParseError, self).__init__(message) + self.string = string + self.pos = pos + + +class LexError(ParseError): + """Raised when we don't know how to lex something.""" + def __init__(self, message, string, pos): + super(LexError, self).__init__(message, string, pos) diff --git a/lib/spack/spack/spec.py b/lib/spack/spack/spec.py index d7998a8fb1..edc97c7c3b 100644 --- a/lib/spack/spack/spec.py +++ b/lib/spack/spack/spec.py @@ -62,7 +62,6 @@ specs to avoid ambiguity. Both are provided because ~ can cause shell expansion when it is the first character in an id typed on the command line. """ import sys -from functools import total_ordering from StringIO import StringIO import tty @@ -72,8 +71,11 @@ import spack.compilers import spack.compilers.gcc import spack.packages as packages import spack.arch as arch + from spack.version import * from spack.color import * +from spack.util.lang import * +from spack.util.string import * """This map determines the coloring of specs when using color output. We make the fields different colors to enhance readability. @@ -109,6 +111,7 @@ def colorize_spec(spec): return colorize(re.sub(separators, insert_color(), str(spec)) + '@.') +@key_ordering class Compiler(object): """The Compiler field represents the compiler or range of compiler versions that a package should be built with. Compilers have a @@ -128,6 +131,19 @@ class Compiler(object): self.versions.add(version) + def satisfies(self, other): + return (self.name == other.name and + self.versions.overlaps(other.versions)) + + + def constrain(self, other): + if not self.satisfies(other.compiler): + raise UnsatisfiableCompilerSpecError( + "%s does not satisfy %s" % (self.compiler, other.compiler)) + + self.versions.intersect(other.versions) + + @property def concrete(self): return self.versions.concrete @@ -163,16 +179,8 @@ class Compiler(object): return clone - def __eq__(self, other): - return (self.name, self.versions) == (other.name, other.versions) - - - def __ne__(self, other): - return not (self == other) - - - def __hash__(self): - return hash((self.name, self.versions)) + def _cmp_key(self): + return (self.name, self.versions) def __str__(self): @@ -183,7 +191,7 @@ class Compiler(object): return out -@total_ordering +@key_ordering class Variant(object): """Variants are named, build-time options for a package. Names depend on the particular package being built, and each named variant can @@ -194,67 +202,21 @@ class Variant(object): self.enabled = enabled - def __eq__(self, other): - return self.name == other.name and self.enabled == other.enabled - - - def __ne__(self, other): - return not (self == other) - - - @property - def tuple(self): + def _cmp_key(self): return (self.name, self.enabled) - def __hash__(self): - return hash(self.tuple) - - - def __lt__(self, other): - return self.tuple < other.tuple - - def __str__(self): out = '+' if self.enabled else '~' return out + self.name - -@total_ordering -class HashableMap(dict): - """This is a hashable, comparable dictionary. Hash is performed on - a tuple of the values in the dictionary.""" - def __eq__(self, other): - return (len(self) == len(other) and - sorted(self.values()) == sorted(other.values())) - - - def __ne__(self, other): - return not (self == other) - - - def __lt__(self, other): - return tuple(sorted(self.values())) < tuple(sorted(other.values())) - - - def __hash__(self): - return hash(tuple(sorted(self.values()))) - - - def copy(self): - """Type-agnostic clone method. Preserves subclass type.""" - # Construct a new dict of my type - T = type(self) - clone = T() - - # Copy everything from this dict into it. - for key in self: - clone[key] = self[key] - return clone +class VariantMap(HashableMap): + def satisfies(self, other): + return all(self[key].enabled == other[key].enabled + for key in other if key in self) -class VariantMap(HashableMap): def __str__(self): sorted_keys = sorted(self.keys()) return ''.join(str(self[key]) for key in sorted_keys) @@ -268,13 +230,18 @@ class DependencyMap(HashableMap): return all(d.concrete for d in self.values()) + def satisfies(self, other): + return all(self[name].satisfies(other[name]) for name in self + if name in other) + + def __str__(self): - sorted_keys = sorted(self.keys()) + sorted_dep_names = sorted(self.keys()) return ''.join( - ["^" + str(self[name]) for name in sorted_keys]) + ["^" + str(self[name]) for name in sorted_dep_names]) -@total_ordering +@key_ordering class Spec(object): def __init__(self, name): self.name = name @@ -322,11 +289,11 @@ class Spec(object): @property def concrete(self): - return (self.versions.concrete - # TODO: support variants - and self.architecture - and self.compiler and self.compiler.concrete - and self.dependencies.concrete) + return bool(self.versions.concrete + # TODO: support variants + and self.architecture + and self.compiler and self.compiler.concrete + and self.dependencies.concrete) def _concretize(self): @@ -349,7 +316,7 @@ class Spec(object): """ # TODO: modularize the process of selecting concrete versions. # There should be a set of user-configurable policies for these decisions. - self.check_sanity() + self.validate() # take the system's architecture for starters if not self.architecture: @@ -370,60 +337,118 @@ class Spec(object): # Ensure dependencies have right versions + @property + def traverse_deps(self, visited=None): + """Yields dependencies in depth-first order""" + if not visited: + visited = set() + + for name in sorted(self.dependencies.keys()): + dep = dependencies[name] + if dep in visited: + continue - def check_sanity(self): - """Check names of packages and dependency validity.""" - self.check_package_name_sanity() - self.check_dependency_sanity() - self.check_dependence_constraint_sanity() - + for d in dep.traverse_deps(seen): + yield d + yield dep - def check_package_name_sanity(self): - """Ensure that all packages mentioned in the spec exist.""" - packages.get(self.name) - for dep in self.dependencies.values(): - packages.get(dep.name) + def _normalize_helper(self, visited, spec_deps): + """Recursive helper function for _normalize.""" + if self.name in visited: + return + visited.add(self.name) - def check_dependency_sanity(self): - """Ensure that dependencies specified on the spec are actual - dependencies of the package it represents. - """ + # Combine constraints from package dependencies with + # information in this spec's dependencies. pkg = packages.get(self.name) - dep_names = set(dep.name for dep in pkg.all_dependencies) - invalid_dependencies = [d.name for d in self.dependencies.values() - if d.name not in dep_names] - if invalid_dependencies: + for pkg_dep in pkg.dependencies: + name = pkg_dep.name + + if name not in spec_deps: + # Clone the spec from the package + spec_deps[name] = pkg_dep.copy() + + try: + # intersect package information with spec info + spec_deps[name].constrain(pkg_dep) + except UnsatisfiableSpecError, e: + error_type = type(e) + raise error_type( + "Violated depends_on constraint from package %s: %s" + % (self.name, e.message)) + + # Add merged spec to my deps and recurse + self.dependencies[name] = spec_deps[name] + self.dependencies[name]._normalize_helper(visited, spec_deps) + + + def normalize(self): + if any(dep.dependencies for dep in self.dependencies.values()): + raise SpecError("Spec has already been normalized.") + + self.validate_package_names() + + spec_deps = self.dependencies + self.dependencies = DependencyMap() + + visited = set() + self._normalize_helper(visited, spec_deps) + + # If there are deps specified but not visited, they're not + # actually deps of this package. Raise an error. + extra = set(spec_deps.viewkeys()).difference(visited) + if extra: raise InvalidDependencyException( - "The packages (%s) are not dependencies of %s" % - (','.join(invalid_dependencies), self.name)) + self.name + " does not depend on " + comma_or(extra)) - def check_dependence_constraint_sanity(self): - """Ensure that package's dependencies have consistent constraints on - their dependencies. - """ - pkg = packages.get(self.name) - specs = {} - for spec in pkg.all_dependencies: - if not spec.name in specs: - specs[spec.name] = spec - continue + def validate_package_names(self): + for name in self.dependencies: + packages.get(name) - merged = specs[spec.name] - # Specs in deps can't be disjoint. - if not spec.versions.overlaps(merged.versions): - raise InvalidConstraintException( - "One package %s, version constraint %s conflicts with %s" - % (pkg.name, spec.versions, merged.versions)) + def constrain(self, other): + if not self.versions.overlaps(other.versions): + raise UnsatisfiableVersionSpecError( + "%s does not satisfy %s" % (self.versions, other.versions)) + conflicting_variants = [ + v for v in other.variants if v in self.variants and + self.variants[v].enabled != other.variants[v].enabled] - def merge(self, other): - """Considering these specs as constraints, attempt to merge. - Raise an exception if specs are disjoint. - """ - pass + if conflicting_variants: + raise UnsatisfiableVariantSpecError(comma_and( + "%s does not satisfy %s" % (self.variants[v], other.variants[v]) + for v in conflicting_variants)) + + if self.architecture is not None and other.architecture is not None: + if self.architecture != other.architecture: + raise UnsatisfiableArchitectureSpecError( + "Asked for architecture %s, but required %s" + % (self.architecture, other.architecture)) + + if self.compiler is not None and other.compiler is not None: + self.compiler.constrain(other.compiler) + elif self.compiler is None: + self.compiler = other.compiler + + self.versions.intersect(other.versions) + self.variants.update(other.variants) + self.architecture = self.architecture or other.architecture + + + def satisfies(self, other): + def sat(attribute): + s = getattr(self, attribute) + o = getattr(other, attribute) + return not s or not o or s.satisfies(o) + + return (self.name == other.name and + all(sat(attr) for attr in + ('versions', 'variants', 'compiler', 'architecture')) and + # TODO: what does it mean to satisfy deps? + self.dependencies.satisfies(other.dependencies)) def concretized(self): @@ -451,43 +476,16 @@ class Spec(object): return self.versions[0] - @property - def tuple(self): + def _cmp_key(self): return (self.name, self.versions, self.variants, - self.architecture, self.compiler, self.dependencies) - - - @property - def tuple(self): - return (self.name, self.versions, self.variants, self.architecture, - self.compiler, self.dependencies) - - - def __eq__(self, other): - return self.tuple == other.tuple - - - def __ne__(self, other): - return not (self == other) - - - def __lt__(self, other): - return self.tuple < other.tuple - - - def __hash__(self): - return hash(self.tuple) + self.architecture, self.compiler) def colorized(self): return colorize_spec(self) - def __repr__(self): - return str(self) - - - def __str__(self): + def str_without_deps(self): out = self.name # If the version range is entirely open, omit it @@ -502,10 +500,26 @@ class Spec(object): if self.architecture: out += "=%s" % self.architecture - out += str(self.dependencies) return out + def tree(self, indent=""): + """Prints out this spec and its dependencies, tree-formatted + with indentation.""" + out = indent + self.str_without_deps() + for dep in sorted(self.dependencies.keys()): + out += "\n" + self.dependencies[dep].tree(indent + " ") + return out + + + def __repr__(self): + return str(self) + + + def __str__(self): + return self.str_without_deps() + str(self.dependencies) + + # # These are possible token types in the spec grammar. # @@ -580,7 +594,7 @@ class SpecParser(spack.parse.Parser): # If there was no version in the spec, consier it an open range if not added_version: - spec.versions = VersionList([':']) + spec.versions = VersionList(':') return spec @@ -721,7 +735,31 @@ class InvalidDependencyException(SpecError): super(InvalidDependencyException, self).__init__(message) -class InvalidConstraintException(SpecError): - """Raised when a package dependencies conflict.""" +class UnsatisfiableSpecError(SpecError): + """Raised when a spec conflicts with package constraints.""" + def __init__(self, message): + super(UnsatisfiableSpecError, self).__init__(message) + + +class UnsatisfiableVersionSpecError(UnsatisfiableSpecError): + """Raised when a spec version conflicts with package constraints.""" + def __init__(self, message): + super(UnsatisfiableVersionSpecError, self).__init__(message) + + +class UnsatisfiableCompilerSpecError(UnsatisfiableSpecError): + """Raised when a spec comiler conflicts with package constraints.""" + def __init__(self, message): + super(UnsatisfiableCompilerSpecError, self).__init__(message) + + +class UnsatisfiableVariantSpecError(UnsatisfiableSpecError): + """Raised when a spec variant conflicts with package constraints.""" + def __init__(self, message): + super(UnsatisfiableVariantSpecError, self).__init__(message) + + +class UnsatisfiableArchitectureSpecError(UnsatisfiableSpecError): + """Raised when a spec architecture conflicts with package constraints.""" def __init__(self, message): - super(InvalidConstraintException, self).__init__(message) + super(UnsatisfiableArchitectureSpecError, self).__init__(message) diff --git a/lib/spack/spack/test/concretize.py b/lib/spack/spack/test/concretize.py index 05a2f4811c..3a528f1b16 100644 --- a/lib/spack/spack/test/concretize.py +++ b/lib/spack/spack/test/concretize.py @@ -6,8 +6,12 @@ class ConcretizeTest(unittest.TestCase): def check_concretize(self, abstract_spec): abstract = spack.spec.parse_one(abstract_spec) + print abstract + print abstract.concretized() + print abstract.concretized().concrete self.assertTrue(abstract.concretized().concrete) def test_packages(self): - self.check_concretize("libelf") + pass + #self.check_concretize("libelf") diff --git a/lib/spack/spack/test/specs.py b/lib/spack/spack/test/specs.py index f495738a72..cb8bf79ff8 100644 --- a/lib/spack/spack/test/specs.py +++ b/lib/spack/spack/test/specs.py @@ -59,6 +59,25 @@ class SpecTest(unittest.TestCase): # Only check the type for non-identifiers. self.assertEqual(tok.type, spec_tok.type) + + def check_satisfies(self, lspec, rspec): + l = spack.spec.parse_one(lspec) + r = spack.spec.parse_one(rspec) + self.assertTrue(l.satisfies(r) and r.satisfies(l)) + + # These should not raise + l.constrain(r) + r.constrain(l) + + + def check_constrain(self, expected, constrained, constraint): + exp = spack.spec.parse_one(expected) + constrained = spack.spec.parse_one(constrained) + constraint = spack.spec.parse_one(constraint) + constrained.constrain(constraint) + self.assertEqual(exp, constrained) + + # ================================================================================ # Parse checks # =============================================================================== @@ -117,6 +136,18 @@ class SpecTest(unittest.TestCase): self.assertRaises(DuplicateCompilerError, self.check_parse, "x ^y%gcc%intel") + # ================================================================================ + # Satisfiability and constraints + # ================================================================================ + def test_satisfies(self): + self.check_satisfies('libelf@0.8.13', 'libelf@0:1') + self.check_satisfies('libdwarf^libelf@0.8.13', 'libdwarf^libelf@0:1') + + + def test_constrain(self): + self.check_constrain('libelf@0:1', 'libelf', 'libelf@0:1') + + # ================================================================================ # Lex checks # ================================================================================ diff --git a/lib/spack/spack/test/versions.py b/lib/spack/spack/test/versions.py index 09d74549ef..135e3a031f 100644 --- a/lib/spack/spack/test/versions.py +++ b/lib/spack/spack/test/versions.py @@ -59,6 +59,10 @@ class VersionsTest(unittest.TestCase): self.assertFalse(ver(v1).overlaps(ver(v2))) + def check_intersection(self, expected, a, b): + self.assertEqual(ver(expected), ver(a).intersection(ver(b))) + + def test_two_segments(self): self.assert_ver_eq('1.0', '1.0') self.assert_ver_lt('1.0', '2.0') @@ -215,6 +219,7 @@ class VersionsTest(unittest.TestCase): self.assert_overlaps('1.2:', '1.6:') self.assert_overlaps(':', ':') self.assert_overlaps(':', '1.6:1.9') + self.assert_overlaps('1.6:1.9', ':') def test_lists_overlap(self): @@ -258,3 +263,17 @@ class VersionsTest(unittest.TestCase): self.assert_canonical([':'], [':,1.3, 1.3.1,1.3.9,1.4 : 1.5 , 1.3 : 1.4']) + + + def test_intersection(self): + self.check_intersection('2.5', + '1.0:2.5', '2.5:3.0') + self.check_intersection('2.5:2.7', + '1.0:2.7', '2.5:3.0') + self.check_intersection('0:1', ':', '0:1') + + self.check_intersection(['1.0', '2.5:2.7'], + ['1.0:2.7'], ['2.5:3.0','1.0']) + self.check_intersection(['2.5:2.7'], + ['1.1:2.7'], ['2.5:3.0','1.0']) + self.check_intersection(['0:1'], [':'], ['0:1']) diff --git a/lib/spack/spack/util/lang.py b/lib/spack/spack/util/lang.py index 0d9b7e32bb..92532c109f 100644 --- a/lib/spack/spack/util/lang.py +++ b/lib/spack/spack/util/lang.py @@ -1,8 +1,20 @@ import os import re +import sys import functools +import inspect from spack.util.filesystem import new_path + +def has_method(cls, name): + for base in inspect.getmro(cls): + if base is object: + continue + if name in base.__dict__: + return True + return False + + def memoized(obj): """Decorator that caches the results of a function, storing them in an attribute of that function.""" @@ -30,3 +42,54 @@ def list_modules(directory): elif name.endswith('.py'): yield re.sub('.py$', '', name) + + +def key_ordering(cls): + """Decorates a class with extra methods that implement rich comparison + operations and __hash__. The decorator assumes that the class + implements a function called _cmp_key(). The rich comparison operations + will compare objects using this key, and the __hash__ function will + return the hash of this key. + + If a class already has __eq__, __ne__, __lt__, __le__, __gt__, or __ge__ + defined, this decorator will overwrite them. If the class does not + have a _cmp_key method, then this will raise a TypeError. + """ + def setter(name, value): + value.__name__ = name + setattr(cls, name, value) + + if not has_method(cls, '_cmp_key'): + raise TypeError("'%s' doesn't define _cmp_key()." % cls.__name__) + + setter('__eq__', lambda s,o: o is not None and s._cmp_key() == o._cmp_key()) + setter('__lt__', lambda s,o: o is not None and s._cmp_key() < o._cmp_key()) + setter('__le__', lambda s,o: o is not None and s._cmp_key() <= o._cmp_key()) + + setter('__ne__', lambda s,o: o is None or s._cmp_key() != o._cmp_key()) + setter('__gt__', lambda s,o: o is None or s._cmp_key() > o._cmp_key()) + setter('__ge__', lambda s,o: o is None or s._cmp_key() >= o._cmp_key()) + + setter('__hash__', lambda self: hash(self._cmp_key())) + + return cls + + +@key_ordering +class HashableMap(dict): + """This is a hashable, comparable dictionary. Hash is performed on + a tuple of the values in the dictionary.""" + def _cmp_key(self): + return tuple(sorted(self.values())) + + + def copy(self): + """Type-agnostic clone method. Preserves subclass type.""" + # Construct a new dict of my type + T = type(self) + clone = T() + + # Copy everything from this dict into it. + for key in self: + clone[key] = self[key] + return clone diff --git a/lib/spack/spack/util/string.py b/lib/spack/spack/util/string.py new file mode 100644 index 0000000000..466ea91148 --- /dev/null +++ b/lib/spack/spack/util/string.py @@ -0,0 +1,23 @@ + +def comma_list(sequence, article=''): + if type(sequence) != list: + sequence = list(sequence) + + if not sequence: + return + elif len(sequence) == 1: + return sequence[0] + else: + out = ', '.join(str(s) for s in sequence[:-1]) + out += ', ' + if article: + out += article + ' ' + out += str(sequence[-1]) + return out + +def comma_or(sequence): + return comma_list(sequence, 'or') + + +def comma_and(sequence): + return comma_list(sequence, 'and') diff --git a/lib/spack/spack/version.py b/lib/spack/spack/version.py index 9ed7b465bb..15b4027606 100644 --- a/lib/spack/spack/version.py +++ b/lib/spack/spack/version.py @@ -13,8 +13,10 @@ be called on any of the types: __eq__, __ne__, __lt__, __gt__, __ge__, __le__, __hash__ __contains__ + satisfies overlaps - merge + union + intersection concrete True if the Version, VersionRange or VersionList represents a single version. @@ -161,6 +163,7 @@ class Version(object): def concrete(self): return self + @coerced def __lt__(self, other): """Version comparison is designed for consistency with the way RPM @@ -219,13 +222,21 @@ class Version(object): @coerced - def merge(self, other): + def union(self, other): if self == other: return self else: return VersionList([self, other]) + @coerced + def intersection(self, other): + if self == other: + return self + else: + return VersionList() + + @total_ordering class VersionRange(object): def __init__(self, start, end): @@ -295,9 +306,21 @@ class VersionRange(object): @coerced - def merge(self, other): - return VersionRange(none_low.min(self.start, other.start), - none_high.max(self.end, other.end)) + def union(self, other): + if self.overlaps(other): + return VersionRange(none_low.min(self.start, other.start), + none_high.max(self.end, other.end)) + else: + return VersionList([self, other]) + + + @coerced + def intersection(self, other): + if self.overlaps(other): + return VersionRange(none_low.max(self.start, other.start), + none_high.min(self.end, other.end)) + else: + return VersionList() def __hash__(self): @@ -338,12 +361,12 @@ class VersionList(object): i = bisect_left(self, version) while i-1 >= 0 and version.overlaps(self[i-1]): - version = version.merge(self[i-1]) + version = version.union(self[i-1]) del self.versions[i-1] i -= 1 while i < len(self) and version.overlaps(self[i]): - version = version.merge(self[i]) + version = version.union(self[i]) del self.versions[i] self.versions.insert(i, version) @@ -384,25 +407,54 @@ class VersionList(object): return self[-1].highest() + def satisfies(self, other): + """Synonym for overlaps.""" + return self.overlaps(other) + + @coerced def overlaps(self, other): if not other or not self: return False - i = o = 0 - while i < len(self) and o < len(other): - if self[i].overlaps(other[o]): + s = o = 0 + while s < len(self) and o < len(other): + if self[s].overlaps(other[o]): return True - elif self[i] < other[o]: - i += 1 + elif self[s] < other[o]: + s += 1 else: o += 1 return False @coerced - def merge(self, other): - return VersionList(self.versions + other.versions) + def update(self, other): + for v in other.versions: + self.add(v) + + + @coerced + def union(self, other): + result = self.copy() + result.update(other) + return result + + + @coerced + def intersection(self, other): + # TODO: make this faster. This is O(n^2). + result = VersionList() + for s in self: + for o in other: + result.add(s.intersection(o)) + return result + + + @coerced + def intersect(self, other): + isection = self.intersection(other) + self.versions = isection.versions @coerced -- cgit v1.2.3-70-g09d2