summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorTodd Gamblin <tgamblin@llnl.gov>2020-06-07 15:54:58 -0700
committerTodd Gamblin <tgamblin@llnl.gov>2020-11-17 10:04:13 -0800
commit8a6207aa70221070724944529956b49fbfbbb4e2 (patch)
treee7d4c91f80b69988f9e2db32bdc7160f9745967e /lib
parent6e31430bec0d72169ec6844abda96894b4b143be (diff)
downloadspack-8a6207aa70221070724944529956b49fbfbbb4e2.tar.gz
spack-8a6207aa70221070724944529956b49fbfbbb4e2.tar.bz2
spack-8a6207aa70221070724944529956b49fbfbbb4e2.tar.xz
spack-8a6207aa70221070724944529956b49fbfbbb4e2.zip
concretizer: handle versions with choice construct rather than conflicts
Use '1 { version(x); version(y); version(z) } 1.' instead of declaring conflicts for non-matching versions. This keeps the sense of version clauses positive, which will allow them to be used more easily in conditionals later. Also refactor `spec_clauses()` method to return clauses that can be used in conditions, etc. instead of just printing out facts.
Diffstat (limited to 'lib')
-rw-r--r--lib/spack/spack/solver/asp.py135
1 files changed, 95 insertions, 40 deletions
diff --git a/lib/spack/spack/solver/asp.py b/lib/spack/spack/solver/asp.py
index d573d8d074..9a80dbe83f 100644
--- a/lib/spack/spack/solver/asp.py
+++ b/lib/spack/spack/solver/asp.py
@@ -30,11 +30,6 @@ from spack.version import ver
_max_line = 80
-def _id(thing):
- """Quote string if needed for it to be a valid identifier."""
- return '"%s"' % str(thing)
-
-
def issequence(obj):
if isinstance(obj, string_types):
return False
@@ -59,7 +54,16 @@ def specify(spec):
return spack.spec.Spec(spec)
-class AspFunction(object):
+class AspObject(object):
+ """Object representing a piece of ASP code."""
+
+
+def _id(thing):
+ """Quote string if needed for it to be a valid identifier."""
+ return thing if isinstance(thing, AspObject) else '"%s"' % str(thing)
+
+
+class AspFunction(AspObject):
def __init__(self, name):
self.name = name
self.args = []
@@ -73,7 +77,7 @@ class AspFunction(object):
self.name, ', '.join(_id(arg) for arg in self.args))
-class AspAnd(object):
+class AspAnd(AspObject):
def __init__(self, *args):
args = listify(args)
self.args = args
@@ -83,7 +87,7 @@ class AspAnd(object):
return s
-class AspOr(object):
+class AspOr(AspObject):
def __init__(self, *args):
args = listify(args)
self.args = args
@@ -92,7 +96,7 @@ class AspOr(object):
return " | ".join(str(arg) for arg in self.args)
-class AspNot(object):
+class AspNot(AspObject):
def __init__(self, arg):
self.arg = arg
@@ -100,6 +104,16 @@ class AspNot(object):
return "not %s" % self.arg
+class AspOneOf(AspObject):
+ def __init__(self, *args):
+ args = listify(args)
+ self.args = args
+
+ def __str__(self):
+ body = "; ".join(str(arg) for arg in self.args)
+ return "1 { %s } 1" % body
+
+
class AspFunctionBuilder(object):
def __getattr__(self, name):
return AspFunction(name)
@@ -112,6 +126,7 @@ class AspGenerator(object):
def __init__(self, out):
self.out = out
self.func = AspFunctionBuilder()
+ self.possible_versions = {}
def title(self, name, char):
self.out.write('\n')
@@ -133,6 +148,9 @@ class AspGenerator(object):
self.out.write("%% %s\n" % name)
self.out.write("%\n")
+ def one_of(self, *args):
+ return AspOneOf(*args)
+
def _or(self, *args):
return AspOr(*args)
@@ -158,31 +176,34 @@ class AspGenerator(object):
self.out.write(":- %s.\n" % body)
def pkg_version_rules(self, pkg):
+ """Output declared versions of a package.
+
+ This uses self.possible_versions so that we include any versions
+ that arise from a spec.
+ """
pkg = packagize(pkg)
- for v in pkg.versions:
+ for v in self.possible_versions[pkg.name]:
self.fact(fn.version_declared(pkg.name, v))
def spec_versions(self, spec):
+ """Return list of clauses expressing spec's version constraints."""
spec = specify(spec)
+ assert spec.name
if spec.concrete:
- self.rule(fn.version(spec.name, spec.version),
- fn.node(spec.name))
- else:
- versions = list(spec.package.versions)
-
- # if the spec declares a new version, add it to the
- # possibilities.
- if spec.versions.concrete and spec.version not in versions:
- self.fact(fn.version_declared(spec.name, spec.version))
- versions.append(spec.version)
-
- # conflict with any versions that do not satisfy the spec
- # TODO: need to traverse allspecs beforehand and ensure all
- # TODO: versions are known so we can disallow them.
- for v in versions:
- if not v.satisfies(spec.versions):
- self.fact(fn.version_conflict(spec.name, v))
+ return [fn.version(spec.name, spec.version)]
+
+ # version must be *one* of the ones the spec allows.
+ allowed_versions = [
+ v for v in self.possible_versions[spec.name]
+ if v.satisfies(spec.versions)
+ ]
+ predicates = [fn.version(spec.name, v) for v in allowed_versions]
+
+ # conflict with any versions that do not satisfy the spec
+ if predicates:
+ return [self.one_of(*predicates)]
+ return []
def compiler_defaults(self):
"""Facts about available compilers."""
@@ -259,30 +280,39 @@ class AspGenerator(object):
for cond, dep in conditions.items():
self.fact(fn.depends_on(dep.pkg.name, dep.spec.name))
- def spec_rules(self, spec):
- self.fact(fn.node(spec.name))
- self.spec_versions(spec)
+ def spec_clauses(self, spec):
+ """Return a list of clauses the spec mandates are true.
+
+ Arguments:
+ spec (Spec): the spec to analyze
+ """
+ clauses = []
+
+ if spec.name:
+ clauses.append(fn.node(spec.name))
+
+ clauses.extend(self.spec_versions(spec))
# seed architecture at the root (we'll propagate later)
# TODO: use better semantics.
arch = spec.architecture
if arch:
if arch.platform:
- self.fact(fn.arch_platform_set(spec.name, arch.platform))
+ clauses.append(fn.arch_platform_set(spec.name, arch.platform))
if arch.os:
- self.fact(fn.arch_os_set(spec.name, arch.os))
+ clauses.append(fn.arch_os_set(spec.name, arch.os))
if arch.target:
- self.fact(fn.arch_target_set(spec.name, arch.target))
+ clauses.append(fn.arch_target_set(spec.name, arch.target))
# variants
for vname, variant in spec.variants.items():
- self.fact(fn.variant_set(spec.name, vname, variant.value))
+ clauses.append(fn.variant_set(spec.name, vname, variant.value))
# compiler and compiler version
if spec.compiler:
- self.fact(fn.node_compiler_set(spec.name, spec.compiler.name))
+ clauses.append(fn.node_compiler_set(spec.name, spec.compiler.name))
if spec.compiler.concrete:
- self.fact(fn.node_compiler_version_set(
+ clauses.append(fn.node_compiler_version_set(
spec.name, spec.compiler.name, spec.compiler.version))
# TODO
@@ -292,6 +322,22 @@ class AspGenerator(object):
# compiler_flags
# namespace
+ return clauses
+
+ def build_version_dict(self, possible_pkgs, specs):
+ """Declare any versions in specs not declared in packages."""
+ self.possible_versions = collections.defaultdict(lambda: set())
+
+ for pkg_name in possible_pkgs:
+ pkg = spack.repo.get(pkg_name)
+ for v in pkg.versions:
+ self.possible_versions[pkg_name].add(v)
+
+ for spec in specs:
+ for dep in spec.traverse():
+ if dep.versions.concrete:
+ self.possible_versions[dep.name].add(dep.version)
+
def arch_defaults(self):
"""Add facts about the default architecture for a package."""
self.h2('Default architecture')
@@ -310,13 +356,20 @@ class AspGenerator(object):
"""
# get list of all possible dependencies
pkg_names = set(spec.fullname for spec in specs)
- pkgs = [spack.repo.path.get_pkg_class(name) for name in pkg_names]
- pkgs = list(set(spack.package.possible_dependencies(*pkgs))
- | set(pkg_names))
+
+ possible = set()
+ for name in pkg_names:
+ pkg = spack.repo.path.get_pkg_class(name)
+ possible.update(pkg.possible_dependencies())
+
+ pkgs = set(possible) | set(pkg_names)
concretize_lp = pkgutil.get_data('spack.solver', 'concretize.lp')
self.out.write(concretize_lp.decode("utf-8"))
+ # traverse all specs and packages to build dict of possible versions
+ self.build_version_dict(possible, specs)
+
self.h1('General Constraints')
self.compiler_defaults()
self.arch_defaults()
@@ -330,7 +383,9 @@ class AspGenerator(object):
for spec in specs:
for dep in spec.traverse():
self.h2('Spec: %s' % str(dep))
- self.spec_rules(dep)
+ self.fact(fn.node(dep.name))
+ for clause in self.spec_clauses(dep):
+ self.rule(clause, fn.node(dep.name))
self.out.write('\n')
display_lp = pkgutil.get_data('spack.solver', 'display.lp')