summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDanny McClanahan <1305167+cosmicexplorer@users.noreply.github.com>2022-03-02 19:12:15 +0000
committerGitHub <noreply@github.com>2022-03-02 11:12:15 -0800
commit2c331a1d7f5e011781dcd258dfdb0e7f14841cab (patch)
treef5025db3c3e55d01910774a2af80636604cee283 /lib
parent916c94fd65b5d57914a25466b47efd05092fc88b (diff)
downloadspack-2c331a1d7f5e011781dcd258dfdb0e7f14841cab.tar.gz
spack-2c331a1d7f5e011781dcd258dfdb0e7f14841cab.tar.bz2
spack-2c331a1d7f5e011781dcd258dfdb0e7f14841cab.tar.xz
spack-2c331a1d7f5e011781dcd258dfdb0e7f14841cab.zip
make @llnl.util.lang.memoized support kwargs (#21722)
* make memoized() support kwargs * add testing for @memoized
Diffstat (limited to 'lib')
-rw-r--r--lib/spack/llnl/util/lang.py44
-rw-r--r--lib/spack/spack/test/llnl/util/lang.py62
2 files changed, 96 insertions, 10 deletions
diff --git a/lib/spack/llnl/util/lang.py b/lib/spack/llnl/util/lang.py
index c0b8c863d4..e474924507 100644
--- a/lib/spack/llnl/util/lang.py
+++ b/lib/spack/llnl/util/lang.py
@@ -13,9 +13,10 @@ import re
import sys
from datetime import datetime, timedelta
+import six
from six import string_types
-from llnl.util.compat import Hashable, MutableMapping, zip_longest
+from llnl.util.compat import MutableMapping, zip_longest
# Ignore emacs backups when listing modules
ignore_modules = [r'^\.#', '~$']
@@ -165,6 +166,19 @@ def union_dicts(*dicts):
return result
+# Used as a sentinel that disambiguates tuples passed in *args from coincidentally
+# matching tuples formed from kwargs item pairs.
+_kwargs_separator = (object(),)
+
+
+def stable_args(*args, **kwargs):
+ """A key factory that performs a stable sort of the parameters."""
+ key = args
+ if kwargs:
+ key += _kwargs_separator + tuple(sorted(kwargs.items()))
+ return key
+
+
def memoized(func):
"""Decorator that caches the results of a function, storing them in
an attribute of that function.
@@ -172,15 +186,23 @@ def memoized(func):
func.cache = {}
@functools.wraps(func)
- def _memoized_function(*args):
- if not isinstance(args, Hashable):
- # Not hashable, so just call the function.
- return func(*args)
+ def _memoized_function(*args, **kwargs):
+ key = stable_args(*args, **kwargs)
- if args not in func.cache:
- func.cache[args] = func(*args)
-
- return func.cache[args]
+ try:
+ return func.cache[key]
+ except KeyError:
+ ret = func(*args, **kwargs)
+ func.cache[key] = ret
+ return ret
+ except TypeError as e:
+ # TypeError is raised when indexing into a dict if the key is unhashable.
+ raise six.raise_from(
+ UnhashableArguments(
+ "args + kwargs '{}' was not hashable for function '{}'"
+ .format(key, func.__name__),
+ ),
+ e)
return _memoized_function
@@ -930,3 +952,7 @@ def nullcontext(*args, **kwargs):
TODO: replace with contextlib.nullcontext() if we ever require python 3.7.
"""
yield
+
+
+class UnhashableArguments(TypeError):
+ """Raise when an @memoized function receives unhashable arg or kwarg values."""
diff --git a/lib/spack/spack/test/llnl/util/lang.py b/lib/spack/spack/test/llnl/util/lang.py
index 8a2a03ee52..3fb6196c3a 100644
--- a/lib/spack/spack/test/llnl/util/lang.py
+++ b/lib/spack/spack/test/llnl/util/lang.py
@@ -10,7 +10,7 @@ from datetime import datetime, timedelta
import pytest
import llnl.util.lang
-from llnl.util.lang import match_predicate, pretty_date
+from llnl.util.lang import match_predicate, memoized, pretty_date, stable_args
@pytest.fixture()
@@ -205,3 +205,63 @@ def test_key_ordering():
assert hash(a) == hash(a2)
assert hash(b) == hash(b)
assert hash(b) == hash(b2)
+
+
+@pytest.mark.parametrize(
+ "args1,kwargs1,args2,kwargs2",
+ [
+ # Ensure tuples passed in args are disambiguated from equivalent kwarg items.
+ (('a', 3), {}, (), {'a': 3})
+ ],
+)
+def test_unequal_args(args1, kwargs1, args2, kwargs2):
+ assert stable_args(*args1, **kwargs1) != stable_args(*args2, **kwargs2)
+
+
+@pytest.mark.parametrize(
+ "args1,kwargs1,args2,kwargs2",
+ [
+ # Ensure that kwargs are stably sorted.
+ ((), {'a': 3, 'b': 4}, (), {'b': 4, 'a': 3}),
+ ],
+)
+def test_equal_args(args1, kwargs1, args2, kwargs2):
+ assert stable_args(*args1, **kwargs1) == stable_args(*args2, **kwargs2)
+
+
+@pytest.mark.parametrize(
+ "args, kwargs",
+ [
+ ((1,), {}),
+ ((), {'a': 3}),
+ ((1,), {'a': 3}),
+ ],
+)
+def test_memoized(args, kwargs):
+ @memoized
+ def f(*args, **kwargs):
+ return 'return-value'
+ assert f(*args, **kwargs) == 'return-value'
+ key = stable_args(*args, **kwargs)
+ assert list(f.cache.keys()) == [key]
+ assert f.cache[key] == 'return-value'
+
+
+@pytest.mark.parametrize(
+ "args, kwargs",
+ [
+ (([1],), {}),
+ ((), {'a': [1]})
+ ],
+)
+def test_memoized_unhashable(args, kwargs):
+ """Check that an exception is raised clearly"""
+ @memoized
+ def f(*args, **kwargs):
+ return None
+ with pytest.raises(llnl.util.lang.UnhashableArguments) as exc_info:
+ f(*args, **kwargs)
+ exc_msg = str(exc_info.value)
+ key = stable_args(*args, **kwargs)
+ assert str(key) in exc_msg
+ assert "function 'f'" in exc_msg