summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorTodd Gamblin <tgamblin@llnl.gov>2015-02-15 01:58:35 -0800
committerTodd Gamblin <tgamblin@llnl.gov>2015-02-15 01:58:35 -0800
commitc0c08799249fb56c281f62b3659e7cf7d7080188 (patch)
tree1df0dc57d97d85b947356c1b0cab7da6f591a902 /lib
parent82dc935a50874e899380f32a9a35b7cc4f76df87 (diff)
downloadspack-c0c08799249fb56c281f62b3659e7cf7d7080188.tar.gz
spack-c0c08799249fb56c281f62b3659e7cf7d7080188.tar.bz2
spack-c0c08799249fb56c281f62b3659e7cf7d7080188.tar.xz
spack-c0c08799249fb56c281f62b3659e7cf7d7080188.zip
Better extension activation/deactivation
Diffstat (limited to 'lib')
-rw-r--r--lib/spack/spack/cmd/extensions.py4
-rw-r--r--lib/spack/spack/directory_layout.py150
-rw-r--r--lib/spack/spack/package.py19
3 files changed, 125 insertions, 48 deletions
diff --git a/lib/spack/spack/cmd/extensions.py b/lib/spack/spack/cmd/extensions.py
index ae73d8ac55..fc8e6842c3 100644
--- a/lib/spack/spack/cmd/extensions.py
+++ b/lib/spack/spack/cmd/extensions.py
@@ -89,10 +89,10 @@ def extensions(parser, args):
spack.cmd.find.display_specs(installed, mode=args.mode)
# List specs of activated extensions.
- activated = spack.install_layout.get_extensions(spec)
+ activated = spack.install_layout.extension_map(spec)
print
if not activated:
tty.msg("None activated.")
return
tty.msg("%d currently activated:" % len(activated))
- spack.cmd.find.display_specs(activated, mode=args.mode)
+ spack.cmd.find.display_specs(activated.values(), mode=args.mode)
diff --git a/lib/spack/spack/directory_layout.py b/lib/spack/spack/directory_layout.py
index 37740720a2..562c0bd3ed 100644
--- a/lib/spack/spack/directory_layout.py
+++ b/lib/spack/spack/directory_layout.py
@@ -27,6 +27,7 @@ import os
import exceptions
import hashlib
import shutil
+import tempfile
from contextlib import closing
import llnl.util.tty as tty
@@ -84,17 +85,38 @@ class DirectoryLayout(object):
raise NotImplementedError()
- def get_extensions(self, spec):
- """Get a set of currently installed extension packages for a spec."""
+ def extension_map(self, spec):
+ """Get a dict of currently installed extension packages for a spec.
+
+ Dict maps { name : extension_spec }
+ Modifying dict does not affect internals of this layout.
+ """
+ raise NotImplementedError()
+
+
+ def check_extension_conflict(self, spec, ext_spec):
+ """Ensure that ext_spec can be activated in spec.
+
+ If not, raise ExtensionAlreadyInstalledError or
+ ExtensionConflictError.
+ """
+ raise NotImplementedError()
+
+
+ def check_activated(self, spec, ext_spec):
+ """Ensure that ext_spec can be removed from spec.
+
+ If not, raise NoSuchExtensionError.
+ """
raise NotImplementedError()
- def add_extension(self, spec, extension_spec):
+ def add_extension(self, spec, ext_spec):
"""Add to the list of currently installed extensions."""
raise NotImplementedError()
- def remove_extension(self, spec, extension_spec):
+ def remove_extension(self, spec, ext_spec):
"""Remove from the list of currently installed extensions."""
raise NotImplementedError()
@@ -173,6 +195,8 @@ class SpecHashDirectoryLayout(DirectoryLayout):
self.spec_file_name = spec_file_name
self.extension_file_name = extension_file_name
+ # Cache of already written/read extension maps.
+ self._extension_maps = {}
@property
def hidden_file_paths(self):
@@ -271,54 +295,94 @@ class SpecHashDirectoryLayout(DirectoryLayout):
return join_path(self.path_for_spec(spec), self.extension_file_name)
- def get_extensions(self, spec):
+ def _extension_map(self, spec):
+ """Get a dict<name -> spec> for all extensions currnetly
+ installed for this package."""
_check_concrete(spec)
- extensions = set()
- path = self.extension_file_path(spec)
- if os.path.exists(path):
- with closing(open(path)) as ext_file:
- for line in ext_file:
- try:
- extensions.add(Spec(line.strip()))
- except spack.error.SpackError, e:
- raise InvalidExtensionSpecError(str(e))
- return extensions
+ if not spec in self._extension_maps:
+ path = self.extension_file_path(spec)
+ if not os.path.exists(path):
+ self._extension_maps[spec] = {}
+
+ else:
+ exts = {}
+ with closing(open(path)) as ext_file:
+ for line in ext_file:
+ try:
+ spec = Spec(line.strip())
+ exts[spec.name] = spec
+ except spack.error.SpackError, e:
+ # TODO: do something better here -- should be
+ # resilient to corrupt files.
+ raise InvalidExtensionSpecError(str(e))
+ self._extension_maps[spec] = exts
+
+ return self._extension_maps[spec]
+
+
+ def extension_map(self, spec):
+ """Defensive copying version of _extension_map() for external API."""
+ return self._extension_map(spec).copy()
+
+
+ def check_extension_conflict(self, spec, ext_spec):
+ exts = self._extension_map(spec)
+ if ext_spec.name in exts:
+ installed_spec = exts[ext_spec.name]
+ if ext_spec == installed_spec:
+ raise ExtensionAlreadyInstalledError(spec, ext_spec)
+ else:
+ raise ExtensionConflictError(spec, ext_spec, installed_spec)
+
+ def check_activated(self, spec, ext_spec):
+ exts = self._extension_map(spec)
+ if (not ext_spec.name in exts) or (ext_spec != exts[ext_spec.name]):
+ raise NoSuchExtensionError(spec, ext_spec)
- def write_extensions(self, spec, extensions):
+
+ def _write_extensions(self, spec, extensions):
path = self.extension_file_path(spec)
- with closing(open(path, 'w')) as spec_file:
- for extension in sorted(extensions):
- spec_file.write("%s\n" % extension)
+
+ # Create a temp file in the same directory as the actual file.
+ dirname, basename = os.path.split(path)
+ tmp = tempfile.NamedTemporaryFile(
+ prefix=basename, dir=dirname, delete=False)
+
+ # Write temp file.
+ with closing(tmp):
+ for extension in sorted(extensions.values()):
+ tmp.write("%s\n" % extension)
+
+ # Atomic update by moving tmpfile on top of old one.
+ os.rename(tmp.name, path)
- def add_extension(self, spec, extension_spec):
+ def add_extension(self, spec, ext_spec):
_check_concrete(spec)
- _check_concrete(extension_spec)
+ _check_concrete(ext_spec)
- exts = self.get_extensions(spec)
- if extension_spec in exts:
- raise ExtensionAlreadyInstalledError(spec, extension_spec)
- else:
- for already_installed in exts:
- if spec.name == extension_spec.name:
- raise ExtensionConflictError(spec, extension_spec, already_installed)
+ # Check whether it's already installed or if it's a conflict.
+ exts = self.extension_map(spec)
+ self.check_extension_conflict(spec, ext_spec)
- exts.add(extension_spec)
- self.write_extensions(spec, exts)
+ # do the actual adding.
+ exts[ext_spec.name] = ext_spec
+ self._write_extensions(spec, exts)
- def remove_extension(self, spec, extension_spec):
+ def remove_extension(self, spec, ext_spec):
_check_concrete(spec)
- _check_concrete(extension_spec)
+ _check_concrete(ext_spec)
- exts = self.get_extensions(spec)
- if not extension_spec in exts:
- raise NoSuchExtensionError(spec, extension_spec)
+ # Make sure it's installed before removing.
+ exts = self.extension_map(spec)
+ self.check_activated(spec, ext_spec)
- exts.remove(extension_spec)
- self.write_extensions(spec, exts)
+ # do the actual removing.
+ del exts[ext_spec.name]
+ self._write_extensions(spec, exts)
class DirectoryLayoutError(SpackError):
@@ -365,24 +429,24 @@ class InvalidExtensionSpecError(DirectoryLayoutError):
class ExtensionAlreadyInstalledError(DirectoryLayoutError):
"""Raised when an extension is added to a package that already has it."""
- def __init__(self, spec, extension_spec):
+ def __init__(self, spec, ext_spec):
super(ExtensionAlreadyInstalledError, self).__init__(
- "%s is already installed in %s" % (extension_spec.short_spec, spec.short_spec))
+ "%s is already installed in %s" % (ext_spec.short_spec, spec.short_spec))
class ExtensionConflictError(DirectoryLayoutError):
"""Raised when an extension is added to a package that already has it."""
- def __init__(self, spec, extension_spec, conflict):
+ def __init__(self, spec, ext_spec, conflict):
super(ExtensionConflictError, self).__init__(
"%s cannot be installed in %s because it conflicts with %s."% (
- extension_spec.short_spec, spec.short_spec, conflict.short_spec))
+ ext_spec.short_spec, spec.short_spec, conflict.short_spec))
class NoSuchExtensionError(DirectoryLayoutError):
"""Raised when an extension isn't there on remove."""
- def __init__(self, spec, extension_spec):
+ def __init__(self, spec, ext_spec):
super(NoSuchExtensionError, self).__init__(
"%s cannot be removed from %s because it's not installed."% (
- extension_spec.short_spec, spec.short_spec))
+ ext_spec.short_spec, spec.short_spec))
diff --git a/lib/spack/spack/package.py b/lib/spack/spack/package.py
index b18d054990..a624c1ebf5 100644
--- a/lib/spack/spack/package.py
+++ b/lib/spack/spack/package.py
@@ -534,7 +534,8 @@ class Package(object):
if not self.is_extension:
raise ValueError("is_extension called on package that is not an extension.")
- return self.spec in spack.install_layout.get_extensions(self.extendee_spec)
+ exts = spack.install_layout.extension_map(self.extendee_spec)
+ return (self.name in exts) and (exts[self.name] == self.spec)
def preorder_traversal(self, visited=None, **kwargs):
@@ -987,6 +988,8 @@ class Package(object):
activate() directly.
"""
self._sanity_check_extension()
+ spack.install_layout.check_extension_conflict(self.extendee_spec, self.spec)
+
self.extendee_spec.package.activate(self, **self.extendee_args)
spack.install_layout.add_extension(self.extendee_spec, self.spec)
@@ -1014,12 +1017,22 @@ class Package(object):
tree.merge(self.prefix, ignore=ignore)
- def do_deactivate(self):
+ def do_deactivate(self, **kwargs):
"""Called on the extension to invoke extendee's deactivate() method."""
+ force = kwargs.get('force', False)
+
self._sanity_check_extension()
+
+ # Allow a force deactivate to happen. This can unlink
+ # spurious files if something was corrupted.
+ if not force:
+ spack.install_layout.check_activated(self.extendee_spec, self.spec)
+
self.extendee_spec.package.deactivate(self, **self.extendee_args)
- if self.spec in spack.install_layout.get_extensions(self.extendee_spec):
+ # redundant activation check -- makes SURE the spec is not
+ # still activated even if something was wrong above.
+ if self.activated:
spack.install_layout.remove_extension(self.extendee_spec, self.spec)
tty.msg("Deactivated extension %s for %s."