diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/spack/llnl/util/lang.py | 21 | ||||
-rw-r--r-- | lib/spack/spack/test/llnl/util/lang.py | 20 |
2 files changed, 39 insertions, 2 deletions
diff --git a/lib/spack/llnl/util/lang.py b/lib/spack/llnl/util/lang.py index 4097e4bf6f..b20cd91db4 100644 --- a/lib/spack/llnl/util/lang.py +++ b/lib/spack/llnl/util/lang.py @@ -804,6 +804,9 @@ class LazyReference(object): def load_module_from_file(module_name, module_path): """Loads a python module from the path of the corresponding file. + If the module is already in ``sys.modules`` it will be returned as + is and not reloaded. + Args: module_name (str): namespace where the python module will be loaded, e.g. ``foo.bar`` @@ -816,12 +819,28 @@ def load_module_from_file(module_name, module_path): ImportError: when the module can't be loaded FileNotFoundError: when module_path doesn't exist """ + if module_name in sys.modules: + return sys.modules[module_name] + + # This recipe is adapted from https://stackoverflow.com/a/67692/771663 if sys.version_info[0] == 3 and sys.version_info[1] >= 5: import importlib.util spec = importlib.util.spec_from_file_location( # novm module_name, module_path) module = importlib.util.module_from_spec(spec) # novm - spec.loader.exec_module(module) + # The module object needs to exist in sys.modules before the + # loader executes the module code. + # + # See https://docs.python.org/3/reference/import.html#loading + sys.modules[spec.name] = module + try: + spec.loader.exec_module(module) + except BaseException: + try: + del sys.modules[spec.name] + except KeyError: + pass + raise elif sys.version_info[0] == 3 and sys.version_info[1] < 5: import importlib.machinery loader = importlib.machinery.SourceFileLoader( # novm diff --git a/lib/spack/spack/test/llnl/util/lang.py b/lib/spack/spack/test/llnl/util/lang.py index 6555a81825..99829f49d9 100644 --- a/lib/spack/spack/test/llnl/util/lang.py +++ b/lib/spack/spack/test/llnl/util/lang.py @@ -6,6 +6,7 @@ import pytest import os.path +import sys from datetime import datetime, timedelta import llnl.util.lang @@ -27,7 +28,12 @@ value = 1 path = os.path.join('/usr', 'bin') """ m.write(content) - return str(m) + + yield str(m) + + # Don't leave garbage in the module system + if 'foo' in sys.modules: + del sys.modules['foo'] def test_pretty_date(): @@ -127,10 +133,22 @@ def test_match_predicate(): def test_load_modules_from_file(module_path): + # Check prerequisites + assert 'foo' not in sys.modules + + # Check that the module is loaded correctly from file foo = llnl.util.lang.load_module_from_file('foo', module_path) + assert 'foo' in sys.modules assert foo.value == 1 assert foo.path == os.path.join('/usr', 'bin') + # Check that the module is not reloaded a second time on subsequent calls + foo.value = 2 + foo = llnl.util.lang.load_module_from_file('foo', module_path) + assert 'foo' in sys.modules + assert foo.value == 2 + assert foo.path == os.path.join('/usr', 'bin') + def test_uniq(): assert [1, 2, 3] == llnl.util.lang.uniq([1, 2, 3]) |