From 2c331a1d7f5e011781dcd258dfdb0e7f14841cab Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Wed, 2 Mar 2022 19:12:15 +0000 Subject: make @llnl.util.lang.memoized support kwargs (#21722) * make memoized() support kwargs * add testing for @memoized --- lib/spack/llnl/util/lang.py | 44 +++++++++++++++++++----- lib/spack/spack/test/llnl/util/lang.py | 62 +++++++++++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 10 deletions(-) (limited to 'lib') diff --git a/lib/spack/llnl/util/lang.py b/lib/spack/llnl/util/lang.py index c0b8c863d4..e474924507 100644 --- a/lib/spack/llnl/util/lang.py +++ b/lib/spack/llnl/util/lang.py @@ -13,9 +13,10 @@ import re import sys from datetime import datetime, timedelta +import six from six import string_types -from llnl.util.compat import Hashable, MutableMapping, zip_longest +from llnl.util.compat import MutableMapping, zip_longest # Ignore emacs backups when listing modules ignore_modules = [r'^\.#', '~$'] @@ -165,6 +166,19 @@ def union_dicts(*dicts): return result +# Used as a sentinel that disambiguates tuples passed in *args from coincidentally +# matching tuples formed from kwargs item pairs. +_kwargs_separator = (object(),) + + +def stable_args(*args, **kwargs): + """A key factory that performs a stable sort of the parameters.""" + key = args + if kwargs: + key += _kwargs_separator + tuple(sorted(kwargs.items())) + return key + + def memoized(func): """Decorator that caches the results of a function, storing them in an attribute of that function. @@ -172,15 +186,23 @@ def memoized(func): func.cache = {} @functools.wraps(func) - def _memoized_function(*args): - if not isinstance(args, Hashable): - # Not hashable, so just call the function. - return func(*args) + def _memoized_function(*args, **kwargs): + key = stable_args(*args, **kwargs) - if args not in func.cache: - func.cache[args] = func(*args) - - return func.cache[args] + try: + return func.cache[key] + except KeyError: + ret = func(*args, **kwargs) + func.cache[key] = ret + return ret + except TypeError as e: + # TypeError is raised when indexing into a dict if the key is unhashable. + raise six.raise_from( + UnhashableArguments( + "args + kwargs '{}' was not hashable for function '{}'" + .format(key, func.__name__), + ), + e) return _memoized_function @@ -930,3 +952,7 @@ def nullcontext(*args, **kwargs): TODO: replace with contextlib.nullcontext() if we ever require python 3.7. """ yield + + +class UnhashableArguments(TypeError): + """Raise when an @memoized function receives unhashable arg or kwarg values.""" diff --git a/lib/spack/spack/test/llnl/util/lang.py b/lib/spack/spack/test/llnl/util/lang.py index 8a2a03ee52..3fb6196c3a 100644 --- a/lib/spack/spack/test/llnl/util/lang.py +++ b/lib/spack/spack/test/llnl/util/lang.py @@ -10,7 +10,7 @@ from datetime import datetime, timedelta import pytest import llnl.util.lang -from llnl.util.lang import match_predicate, pretty_date +from llnl.util.lang import match_predicate, memoized, pretty_date, stable_args @pytest.fixture() @@ -205,3 +205,63 @@ def test_key_ordering(): assert hash(a) == hash(a2) assert hash(b) == hash(b) assert hash(b) == hash(b2) + + +@pytest.mark.parametrize( + "args1,kwargs1,args2,kwargs2", + [ + # Ensure tuples passed in args are disambiguated from equivalent kwarg items. + (('a', 3), {}, (), {'a': 3}) + ], +) +def test_unequal_args(args1, kwargs1, args2, kwargs2): + assert stable_args(*args1, **kwargs1) != stable_args(*args2, **kwargs2) + + +@pytest.mark.parametrize( + "args1,kwargs1,args2,kwargs2", + [ + # Ensure that kwargs are stably sorted. + ((), {'a': 3, 'b': 4}, (), {'b': 4, 'a': 3}), + ], +) +def test_equal_args(args1, kwargs1, args2, kwargs2): + assert stable_args(*args1, **kwargs1) == stable_args(*args2, **kwargs2) + + +@pytest.mark.parametrize( + "args, kwargs", + [ + ((1,), {}), + ((), {'a': 3}), + ((1,), {'a': 3}), + ], +) +def test_memoized(args, kwargs): + @memoized + def f(*args, **kwargs): + return 'return-value' + assert f(*args, **kwargs) == 'return-value' + key = stable_args(*args, **kwargs) + assert list(f.cache.keys()) == [key] + assert f.cache[key] == 'return-value' + + +@pytest.mark.parametrize( + "args, kwargs", + [ + (([1],), {}), + ((), {'a': [1]}) + ], +) +def test_memoized_unhashable(args, kwargs): + """Check that an exception is raised clearly""" + @memoized + def f(*args, **kwargs): + return None + with pytest.raises(llnl.util.lang.UnhashableArguments) as exc_info: + f(*args, **kwargs) + exc_msg = str(exc_info.value) + key = stable_args(*args, **kwargs) + assert str(key) in exc_msg + assert "function 'f'" in exc_msg -- cgit v1.2.3-60-g2f50