From 98577e3af5cf571d5408fcb1da32de3586238ead Mon Sep 17 00:00:00 2001 From: Todd Gamblin Date: Sat, 21 Dec 2019 16:23:54 -0800 Subject: lock transactions: fix non-transactional writes Lock transactions were actually writing *after* the lock was released. The code was looking at the result of `release_write()` before writing, then writing based on whether the lock was released. This is pretty obviously wrong. - [x] Refactor `Lock` so that a release function can be passed to the `Lock` and called *only* when a lock is really released. - [x] Refactor `LockTransaction` classes to use the release function instead of checking the return value of `release_read()` / `release_write()` --- lib/spack/llnl/util/lock.py | 100 ++++++---- lib/spack/spack/database.py | 7 +- lib/spack/spack/test/llnl/util/lock.py | 340 ++++++++++++++++++++------------- lib/spack/spack/util/file_cache.py | 10 +- 4 files changed, 285 insertions(+), 172 deletions(-) diff --git a/lib/spack/llnl/util/lock.py b/lib/spack/llnl/util/lock.py index 66cb067c88..c675c7c452 100644 --- a/lib/spack/llnl/util/lock.py +++ b/lib/spack/llnl/util/lock.py @@ -95,10 +95,6 @@ class Lock(object): The lock is implemented as a spin lock using a nonblocking call to ``lockf()``. - On acquiring an exclusive lock, the lock writes this process's - pid and host to the lock file, in case the holding process needs - to be killed later. - If the lock times out, it raises a ``LockError``. If the lock is successfully acquired, the total wait time and the number of attempts is returned. @@ -284,11 +280,19 @@ class Lock(object): self._writes += 1 return False - def release_read(self): + def release_read(self, release_fn=None): """Releases a read lock. - Returns True if the last recursive lock was released, False if - there are still outstanding locks. + Arguments: + release_fn (callable): function to call *before* the last recursive + lock (read or write) is released. + + If the last recursive lock will be released, then this will call + release_fn and return its result (if provided), or return True + (if release_fn was not provided). + + Otherwise, we are still nested inside some other lock, so do not + call the release_fn and, return False. Does limited correctness checking: if a read lock is released when none are held, this will raise an assertion error. @@ -300,18 +304,30 @@ class Lock(object): self._debug( 'READ LOCK: {0.path}[{0._start}:{0._length}] [Released]' .format(self)) + + result = True + if release_fn is not None: + result = release_fn() + self._unlock() # can raise LockError. self._reads -= 1 - return True + return result else: self._reads -= 1 return False - def release_write(self): + def release_write(self, release_fn=None): """Releases a write lock. - Returns True if the last recursive lock was released, False if - there are still outstanding locks. + Arguments: + release_fn (callable): function to call before the last recursive + write is released. + + If the last recursive *write* lock will be released, then this + will call release_fn and return its result (if provided), or + return True (if release_fn was not provided). Otherwise, we are + still nested inside some other write lock, so do not call the + release_fn, and return False. Does limited correctness checking: if a read lock is released when none are held, this will raise an assertion error. @@ -323,9 +339,16 @@ class Lock(object): self._debug( 'WRITE LOCK: {0.path}[{0._start}:{0._length}] [Released]' .format(self)) + + # we need to call release_fn before releasing the lock + result = True + if release_fn is not None: + result = release_fn() + self._unlock() # can raise LockError. self._writes -= 1 - return True + return result + else: self._writes -= 1 return False @@ -349,28 +372,36 @@ class Lock(object): class LockTransaction(object): """Simple nested transaction context manager that uses a file lock. - This class can trigger actions when the lock is acquired for the - first time and released for the last. + Arguments: + lock (Lock): underlying lock for this transaction to be accquired on + enter and released on exit + acquire (callable or contextmanager): function to be called after lock + is acquired, or contextmanager to enter after acquire and leave + before release. + release (callable): function to be called before release. If + ``acquire`` is a contextmanager, this will be called *after* + exiting the nexted context and before the lock is released. + timeout (float): number of seconds to set for the timeout when + accquiring the lock (default no timeout) If the ``acquire_fn`` returns a value, it is used as the return value for ``__enter__``, allowing it to be passed as the ``as`` argument of a ``with`` statement. If ``acquire_fn`` returns a context manager, *its* ``__enter__`` function - will be called in ``__enter__`` after ``acquire_fn``, and its ``__exit__`` - funciton will be called before ``release_fn`` in ``__exit__``, allowing you - to nest a context manager to be used along with the lock. + will be called after the lock is acquired, and its ``__exit__`` funciton + will be called before ``release_fn`` in ``__exit__``, allowing you to + nest a context manager inside this one. Timeout for lock is customizable. """ - def __init__(self, lock, acquire_fn=None, release_fn=None, - timeout=None): + def __init__(self, lock, acquire=None, release=None, timeout=None): self._lock = lock self._timeout = timeout - self._acquire_fn = acquire_fn - self._release_fn = release_fn + self._acquire_fn = acquire + self._release_fn = release self._as = None def __enter__(self): @@ -383,13 +414,18 @@ class LockTransaction(object): def __exit__(self, type, value, traceback): suppress = False - if self._exit(): - if self._as and hasattr(self._as, '__exit__'): - if self._as.__exit__(type, value, traceback): - suppress = True - if self._release_fn: - if self._release_fn(type, value, traceback): - suppress = True + + def release_fn(): + if self._release_fn is not None: + return self._release_fn(type, value, traceback) + + if self._as and hasattr(self._as, '__exit__'): + if self._as.__exit__(type, value, traceback): + suppress = True + + if self._exit(release_fn): + suppress = True + return suppress @@ -398,8 +434,8 @@ class ReadTransaction(LockTransaction): def _enter(self): return self._lock.acquire_read(self._timeout) - def _exit(self): - return self._lock.release_read() + def _exit(self, release_fn): + return self._lock.release_read(release_fn) class WriteTransaction(LockTransaction): @@ -407,8 +443,8 @@ class WriteTransaction(LockTransaction): def _enter(self): return self._lock.acquire_write(self._timeout) - def _exit(self): - return self._lock.release_write() + def _exit(self, release_fn): + return self._lock.release_write(release_fn) class LockError(Exception): diff --git a/lib/spack/spack/database.py b/lib/spack/spack/database.py index e6e82f9803..243f1a20d5 100644 --- a/lib/spack/spack/database.py +++ b/lib/spack/spack/database.py @@ -332,11 +332,12 @@ class Database(object): def write_transaction(self): """Get a write lock context manager for use in a `with` block.""" - return WriteTransaction(self.lock, self._read, self._write) + return WriteTransaction( + self.lock, acquire=self._read, release=self._write) def read_transaction(self): """Get a read lock context manager for use in a `with` block.""" - return ReadTransaction(self.lock, self._read) + return ReadTransaction(self.lock, acquire=self._read) def prefix_lock(self, spec): """Get a lock on a particular spec's installation directory. @@ -624,7 +625,7 @@ class Database(object): self._data = {} transaction = WriteTransaction( - self.lock, _read_suppress_error, self._write + self.lock, acquire=_read_suppress_error, release=self._write ) with transaction: diff --git a/lib/spack/spack/test/llnl/util/lock.py b/lib/spack/spack/test/llnl/util/lock.py index d8081d108c..3bf8a236b1 100644 --- a/lib/spack/spack/test/llnl/util/lock.py +++ b/lib/spack/spack/test/llnl/util/lock.py @@ -42,6 +42,7 @@ node-local filesystem, and multi-node tests will fail if the locks aren't actually on a shared filesystem. """ +import collections import os import socket import shutil @@ -776,189 +777,258 @@ def test_complex_acquire_and_release_chain(lock_path): multiproc_test(p1, p2, p3) -def test_transaction(lock_path): +class AssertLock(lk.Lock): + """Test lock class that marks acquire/release events.""" + def __init__(self, lock_path, vals): + super(AssertLock, self).__init__(lock_path) + self.vals = vals + + def acquire_read(self, timeout=None): + self.assert_acquire_read() + result = super(AssertLock, self).acquire_read(timeout) + self.vals['acquired_read'] = True + return result + + def acquire_write(self, timeout=None): + self.assert_acquire_write() + result = super(AssertLock, self).acquire_write(timeout) + self.vals['acquired_write'] = True + return result + + def release_read(self, release_fn=None): + self.assert_release_read() + result = super(AssertLock, self).release_read(release_fn) + self.vals['released_read'] = True + return result + + def release_write(self, release_fn=None): + self.assert_release_write() + result = super(AssertLock, self).release_write(release_fn) + self.vals['released_write'] = True + return result + + +@pytest.mark.parametrize( + "transaction,type", + [(lk.ReadTransaction, "read"), (lk.WriteTransaction, "write")] +) +def test_transaction(lock_path, transaction, type): + class MockLock(AssertLock): + def assert_acquire_read(self): + assert not vals['entered_fn'] + assert not vals['exited_fn'] + + def assert_release_read(self): + assert vals['entered_fn'] + assert not vals['exited_fn'] + + def assert_acquire_write(self): + assert not vals['entered_fn'] + assert not vals['exited_fn'] + + def assert_release_write(self): + assert vals['entered_fn'] + assert not vals['exited_fn'] + def enter_fn(): - vals['entered'] = True + # assert enter_fn is called while lock is held + assert vals['acquired_%s' % type] + vals['entered_fn'] = True def exit_fn(t, v, tb): - vals['exited'] = True + # assert exit_fn is called while lock is held + assert not vals['released_%s' % type] + vals['exited_fn'] = True vals['exception'] = (t or v or tb) - lock = lk.Lock(lock_path) - vals = {'entered': False, 'exited': False, 'exception': False} - with lk.ReadTransaction(lock, enter_fn, exit_fn): - pass + vals = collections.defaultdict(lambda: False) + lock = MockLock(lock_path, vals) + + with transaction(lock, acquire=enter_fn, release=exit_fn): + assert vals['acquired_%s' % type] + assert not vals['released_%s' % type] - assert vals['entered'] - assert vals['exited'] + assert vals['entered_fn'] + assert vals['exited_fn'] + assert vals['acquired_%s' % type] + assert vals['released_%s' % type] assert not vals['exception'] - vals = {'entered': False, 'exited': False, 'exception': False} - with lk.WriteTransaction(lock, enter_fn, exit_fn): - pass - assert vals['entered'] - assert vals['exited'] - assert not vals['exception'] +@pytest.mark.parametrize( + "transaction,type", + [(lk.ReadTransaction, "read"), (lk.WriteTransaction, "write")] +) +def test_transaction_with_exception(lock_path, transaction, type): + class MockLock(AssertLock): + def assert_acquire_read(self): + assert not vals['entered_fn'] + assert not vals['exited_fn'] + + def assert_release_read(self): + assert vals['entered_fn'] + assert not vals['exited_fn'] + def assert_acquire_write(self): + assert not vals['entered_fn'] + assert not vals['exited_fn'] + + def assert_release_write(self): + assert vals['entered_fn'] + assert not vals['exited_fn'] -def test_transaction_with_exception(lock_path): def enter_fn(): - vals['entered'] = True + assert vals['acquired_%s' % type] + vals['entered_fn'] = True def exit_fn(t, v, tb): - vals['exited'] = True + assert not vals['released_%s' % type] + vals['exited_fn'] = True vals['exception'] = (t or v or tb) + return exit_result - lock = lk.Lock(lock_path) - - def do_read_with_exception(): - with lk.ReadTransaction(lock, enter_fn, exit_fn): - raise Exception() - - def do_write_with_exception(): - with lk.WriteTransaction(lock, enter_fn, exit_fn): - raise Exception() + exit_result = False + vals = collections.defaultdict(lambda: False) + lock = MockLock(lock_path, vals) - vals = {'entered': False, 'exited': False, 'exception': False} with pytest.raises(Exception): - do_read_with_exception() - assert vals['entered'] - assert vals['exited'] - assert vals['exception'] + with transaction(lock, acquire=enter_fn, release=exit_fn): + raise Exception() - vals = {'entered': False, 'exited': False, 'exception': False} - with pytest.raises(Exception): - do_write_with_exception() - assert vals['entered'] - assert vals['exited'] + assert vals['entered_fn'] + assert vals['exited_fn'] assert vals['exception'] + # test suppression of exceptions from exit_fn + exit_result = True + vals.clear() -def test_transaction_with_context_manager(lock_path): - class TestContextManager(object): - - def __enter__(self): - vals['entered'] = True - - def __exit__(self, t, v, tb): - vals['exited'] = True - vals['exception'] = (t or v or tb) - - def exit_fn(t, v, tb): - vals['exited_fn'] = True - vals['exception_fn'] = (t or v or tb) - - lock = lk.Lock(lock_path) - - vals = {'entered': False, 'exited': False, 'exited_fn': False, - 'exception': False, 'exception_fn': False} - with lk.ReadTransaction(lock, TestContextManager, exit_fn): - pass + # should not raise now. + with transaction(lock, acquire=enter_fn, release=exit_fn): + raise Exception() - assert vals['entered'] - assert vals['exited'] - assert not vals['exception'] + assert vals['entered_fn'] assert vals['exited_fn'] - assert not vals['exception_fn'] - - vals = {'entered': False, 'exited': False, 'exited_fn': False, - 'exception': False, 'exception_fn': False} - with lk.ReadTransaction(lock, TestContextManager): - pass - - assert vals['entered'] - assert vals['exited'] - assert not vals['exception'] - assert not vals['exited_fn'] - assert not vals['exception_fn'] + assert vals['exception'] - vals = {'entered': False, 'exited': False, 'exited_fn': False, - 'exception': False, 'exception_fn': False} - with lk.WriteTransaction(lock, TestContextManager, exit_fn): - pass - assert vals['entered'] - assert vals['exited'] - assert not vals['exception'] - assert vals['exited_fn'] - assert not vals['exception_fn'] +@pytest.mark.parametrize( + "transaction,type", + [(lk.ReadTransaction, "read"), (lk.WriteTransaction, "write")] +) +def test_transaction_with_context_manager(lock_path, transaction, type): + class MockLock(AssertLock): + def assert_acquire_read(self): + assert not vals['entered_ctx'] + assert not vals['exited_ctx'] - vals = {'entered': False, 'exited': False, 'exited_fn': False, - 'exception': False, 'exception_fn': False} - with lk.WriteTransaction(lock, TestContextManager): - pass + def assert_release_read(self): + assert vals['entered_ctx'] + assert vals['exited_ctx'] - assert vals['entered'] - assert vals['exited'] - assert not vals['exception'] - assert not vals['exited_fn'] - assert not vals['exception_fn'] + def assert_acquire_write(self): + assert not vals['entered_ctx'] + assert not vals['exited_ctx'] + def assert_release_write(self): + assert vals['entered_ctx'] + assert vals['exited_ctx'] -def test_transaction_with_context_manager_and_exception(lock_path): class TestContextManager(object): def __enter__(self): - vals['entered'] = True + vals['entered_ctx'] = True def __exit__(self, t, v, tb): - vals['exited'] = True - vals['exception'] = (t or v or tb) + assert not vals['released_%s' % type] + vals['exited_ctx'] = True + vals['exception_ctx'] = (t or v or tb) + return exit_ctx_result def exit_fn(t, v, tb): + assert not vals['released_%s' % type] vals['exited_fn'] = True vals['exception_fn'] = (t or v or tb) + return exit_fn_result - lock = lk.Lock(lock_path) - - def do_read_with_exception(exit_fn): - with lk.ReadTransaction(lock, TestContextManager, exit_fn): - raise Exception() + exit_fn_result, exit_ctx_result = False, False + vals = collections.defaultdict(lambda: False) + lock = MockLock(lock_path, vals) - def do_write_with_exception(exit_fn): - with lk.WriteTransaction(lock, TestContextManager, exit_fn): - raise Exception() + with transaction(lock, acquire=TestContextManager, release=exit_fn): + pass - vals = {'entered': False, 'exited': False, 'exited_fn': False, - 'exception': False, 'exception_fn': False} - with pytest.raises(Exception): - do_read_with_exception(exit_fn) - assert vals['entered'] - assert vals['exited'] - assert vals['exception'] + assert vals['entered_ctx'] + assert vals['exited_ctx'] assert vals['exited_fn'] - assert vals['exception_fn'] - - vals = {'entered': False, 'exited': False, 'exited_fn': False, - 'exception': False, 'exception_fn': False} - with pytest.raises(Exception): - do_read_with_exception(None) - assert vals['entered'] - assert vals['exited'] - assert vals['exception'] - assert not vals['exited_fn'] + assert not vals['exception_ctx'] assert not vals['exception_fn'] - vals = {'entered': False, 'exited': False, 'exited_fn': False, - 'exception': False, 'exception_fn': False} - with pytest.raises(Exception): - do_write_with_exception(exit_fn) - assert vals['entered'] - assert vals['exited'] - assert vals['exception'] - assert vals['exited_fn'] - assert vals['exception_fn'] + vals.clear() + with transaction(lock, acquire=TestContextManager): + pass - vals = {'entered': False, 'exited': False, 'exited_fn': False, - 'exception': False, 'exception_fn': False} - with pytest.raises(Exception): - do_write_with_exception(None) - assert vals['entered'] - assert vals['exited'] - assert vals['exception'] + assert vals['entered_ctx'] + assert vals['exited_ctx'] assert not vals['exited_fn'] + assert not vals['exception_ctx'] assert not vals['exception_fn'] + # below are tests for exceptions with and without suppression + def assert_ctx_and_fn_exception(raises=True): + vals.clear() + + if raises: + with pytest.raises(Exception): + with transaction( + lock, acquire=TestContextManager, release=exit_fn): + raise Exception() + else: + with transaction( + lock, acquire=TestContextManager, release=exit_fn): + raise Exception() + + assert vals['entered_ctx'] + assert vals['exited_ctx'] + assert vals['exited_fn'] + assert vals['exception_ctx'] + assert vals['exception_fn'] + + def assert_only_ctx_exception(raises=True): + vals.clear() + + if raises: + with pytest.raises(Exception): + with transaction(lock, acquire=TestContextManager): + raise Exception() + else: + with transaction(lock, acquire=TestContextManager): + raise Exception() + + assert vals['entered_ctx'] + assert vals['exited_ctx'] + assert not vals['exited_fn'] + assert vals['exception_ctx'] + assert not vals['exception_fn'] + + # no suppression + assert_ctx_and_fn_exception(raises=True) + assert_only_ctx_exception(raises=True) + + # suppress exception only in function + exit_fn_result, exit_ctx_result = True, False + assert_ctx_and_fn_exception(raises=False) + assert_only_ctx_exception(raises=True) + + # suppress exception only in context + exit_fn_result, exit_ctx_result = False, True + assert_ctx_and_fn_exception(raises=False) + assert_only_ctx_exception(raises=False) + + # suppress exception in function and context + exit_fn_result, exit_ctx_result = True, True + assert_ctx_and_fn_exception(raises=False) + assert_only_ctx_exception(raises=False) + def test_lock_debug_output(lock_path): host = socket.getfqdn() diff --git a/lib/spack/spack/util/file_cache.py b/lib/spack/spack/util/file_cache.py index d56f2b33c5..0227edf155 100644 --- a/lib/spack/spack/util/file_cache.py +++ b/lib/spack/spack/util/file_cache.py @@ -107,7 +107,8 @@ class FileCache(object): """ return ReadTransaction( - self._get_lock(key), lambda: open(self.cache_path(key))) + self._get_lock(key), acquire=lambda: open(self.cache_path(key)) + ) def write_transaction(self, key): """Get a write transaction on a file cache item. @@ -117,6 +118,10 @@ class FileCache(object): moves the file into place on top of the old file atomically. """ + # TODO: this nested context manager adds a lot of complexity and + # TODO: is pretty hard to reason about in llnl.util.lock. At some + # TODO: point we should just replace it with functions and simplify + # TODO: the locking code. class WriteContextManager(object): def __enter__(cm): # noqa @@ -142,7 +147,8 @@ class FileCache(object): else: os.rename(cm.tmp_filename, cm.orig_filename) - return WriteTransaction(self._get_lock(key), WriteContextManager) + return WriteTransaction( + self._get_lock(key), acquire=WriteContextManager) def mtime(self, key): """Return modification time of cache file, or 0 if it does not exist. -- cgit v1.2.3-70-g09d2