diff options
-rw-r--r-- | lib/spack/spack/test/util/package_hash.py | 130 | ||||
-rw-r--r-- | lib/spack/spack/util/package_hash.py | 166 |
2 files changed, 261 insertions, 35 deletions
diff --git a/lib/spack/spack/test/util/package_hash.py b/lib/spack/spack/test/util/package_hash.py index aedb42e20d..395007f41e 100644 --- a/lib/spack/spack/test/util/package_hash.py +++ b/lib/spack/spack/test/util/package_hash.py @@ -150,3 +150,133 @@ def test_remove_directives(): for name in spack.directives.directive_names: assert name not in unparsed + + +many_multimethods = """\ +class Pkg: + def foo(self): + print("ONE") + + @when("@1.0") + def foo(self): + print("TWO") + + @when("@2.0") + @when(sys.platform == "darwin") + def foo(self): + print("THREE") + + @when("@3.0") + def foo(self): + print("FOUR") + + # this one should always stay + @run_after("install") + def some_function(self): + print("FIVE") +""" + + +def test_multimethod_resolution(tmpdir): + when_pkg = tmpdir.join("pkg.py") + with when_pkg.open("w") as f: + f.write(many_multimethods) + + # all are false but the default + filtered = ph.canonical_source("pkg@4.0", str(when_pkg)) + assert "ONE" in filtered + assert "TWO" not in filtered + assert "THREE" not in filtered + assert "FOUR" not in filtered + assert "FIVE" in filtered + + # we know first @when overrides default and others are false + filtered = ph.canonical_source("pkg@1.0", str(when_pkg)) + assert "ONE" not in filtered + assert "TWO" in filtered + assert "THREE" not in filtered + assert "FOUR" not in filtered + assert "FIVE" in filtered + + # we know last @when overrides default and others are false + filtered = ph.canonical_source("pkg@3.0", str(when_pkg)) + assert "ONE" not in filtered + assert "TWO" not in filtered + assert "THREE" not in filtered + assert "FOUR" in filtered + assert "FIVE" in filtered + + # we don't know if default or THREE will win, include both + filtered = ph.canonical_source("pkg@2.0", str(when_pkg)) + assert "ONE" in filtered + assert "TWO" not in filtered + assert "THREE" in filtered + assert "FOUR" not in filtered + assert "FIVE" in filtered + + +more_dynamic_multimethods = """\ +class Pkg: + @when(sys.platform == "darwin") + def foo(self): + print("ONE") + + @when("@1.0") + def foo(self): + print("TWO") + + # this one isn't dynamic, but an int fails the Spec parse, + # so it's kept because it has to be evaluated at runtime. + @when("@2.0") + @when(1) + def foo(self): + print("THREE") + + @when("@3.0") + def foo(self): + print("FOUR") + + # this one should always stay + @run_after("install") + def some_function(self): + print("FIVE") +""" + + +def test_more_dynamic_multimethod_resolution(tmpdir): + when_pkg = tmpdir.join("pkg.py") + with when_pkg.open("w") as f: + f.write(more_dynamic_multimethods) + + # we know the first one is the only one that can win. + filtered = ph.canonical_source("pkg@4.0", str(when_pkg)) + assert "ONE" in filtered + assert "TWO" not in filtered + assert "THREE" not in filtered + assert "FOUR" not in filtered + assert "FIVE" in filtered + + # now we have to include ONE and TWO because ONE may win dynamically. + filtered = ph.canonical_source("pkg@1.0", str(when_pkg)) + assert "ONE" in filtered + assert "TWO" in filtered + assert "THREE" not in filtered + assert "FOUR" not in filtered + assert "FIVE" in filtered + + # we know FOUR is true and TWO and THREE are false, but ONE may + # still win dynamically. + filtered = ph.canonical_source("pkg@3.0", str(when_pkg)) + assert "ONE" in filtered + assert "TWO" not in filtered + assert "THREE" not in filtered + assert "FOUR" in filtered + assert "FIVE" in filtered + + # TWO and FOUR can't be satisfied, but ONE or THREE could win + filtered = ph.canonical_source("pkg@2.0", str(when_pkg)) + assert "ONE" in filtered + assert "TWO" not in filtered + assert "THREE" in filtered + assert "FOUR" not in filtered + assert "FIVE" in filtered diff --git a/lib/spack/spack/util/package_hash.py b/lib/spack/spack/util/package_hash.py index 3421d90d23..cabb5f8613 100644 --- a/lib/spack/spack/util/package_hash.py +++ b/lib/spack/spack/util/package_hash.py @@ -11,7 +11,9 @@ import spack.error import spack.package import spack.repo import spack.spec +import spack.util.hash import spack.util.naming +from spack.util.unparse import unparse class RemoveDocstrings(ast.NodeTransformer): @@ -82,70 +84,164 @@ class RemoveDirectives(ast.NodeTransformer): class TagMultiMethods(ast.NodeVisitor): - """Tag @when-decorated methods in a spec.""" + """Tag @when-decorated methods in a package AST.""" def __init__(self, spec): self.spec = spec + # map from function name to (implementation, condition_list) tuples self.methods = {} - def visit_FunctionDef(self, node): # noqa - nodes = self.methods.setdefault(node.name, []) - if node.decorator_list: - dec = node.decorator_list[0] + def visit_FunctionDef(self, func): # noqa + conditions = [] + for dec in func.decorator_list: if isinstance(dec, ast.Call) and dec.func.id == 'when': try: + # evaluate spec condition for any when's cond = dec.args[0].s - nodes.append( - (node, self.spec.satisfies(cond, strict=True))) + conditions.append(self.spec.satisfies(cond, strict=True)) except AttributeError: # In this case the condition for the 'when' decorator is # not a string literal (for example it may be a Python - # variable name). Therefore the function is added - # unconditionally since we don't know whether the - # constraint applies or not. - nodes.append((node, None)) - else: - nodes.append((node, None)) + # variable name). We append None because we don't know + # whether the constraint applies or not, and it should be included + # unless some other constraint is False. + conditions.append(None) + + # anything defined without conditions will overwrite prior definitions + if not conditions: + self.methods[func.name] = [] + + # add all discovered conditions on this node to the node list + impl_conditions = self.methods.setdefault(func.name, []) + impl_conditions.append((func, conditions)) + + # don't modify the AST -- return the untouched function node + return func class ResolveMultiMethods(ast.NodeTransformer): - """Remove methods which do not exist if their @when is not satisfied.""" + """Remove multi-methods when we know statically that they won't be used. + + Say we have multi-methods like this:: + + class SomePackage: + def foo(self): print("implementation 1") + + @when("@1.0") + def foo(self): print("implementation 2") + + @when("@2.0") + @when(sys.platform == "darwin") + def foo(self): print("implementation 3") + + @when("@3.0") + def foo(self): print("implementation 4") + + The multimethod that will be chosen at runtime depends on the package spec and on + whether we're on the darwin platform *at build time* (the darwin condition for + implementation 3 is dynamic). We know the package spec statically; we don't know + statically what the runtime environment will be. We need to include things that can + possibly affect package behavior in the package hash, and we want to exclude things + when we know that they will not affect package behavior. + + If we're at version 4.0, we know that implementation 1 will win, because some @when + for 2, 3, and 4 will be `False`. We should only include implementation 1. + + If we're at version 1.0, we know that implementation 2 will win, because it + overrides implementation 1. We should only include implementation 2. + + If we're at version 3.0, we know that implementation 4 will win, because it + overrides implementation 1 (the default), and some @when on all others will be + False. + + If we're at version 2.0, it's a bit more complicated. We know we can remove + implementations 2 and 4, because their @when's will never be satisfied. But, the + choice between implementations 1 and 3 will happen at runtime (this is a bad example + because the spec itself has platform information, and we should prefer to use that, + but we allow arbitrary boolean expressions in @when's, so this example suffices). + For this case, we end up needing to include *both* implementation 1 and 3 in the + package hash, because either could be chosen. + + """ def __init__(self, methods): self.methods = methods - def resolve(self, node): - if node.name not in self.methods: - raise PackageHashError( - "Future traversal visited new node: %s" % node.name) - - result = None - for n, cond in self.methods[node.name]: - if cond: - return n - if cond is None: - result = n + def resolve(self, impl_conditions): + """Given list of nodes and conditions, figure out which node will be chosen.""" + result = [] + default = None + for impl, conditions in impl_conditions: + # if there's a default implementation with no conditions, remember that. + if not conditions: + default = impl + result.append(default) + continue + + # any known-false @when means the method won't be used + if any(c is False for c in conditions): + continue + + # anything with all known-true conditions will be picked if it's first + if all(c is True for c in conditions): + if result and result[0] is default: + return [impl] # we know the first MM will always win + # if anything dynamic comes before it we don't know if it'll win, + # so just let this result get appended + + # anything else has to be determined dynamically, so add it to a list + result.append(impl) + + # if nothing was picked, the last definition wins. return result - def visit_FunctionDef(self, node): # noqa - if self.resolve(node) is node: - node.decorator_list = [] - return node - return None + def visit_FunctionDef(self, func): # noqa + # if the function def wasn't visited on the first traversal there is a problem + assert func.name in self.methods, "Inconsistent package traversal!" + + # if the function is a multimethod, need to resolve it statically + impl_conditions = self.methods[func.name] + + resolutions = self.resolve(impl_conditions) + if not any(r is func for r in resolutions): + # multimethod did not resolve to this function; remove it + return None + + # if we get here, this function is a possible resolution for a multi-method. + # it might be the only one, or there might be several that have to be evaluated + # dynamcially. Either way, we include the function. + + # strip the when decorators (preserve the rest) + func.decorator_list = [ + dec for dec in func.decorator_list + if not (isinstance(dec, ast.Call) and dec.func.id == 'when') + ] + return func def package_content(spec): return ast.dump(package_ast(spec)) +def canonical_source(spec, filename=None): + return unparse(package_ast(spec, filename=filename), py_ver_consistent=True) + + +def canonical_source_hash(spec, filename=None): + source = canonical_source(spec, filename) + return spack.util.hash.b32_hash(source) + + def package_hash(spec, content=None): if content is None: content = package_content(spec) return hashlib.sha256(content.encode('utf-8')).digest().lower() -def package_ast(spec): +def package_ast(spec, filename=None): spec = spack.spec.Spec(spec) - filename = spack.repo.path.filename_for_package_name(spec.name) + if not filename: + filename = spack.repo.path.filename_for_package_name(spec.name) + with open(filename) as f: text = f.read() root = ast.parse(text) @@ -154,10 +250,10 @@ def package_ast(spec): RemoveDirectives(spec).visit(root) - fmm = TagMultiMethods(spec) - fmm.visit(root) + tagger = TagMultiMethods(spec) + tagger.visit(root) - root = ResolveMultiMethods(fmm.methods).visit(root) + root = ResolveMultiMethods(tagger.methods).visit(root) return root |