summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorHarmen Stoppels <harmenstoppels@gmail.com>2021-10-27 11:59:10 +0200
committerGitHub <noreply@github.com>2021-10-27 02:59:10 -0700
commite04b172eb00c72bdadcdba20d45169738828943a (patch)
treead75d7dce7b1c211affec45dad74ad847f428832 /lib
parent2fd87046cd49a75d60aabaf3e502373e88dbc187 (diff)
downloadspack-e04b172eb00c72bdadcdba20d45169738828943a.tar.gz
spack-e04b172eb00c72bdadcdba20d45169738828943a.tar.bz2
spack-e04b172eb00c72bdadcdba20d45169738828943a.tar.xz
spack-e04b172eb00c72bdadcdba20d45169738828943a.zip
Allow non-UTF-8 encoding in sbang hook (#26793)
Currently Spack reads full files containing shebangs to memory as strings, meaning Spack would have to guess their encoding. Currently Spack has a fixed guess of UTF-8. This is unnecessary, since e.g. the Linux kernel does not assume an encoding on paths at all, it's just bytes and some delimiters on the byte level. This commit does the following: 1. Shebangs are treated as bytes, so that e.g. latin1 encoded files do not throw UnicodeEncoding errors, and adds a test for this. 2. No more bytes than necessary are read to memory, we only have to read until the first newline, and from there on we an copy the file byte by bytes instead of decoding and re-encoding text. 3. We cap the number of bytes read to 4096, if no newline is found before that, we don't attempt to patch it. 4. Add support for luajit too. This should make Spack both more efficient and usable for non-UTF8 files.
Diffstat (limited to 'lib')
-rw-r--r--lib/spack/spack/hooks/sbang.py142
-rw-r--r--lib/spack/spack/test/sbang.py110
2 files changed, 181 insertions, 71 deletions
diff --git a/lib/spack/spack/hooks/sbang.py b/lib/spack/spack/hooks/sbang.py
index 8b99ee7873..1c18c46dc6 100644
--- a/lib/spack/spack/hooks/sbang.py
+++ b/lib/spack/spack/hooks/sbang.py
@@ -6,8 +6,10 @@
import filecmp
import os
import re
+import shutil
import stat
import sys
+import tempfile
import llnl.util.filesystem as fs
import llnl.util.tty as tty
@@ -19,9 +21,14 @@ import spack.store
#: Different Linux distributions have different limits, but 127 is the
#: smallest among all modern versions.
if sys.platform == 'darwin':
- shebang_limit = 511
+ system_shebang_limit = 511
else:
- shebang_limit = 127
+ system_shebang_limit = 127
+
+#: Spack itself also limits the shebang line to at most 4KB, which should be plenty.
+spack_shebang_limit = 4096
+
+interpreter_regex = re.compile(b'#![ \t]*?([^ \t\0\n]+)')
def sbang_install_path():
@@ -29,10 +36,10 @@ def sbang_install_path():
sbang_root = str(spack.store.unpadded_root)
install_path = os.path.join(sbang_root, "bin", "sbang")
path_length = len(install_path)
- if path_length > shebang_limit:
+ if path_length > system_shebang_limit:
msg = ('Install tree root is too long. Spack cannot patch shebang lines'
' when script path length ({0}) exceeds limit ({1}).\n {2}')
- msg = msg.format(path_length, shebang_limit, install_path)
+ msg = msg.format(path_length, system_shebang_limit, install_path)
raise SbangPathError(msg)
return install_path
@@ -49,71 +56,92 @@ def sbang_shebang_line():
return '#!/bin/sh %s' % sbang_install_path()
-def shebang_too_long(path):
- """Detects whether a file has a shebang line that is too long."""
- if not os.path.isfile(path):
- return False
-
- with open(path, 'rb') as script:
- bytes = script.read(2)
- if bytes != b'#!':
- return False
-
- line = bytes + script.readline()
- return len(line) > shebang_limit
+def get_interpreter(binary_string):
+ # The interpreter may be preceded with ' ' and \t, is itself any byte that
+ # follows until the first occurrence of ' ', \t, \0, \n or end of file.
+ match = interpreter_regex.match(binary_string)
+ return None if match is None else match.group(1)
def filter_shebang(path):
- """Adds a second shebang line, using sbang, at the beginning of a file."""
- with open(path, 'rb') as original_file:
- original = original_file.read()
- if sys.version_info >= (2, 7):
- original = original.decode(encoding='UTF-8')
- else:
- original = original.decode('UTF-8')
+ """
+ Adds a second shebang line, using sbang, at the beginning of a file, if necessary.
+ Note: Spack imposes a relaxed shebang line limit, meaning that a newline or end of
+ file must occur before ``spack_shebang_limit`` bytes. If not, the file is not
+ patched.
+ """
+ with open(path, 'rb') as original:
+ # If there is no shebang, we shouldn't replace anything.
+ old_shebang_line = original.read(2)
+ if old_shebang_line != b'#!':
+ return False
- # This line will be prepended to file
- new_sbang_line = '%s\n' % sbang_shebang_line()
+ # Stop reading after b'\n'. Note that old_shebang_line includes the first b'\n'.
+ old_shebang_line += original.readline(spack_shebang_limit - 2)
- # Skip files that are already using sbang.
- if original.startswith(new_sbang_line):
- return
+ # If the shebang line is short, we don't have to do anything.
+ if len(old_shebang_line) <= system_shebang_limit:
+ return False
- # In the following, newlines have to be excluded in the regular expression
- # else any mention of "lua" in the document will lead to spurious matches.
+ # Whenever we can't find a newline within the maximum number of bytes, we will
+ # not attempt to rewrite it. In principle we could still get the interpreter if
+ # only the arguments are truncated, but note that for PHP we need the full line
+ # since we have to append `?>` to it. Since our shebang limit is already very
+ # generous, it's unlikely to happen, and it should be fine to ignore.
+ if (
+ len(old_shebang_line) == spack_shebang_limit and
+ old_shebang_line[-1] != b'\n'
+ ):
+ return False
- # Use --! instead of #! on second line for lua.
- if re.search(r'^#!(/[^/\n]*)*lua\b', original):
- original = re.sub(r'^#', '--', original)
+ # This line will be prepended to file
+ new_sbang_line = (sbang_shebang_line() + '\n').encode('utf-8')
- # Use <?php #! instead of #! on second line for php.
- if re.search(r'^#!(/[^/\n]*)*php\b', original):
- original = re.sub(r'^#', '<?php #', original) + ' ?>'
+ # Skip files that are already using sbang.
+ if old_shebang_line == new_sbang_line:
+ return
- # Use //! instead of #! on second line for node.js.
- if re.search(r'^#!(/[^/\n]*)*node\b', original):
- original = re.sub(r'^#', '//', original)
+ interpreter = get_interpreter(old_shebang_line)
- # Change non-writable files to be writable if needed.
- saved_mode = None
- if not os.access(path, os.W_OK):
- st = os.stat(path)
- saved_mode = st.st_mode
- os.chmod(path, saved_mode | stat.S_IWRITE)
+ # If there was only whitespace we don't have to do anything.
+ if not interpreter:
+ return False
- with open(path, 'wb') as new_file:
- if sys.version_info >= (2, 7):
- new_file.write(new_sbang_line.encode(encoding='UTF-8'))
- new_file.write(original.encode(encoding='UTF-8'))
+ # Store the file permissions, the patched version needs the same.
+ saved_mode = os.stat(path).st_mode
+
+ # No need to delete since we'll move it and overwrite the original.
+ patched = tempfile.NamedTemporaryFile('wb', delete=False)
+ patched.write(new_sbang_line)
+
+ # Note that in Python this does not go out of bounds even if interpreter is a
+ # short byte array.
+ # Note: if the interpreter string was encoded with UTF-16, there would have
+ # been a \0 byte between all characters of lua, node, php; meaning that it would
+ # lead to truncation of the interpreter. So we don't have to worry about weird
+ # encodings here, and just looking at bytes is justified.
+ if interpreter[-4:] == b'/lua' or interpreter[-7:] == b'/luajit':
+ # Use --! instead of #! on second line for lua.
+ patched.write(b'--!' + old_shebang_line[2:])
+ elif interpreter[-5:] == b'/node':
+ # Use //! instead of #! on second line for node.js.
+ patched.write(b'//!' + old_shebang_line[2:])
+ elif interpreter[-4:] == b'/php':
+ # Use <?php #!... ?> instead of #!... on second line for php.
+ patched.write(b'<?php ' + old_shebang_line + b' ?>')
else:
- new_file.write(new_sbang_line.encode('UTF-8'))
- new_file.write(original.encode('UTF-8'))
+ patched.write(old_shebang_line)
+
+ # After copying the remainder of the file, we can close the original
+ shutil.copyfileobj(original, patched)
- # Restore original permissions.
- if saved_mode is not None:
- os.chmod(path, saved_mode)
+ # And close the temporary file so we can move it.
+ patched.close()
- tty.debug("Patched overlong shebang in %s" % path)
+ # Overwrite original file with patched file, and keep the original mode
+ shutil.move(patched.name, path)
+ os.chmod(path, saved_mode)
+ return True
def filter_shebangs_in_directory(directory, filenames=None):
@@ -138,8 +166,8 @@ def filter_shebangs_in_directory(directory, filenames=None):
continue
# test the file for a long shebang, and filter
- if shebang_too_long(path):
- filter_shebang(path)
+ if filter_shebang(path):
+ tty.debug("Patched overlong shebang in %s" % path)
def install_sbang():
diff --git a/lib/spack/spack/test/sbang.py b/lib/spack/spack/test/sbang.py
index 4eb3b07a60..e9beebc43d 100644
--- a/lib/spack/spack/test/sbang.py
+++ b/lib/spack/spack/test/sbang.py
@@ -21,7 +21,7 @@ import spack.paths
import spack.store
from spack.util.executable import which
-too_long = sbang.shebang_limit + 1
+too_long = sbang.system_shebang_limit + 1
short_line = "#!/this/is/short/bin/bash\n"
@@ -31,6 +31,10 @@ lua_line = "#!/this/" + ('x' * too_long) + "/is/lua\n"
lua_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100)
lua_line_patched = "--!/this/" + ('x' * too_long) + "/is/lua\n"
+luajit_line = "#!/this/" + ('x' * too_long) + "/is/luajit\n"
+luajit_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100)
+luajit_line_patched = "--!/this/" + ('x' * too_long) + "/is/luajit\n"
+
node_line = "#!/this/" + ('x' * too_long) + "/is/node\n"
node_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100)
node_line_patched = "//!/this/" + ('x' * too_long) + "/is/node\n"
@@ -84,7 +88,7 @@ class ScriptDirectory(object):
f.write(last_line)
self.make_executable(self.lua_shebang)
- # Lua script with long shebang
+ # Lua occurring in text, not in shebang
self.lua_textbang = os.path.join(self.tempdir, 'lua_in_text')
with open(self.lua_textbang, 'w') as f:
f.write(short_line)
@@ -92,6 +96,21 @@ class ScriptDirectory(object):
f.write(last_line)
self.make_executable(self.lua_textbang)
+ # Luajit script with long shebang
+ self.luajit_shebang = os.path.join(self.tempdir, 'luajit')
+ with open(self.luajit_shebang, 'w') as f:
+ f.write(luajit_line)
+ f.write(last_line)
+ self.make_executable(self.luajit_shebang)
+
+ # Luajit occuring in text, not in shebang
+ self.luajit_textbang = os.path.join(self.tempdir, 'luajit_in_text')
+ with open(self.luajit_textbang, 'w') as f:
+ f.write(short_line)
+ f.write(luajit_in_text)
+ f.write(last_line)
+ self.make_executable(self.luajit_textbang)
+
# Node script with long shebang
self.node_shebang = os.path.join(self.tempdir, 'node')
with open(self.node_shebang, 'w') as f:
@@ -99,7 +118,7 @@ class ScriptDirectory(object):
f.write(last_line)
self.make_executable(self.node_shebang)
- # Node script with long shebang
+ # Node occuring in text, not in shebang
self.node_textbang = os.path.join(self.tempdir, 'node_in_text')
with open(self.node_textbang, 'w') as f:
f.write(short_line)
@@ -114,7 +133,7 @@ class ScriptDirectory(object):
f.write(last_line)
self.make_executable(self.php_shebang)
- # php script with long shebang
+ # php occuring in text, not in shebang
self.php_textbang = os.path.join(self.tempdir, 'php_in_text')
with open(self.php_textbang, 'w') as f:
f.write(short_line)
@@ -157,16 +176,22 @@ def script_dir(sbang_line):
sdir.destroy()
-def test_shebang_handling(script_dir, sbang_line):
- assert sbang.shebang_too_long(script_dir.lua_shebang)
- assert sbang.shebang_too_long(script_dir.long_shebang)
- assert sbang.shebang_too_long(script_dir.nonexec_long_shebang)
+@pytest.mark.parametrize('shebang,interpreter', [
+ (b'#!/path/to/interpreter argument\n', b'/path/to/interpreter'),
+ (b'#! /path/to/interpreter truncated-argum', b'/path/to/interpreter'),
+ (b'#! \t \t/path/to/interpreter\t \targument', b'/path/to/interpreter'),
+ (b'#! \t \t /path/to/interpreter', b'/path/to/interpreter'),
+ (b'#!/path/to/interpreter\0', b'/path/to/interpreter'),
+ (b'#!/path/to/interpreter multiple args\n', b'/path/to/interpreter'),
+ (b'#!\0/path/to/interpreter arg\n', None),
+ (b'#!\n/path/to/interpreter arg\n', None),
+ (b'#!', None)
+])
+def test_shebang_interpreter_regex(shebang, interpreter):
+ sbang.get_interpreter(shebang) == interpreter
- assert not sbang.shebang_too_long(script_dir.short_shebang)
- assert not sbang.shebang_too_long(script_dir.has_sbang)
- assert not sbang.shebang_too_long(script_dir.binary)
- assert not sbang.shebang_too_long(script_dir.directory)
+def test_shebang_handling(script_dir, sbang_line):
sbang.filter_shebangs_in_directory(script_dir.tempdir)
# Make sure this is untouched
@@ -192,6 +217,12 @@ def test_shebang_handling(script_dir, sbang_line):
assert f.readline() == last_line
# Make sure this got patched.
+ with open(script_dir.luajit_shebang, 'r') as f:
+ assert f.readline() == sbang_line
+ assert f.readline() == luajit_line_patched
+ assert f.readline() == last_line
+
+ # Make sure this got patched.
with open(script_dir.node_shebang, 'r') as f:
assert f.readline() == sbang_line
assert f.readline() == node_line_patched
@@ -199,8 +230,12 @@ def test_shebang_handling(script_dir, sbang_line):
assert filecmp.cmp(script_dir.lua_textbang,
os.path.join(script_dir.tempdir, 'lua_in_text'))
+ assert filecmp.cmp(script_dir.luajit_textbang,
+ os.path.join(script_dir.tempdir, 'luajit_in_text'))
assert filecmp.cmp(script_dir.node_textbang,
os.path.join(script_dir.tempdir, 'node_in_text'))
+ assert filecmp.cmp(script_dir.php_textbang,
+ os.path.join(script_dir.tempdir, 'php_in_text'))
# Make sure this is untouched
with open(script_dir.has_sbang, 'r') as f:
@@ -261,7 +296,7 @@ def test_install_sbang(install_mockery):
def test_install_sbang_too_long(tmpdir):
root = str(tmpdir)
- num_extend = sbang.shebang_limit - len(root) - len('/bin/sbang')
+ num_extend = sbang.system_shebang_limit - len(root) - len('/bin/sbang')
long_path = root
while num_extend > 1:
add = min(num_extend, 255)
@@ -282,7 +317,7 @@ def test_sbang_hook_skips_nonexecutable_blobs(tmpdir):
# consisting of invalid UTF-8. The latter is technically not really necessary for
# the test, but binary blobs accidentally starting with b'#!' usually do not contain
# valid UTF-8, so we also ensure that Spack does not attempt to decode as UTF-8.
- contents = b'#!' + b'\x80' * sbang.shebang_limit
+ contents = b'#!' + b'\x80' * sbang.system_shebang_limit
file = str(tmpdir.join('non-executable.sh'))
with open(file, 'wb') as f:
f.write(contents)
@@ -292,3 +327,50 @@ def test_sbang_hook_skips_nonexecutable_blobs(tmpdir):
# Make sure there is no sbang shebang.
with open(file, 'rb') as f:
assert b'sbang' not in f.readline()
+
+
+def test_sbang_handles_non_utf8_files(tmpdir):
+ # We have an executable with a copyright sign as filename
+ contents = (b'#!' + b'\xa9' * sbang.system_shebang_limit +
+ b'\nand another symbol: \xa9')
+
+ # Make sure it's indeed valid latin1 but invalid utf-8.
+ assert contents.decode('latin1')
+ with pytest.raises(UnicodeDecodeError):
+ contents.decode('utf-8')
+
+ # Put it in an executable file
+ file = str(tmpdir.join('latin1.sh'))
+ with open(file, 'wb') as f:
+ f.write(contents)
+
+ # Run sbang
+ assert sbang.filter_shebang(file)
+
+ with open(file, 'rb') as f:
+ new_contents = f.read()
+
+ assert contents in new_contents
+ assert b'sbang' in new_contents
+
+
+@pytest.fixture
+def shebang_limits_system_8_spack_16():
+ system_limit, sbang.system_shebang_limit = sbang.system_shebang_limit, 8
+ spack_limit, sbang.spack_shebang_limit = sbang.spack_shebang_limit, 16
+ yield
+ sbang.system_shebang_limit = system_limit
+ sbang.spack_shebang_limit = spack_limit
+
+
+def test_shebang_exceeds_spack_shebang_limit(shebang_limits_system_8_spack_16, tmpdir):
+ """Tests whether shebangs longer than Spack's limit are skipped"""
+ file = str(tmpdir.join('longer_than_spack_limit.sh'))
+ with open(file, 'wb') as f:
+ f.write(b'#!' + b'x' * sbang.spack_shebang_limit)
+
+ # Then Spack shouldn't try to add a shebang
+ assert not sbang.filter_shebang(file)
+
+ with open(file, 'rb') as f:
+ assert b'sbang' not in f.read()