summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMassimiliano Culpo <massimiliano.culpo@gmail.com>2021-05-06 11:53:40 +0200
committerGitHub <noreply@github.com>2021-05-06 11:53:40 +0200
commit219eb09e59358b3fb80b1262d778a85bf69dba3c (patch)
tree833160de64732becdd62caa8728fa0ebb7ce81f4
parent8f1b7016603d508c2d1918e2c3df0c120928f17d (diff)
downloadspack-219eb09e59358b3fb80b1262d778a85bf69dba3c.tar.gz
spack-219eb09e59358b3fb80b1262d778a85bf69dba3c.tar.bz2
spack-219eb09e59358b3fb80b1262d778a85bf69dba3c.tar.xz
spack-219eb09e59358b3fb80b1262d778a85bf69dba3c.zip
Put a module object in sys.modules before executing module code (#23269)
The loading protocol mandates that the the module we are going to import needs to be already in sys.modules before its code is executed, so to prevent unbounded recursions and multiple loading. Loading a module from file exits early if the module is already in sys.modules
-rw-r--r--lib/spack/llnl/util/lang.py21
-rw-r--r--lib/spack/spack/test/llnl/util/lang.py20
2 files changed, 39 insertions, 2 deletions
diff --git a/lib/spack/llnl/util/lang.py b/lib/spack/llnl/util/lang.py
index 4097e4bf6f..b20cd91db4 100644
--- a/lib/spack/llnl/util/lang.py
+++ b/lib/spack/llnl/util/lang.py
@@ -804,6 +804,9 @@ class LazyReference(object):
def load_module_from_file(module_name, module_path):
"""Loads a python module from the path of the corresponding file.
+ If the module is already in ``sys.modules`` it will be returned as
+ is and not reloaded.
+
Args:
module_name (str): namespace where the python module will be loaded,
e.g. ``foo.bar``
@@ -816,12 +819,28 @@ def load_module_from_file(module_name, module_path):
ImportError: when the module can't be loaded
FileNotFoundError: when module_path doesn't exist
"""
+ if module_name in sys.modules:
+ return sys.modules[module_name]
+
+ # This recipe is adapted from https://stackoverflow.com/a/67692/771663
if sys.version_info[0] == 3 and sys.version_info[1] >= 5:
import importlib.util
spec = importlib.util.spec_from_file_location( # novm
module_name, module_path)
module = importlib.util.module_from_spec(spec) # novm
- spec.loader.exec_module(module)
+ # The module object needs to exist in sys.modules before the
+ # loader executes the module code.
+ #
+ # See https://docs.python.org/3/reference/import.html#loading
+ sys.modules[spec.name] = module
+ try:
+ spec.loader.exec_module(module)
+ except BaseException:
+ try:
+ del sys.modules[spec.name]
+ except KeyError:
+ pass
+ raise
elif sys.version_info[0] == 3 and sys.version_info[1] < 5:
import importlib.machinery
loader = importlib.machinery.SourceFileLoader( # novm
diff --git a/lib/spack/spack/test/llnl/util/lang.py b/lib/spack/spack/test/llnl/util/lang.py
index 6555a81825..99829f49d9 100644
--- a/lib/spack/spack/test/llnl/util/lang.py
+++ b/lib/spack/spack/test/llnl/util/lang.py
@@ -6,6 +6,7 @@
import pytest
import os.path
+import sys
from datetime import datetime, timedelta
import llnl.util.lang
@@ -27,7 +28,12 @@ value = 1
path = os.path.join('/usr', 'bin')
"""
m.write(content)
- return str(m)
+
+ yield str(m)
+
+ # Don't leave garbage in the module system
+ if 'foo' in sys.modules:
+ del sys.modules['foo']
def test_pretty_date():
@@ -127,10 +133,22 @@ def test_match_predicate():
def test_load_modules_from_file(module_path):
+ # Check prerequisites
+ assert 'foo' not in sys.modules
+
+ # Check that the module is loaded correctly from file
foo = llnl.util.lang.load_module_from_file('foo', module_path)
+ assert 'foo' in sys.modules
assert foo.value == 1
assert foo.path == os.path.join('/usr', 'bin')
+ # Check that the module is not reloaded a second time on subsequent calls
+ foo.value = 2
+ foo = llnl.util.lang.load_module_from_file('foo', module_path)
+ assert 'foo' in sys.modules
+ assert foo.value == 2
+ assert foo.path == os.path.join('/usr', 'bin')
+
def test_uniq():
assert [1, 2, 3] == llnl.util.lang.uniq([1, 2, 3])