summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/spack/llnl/util/lang.py21
-rw-r--r--lib/spack/spack/test/llnl/util/lang.py20
2 files changed, 39 insertions, 2 deletions
diff --git a/lib/spack/llnl/util/lang.py b/lib/spack/llnl/util/lang.py
index 4097e4bf6f..b20cd91db4 100644
--- a/lib/spack/llnl/util/lang.py
+++ b/lib/spack/llnl/util/lang.py
@@ -804,6 +804,9 @@ class LazyReference(object):
def load_module_from_file(module_name, module_path):
"""Loads a python module from the path of the corresponding file.
+ If the module is already in ``sys.modules`` it will be returned as
+ is and not reloaded.
+
Args:
module_name (str): namespace where the python module will be loaded,
e.g. ``foo.bar``
@@ -816,12 +819,28 @@ def load_module_from_file(module_name, module_path):
ImportError: when the module can't be loaded
FileNotFoundError: when module_path doesn't exist
"""
+ if module_name in sys.modules:
+ return sys.modules[module_name]
+
+ # This recipe is adapted from https://stackoverflow.com/a/67692/771663
if sys.version_info[0] == 3 and sys.version_info[1] >= 5:
import importlib.util
spec = importlib.util.spec_from_file_location( # novm
module_name, module_path)
module = importlib.util.module_from_spec(spec) # novm
- spec.loader.exec_module(module)
+ # The module object needs to exist in sys.modules before the
+ # loader executes the module code.
+ #
+ # See https://docs.python.org/3/reference/import.html#loading
+ sys.modules[spec.name] = module
+ try:
+ spec.loader.exec_module(module)
+ except BaseException:
+ try:
+ del sys.modules[spec.name]
+ except KeyError:
+ pass
+ raise
elif sys.version_info[0] == 3 and sys.version_info[1] < 5:
import importlib.machinery
loader = importlib.machinery.SourceFileLoader( # novm
diff --git a/lib/spack/spack/test/llnl/util/lang.py b/lib/spack/spack/test/llnl/util/lang.py
index 6555a81825..99829f49d9 100644
--- a/lib/spack/spack/test/llnl/util/lang.py
+++ b/lib/spack/spack/test/llnl/util/lang.py
@@ -6,6 +6,7 @@
import pytest
import os.path
+import sys
from datetime import datetime, timedelta
import llnl.util.lang
@@ -27,7 +28,12 @@ value = 1
path = os.path.join('/usr', 'bin')
"""
m.write(content)
- return str(m)
+
+ yield str(m)
+
+ # Don't leave garbage in the module system
+ if 'foo' in sys.modules:
+ del sys.modules['foo']
def test_pretty_date():
@@ -127,10 +133,22 @@ def test_match_predicate():
def test_load_modules_from_file(module_path):
+ # Check prerequisites
+ assert 'foo' not in sys.modules
+
+ # Check that the module is loaded correctly from file
foo = llnl.util.lang.load_module_from_file('foo', module_path)
+ assert 'foo' in sys.modules
assert foo.value == 1
assert foo.path == os.path.join('/usr', 'bin')
+ # Check that the module is not reloaded a second time on subsequent calls
+ foo.value = 2
+ foo = llnl.util.lang.load_module_from_file('foo', module_path)
+ assert 'foo' in sys.modules
+ assert foo.value == 2
+ assert foo.path == os.path.join('/usr', 'bin')
+
def test_uniq():
assert [1, 2, 3] == llnl.util.lang.uniq([1, 2, 3])