diff options
-rw-r--r-- | lib/spack/spack/hooks/sbang.py | 142 | ||||
-rw-r--r-- | lib/spack/spack/test/sbang.py | 110 |
2 files changed, 181 insertions, 71 deletions
diff --git a/lib/spack/spack/hooks/sbang.py b/lib/spack/spack/hooks/sbang.py index 8b99ee7873..1c18c46dc6 100644 --- a/lib/spack/spack/hooks/sbang.py +++ b/lib/spack/spack/hooks/sbang.py @@ -6,8 +6,10 @@ import filecmp import os import re +import shutil import stat import sys +import tempfile import llnl.util.filesystem as fs import llnl.util.tty as tty @@ -19,9 +21,14 @@ import spack.store #: Different Linux distributions have different limits, but 127 is the #: smallest among all modern versions. if sys.platform == 'darwin': - shebang_limit = 511 + system_shebang_limit = 511 else: - shebang_limit = 127 + system_shebang_limit = 127 + +#: Spack itself also limits the shebang line to at most 4KB, which should be plenty. +spack_shebang_limit = 4096 + +interpreter_regex = re.compile(b'#![ \t]*?([^ \t\0\n]+)') def sbang_install_path(): @@ -29,10 +36,10 @@ def sbang_install_path(): sbang_root = str(spack.store.unpadded_root) install_path = os.path.join(sbang_root, "bin", "sbang") path_length = len(install_path) - if path_length > shebang_limit: + if path_length > system_shebang_limit: msg = ('Install tree root is too long. Spack cannot patch shebang lines' ' when script path length ({0}) exceeds limit ({1}).\n {2}') - msg = msg.format(path_length, shebang_limit, install_path) + msg = msg.format(path_length, system_shebang_limit, install_path) raise SbangPathError(msg) return install_path @@ -49,71 +56,92 @@ def sbang_shebang_line(): return '#!/bin/sh %s' % sbang_install_path() -def shebang_too_long(path): - """Detects whether a file has a shebang line that is too long.""" - if not os.path.isfile(path): - return False - - with open(path, 'rb') as script: - bytes = script.read(2) - if bytes != b'#!': - return False - - line = bytes + script.readline() - return len(line) > shebang_limit +def get_interpreter(binary_string): + # The interpreter may be preceded with ' ' and \t, is itself any byte that + # follows until the first occurrence of ' ', \t, \0, \n or end of file. + match = interpreter_regex.match(binary_string) + return None if match is None else match.group(1) def filter_shebang(path): - """Adds a second shebang line, using sbang, at the beginning of a file.""" - with open(path, 'rb') as original_file: - original = original_file.read() - if sys.version_info >= (2, 7): - original = original.decode(encoding='UTF-8') - else: - original = original.decode('UTF-8') + """ + Adds a second shebang line, using sbang, at the beginning of a file, if necessary. + Note: Spack imposes a relaxed shebang line limit, meaning that a newline or end of + file must occur before ``spack_shebang_limit`` bytes. If not, the file is not + patched. + """ + with open(path, 'rb') as original: + # If there is no shebang, we shouldn't replace anything. + old_shebang_line = original.read(2) + if old_shebang_line != b'#!': + return False - # This line will be prepended to file - new_sbang_line = '%s\n' % sbang_shebang_line() + # Stop reading after b'\n'. Note that old_shebang_line includes the first b'\n'. + old_shebang_line += original.readline(spack_shebang_limit - 2) - # Skip files that are already using sbang. - if original.startswith(new_sbang_line): - return + # If the shebang line is short, we don't have to do anything. + if len(old_shebang_line) <= system_shebang_limit: + return False - # In the following, newlines have to be excluded in the regular expression - # else any mention of "lua" in the document will lead to spurious matches. + # Whenever we can't find a newline within the maximum number of bytes, we will + # not attempt to rewrite it. In principle we could still get the interpreter if + # only the arguments are truncated, but note that for PHP we need the full line + # since we have to append `?>` to it. Since our shebang limit is already very + # generous, it's unlikely to happen, and it should be fine to ignore. + if ( + len(old_shebang_line) == spack_shebang_limit and + old_shebang_line[-1] != b'\n' + ): + return False - # Use --! instead of #! on second line for lua. - if re.search(r'^#!(/[^/\n]*)*lua\b', original): - original = re.sub(r'^#', '--', original) + # This line will be prepended to file + new_sbang_line = (sbang_shebang_line() + '\n').encode('utf-8') - # Use <?php #! instead of #! on second line for php. - if re.search(r'^#!(/[^/\n]*)*php\b', original): - original = re.sub(r'^#', '<?php #', original) + ' ?>' + # Skip files that are already using sbang. + if old_shebang_line == new_sbang_line: + return - # Use //! instead of #! on second line for node.js. - if re.search(r'^#!(/[^/\n]*)*node\b', original): - original = re.sub(r'^#', '//', original) + interpreter = get_interpreter(old_shebang_line) - # Change non-writable files to be writable if needed. - saved_mode = None - if not os.access(path, os.W_OK): - st = os.stat(path) - saved_mode = st.st_mode - os.chmod(path, saved_mode | stat.S_IWRITE) + # If there was only whitespace we don't have to do anything. + if not interpreter: + return False - with open(path, 'wb') as new_file: - if sys.version_info >= (2, 7): - new_file.write(new_sbang_line.encode(encoding='UTF-8')) - new_file.write(original.encode(encoding='UTF-8')) + # Store the file permissions, the patched version needs the same. + saved_mode = os.stat(path).st_mode + + # No need to delete since we'll move it and overwrite the original. + patched = tempfile.NamedTemporaryFile('wb', delete=False) + patched.write(new_sbang_line) + + # Note that in Python this does not go out of bounds even if interpreter is a + # short byte array. + # Note: if the interpreter string was encoded with UTF-16, there would have + # been a \0 byte between all characters of lua, node, php; meaning that it would + # lead to truncation of the interpreter. So we don't have to worry about weird + # encodings here, and just looking at bytes is justified. + if interpreter[-4:] == b'/lua' or interpreter[-7:] == b'/luajit': + # Use --! instead of #! on second line for lua. + patched.write(b'--!' + old_shebang_line[2:]) + elif interpreter[-5:] == b'/node': + # Use //! instead of #! on second line for node.js. + patched.write(b'//!' + old_shebang_line[2:]) + elif interpreter[-4:] == b'/php': + # Use <?php #!... ?> instead of #!... on second line for php. + patched.write(b'<?php ' + old_shebang_line + b' ?>') else: - new_file.write(new_sbang_line.encode('UTF-8')) - new_file.write(original.encode('UTF-8')) + patched.write(old_shebang_line) + + # After copying the remainder of the file, we can close the original + shutil.copyfileobj(original, patched) - # Restore original permissions. - if saved_mode is not None: - os.chmod(path, saved_mode) + # And close the temporary file so we can move it. + patched.close() - tty.debug("Patched overlong shebang in %s" % path) + # Overwrite original file with patched file, and keep the original mode + shutil.move(patched.name, path) + os.chmod(path, saved_mode) + return True def filter_shebangs_in_directory(directory, filenames=None): @@ -138,8 +166,8 @@ def filter_shebangs_in_directory(directory, filenames=None): continue # test the file for a long shebang, and filter - if shebang_too_long(path): - filter_shebang(path) + if filter_shebang(path): + tty.debug("Patched overlong shebang in %s" % path) def install_sbang(): diff --git a/lib/spack/spack/test/sbang.py b/lib/spack/spack/test/sbang.py index 4eb3b07a60..e9beebc43d 100644 --- a/lib/spack/spack/test/sbang.py +++ b/lib/spack/spack/test/sbang.py @@ -21,7 +21,7 @@ import spack.paths import spack.store from spack.util.executable import which -too_long = sbang.shebang_limit + 1 +too_long = sbang.system_shebang_limit + 1 short_line = "#!/this/is/short/bin/bash\n" @@ -31,6 +31,10 @@ lua_line = "#!/this/" + ('x' * too_long) + "/is/lua\n" lua_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100) lua_line_patched = "--!/this/" + ('x' * too_long) + "/is/lua\n" +luajit_line = "#!/this/" + ('x' * too_long) + "/is/luajit\n" +luajit_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100) +luajit_line_patched = "--!/this/" + ('x' * too_long) + "/is/luajit\n" + node_line = "#!/this/" + ('x' * too_long) + "/is/node\n" node_in_text = ("line\n") * 100 + "lua\n" + ("line\n" * 100) node_line_patched = "//!/this/" + ('x' * too_long) + "/is/node\n" @@ -84,7 +88,7 @@ class ScriptDirectory(object): f.write(last_line) self.make_executable(self.lua_shebang) - # Lua script with long shebang + # Lua occurring in text, not in shebang self.lua_textbang = os.path.join(self.tempdir, 'lua_in_text') with open(self.lua_textbang, 'w') as f: f.write(short_line) @@ -92,6 +96,21 @@ class ScriptDirectory(object): f.write(last_line) self.make_executable(self.lua_textbang) + # Luajit script with long shebang + self.luajit_shebang = os.path.join(self.tempdir, 'luajit') + with open(self.luajit_shebang, 'w') as f: + f.write(luajit_line) + f.write(last_line) + self.make_executable(self.luajit_shebang) + + # Luajit occuring in text, not in shebang + self.luajit_textbang = os.path.join(self.tempdir, 'luajit_in_text') + with open(self.luajit_textbang, 'w') as f: + f.write(short_line) + f.write(luajit_in_text) + f.write(last_line) + self.make_executable(self.luajit_textbang) + # Node script with long shebang self.node_shebang = os.path.join(self.tempdir, 'node') with open(self.node_shebang, 'w') as f: @@ -99,7 +118,7 @@ class ScriptDirectory(object): f.write(last_line) self.make_executable(self.node_shebang) - # Node script with long shebang + # Node occuring in text, not in shebang self.node_textbang = os.path.join(self.tempdir, 'node_in_text') with open(self.node_textbang, 'w') as f: f.write(short_line) @@ -114,7 +133,7 @@ class ScriptDirectory(object): f.write(last_line) self.make_executable(self.php_shebang) - # php script with long shebang + # php occuring in text, not in shebang self.php_textbang = os.path.join(self.tempdir, 'php_in_text') with open(self.php_textbang, 'w') as f: f.write(short_line) @@ -157,16 +176,22 @@ def script_dir(sbang_line): sdir.destroy() -def test_shebang_handling(script_dir, sbang_line): - assert sbang.shebang_too_long(script_dir.lua_shebang) - assert sbang.shebang_too_long(script_dir.long_shebang) - assert sbang.shebang_too_long(script_dir.nonexec_long_shebang) +@pytest.mark.parametrize('shebang,interpreter', [ + (b'#!/path/to/interpreter argument\n', b'/path/to/interpreter'), + (b'#! /path/to/interpreter truncated-argum', b'/path/to/interpreter'), + (b'#! \t \t/path/to/interpreter\t \targument', b'/path/to/interpreter'), + (b'#! \t \t /path/to/interpreter', b'/path/to/interpreter'), + (b'#!/path/to/interpreter\0', b'/path/to/interpreter'), + (b'#!/path/to/interpreter multiple args\n', b'/path/to/interpreter'), + (b'#!\0/path/to/interpreter arg\n', None), + (b'#!\n/path/to/interpreter arg\n', None), + (b'#!', None) +]) +def test_shebang_interpreter_regex(shebang, interpreter): + sbang.get_interpreter(shebang) == interpreter - assert not sbang.shebang_too_long(script_dir.short_shebang) - assert not sbang.shebang_too_long(script_dir.has_sbang) - assert not sbang.shebang_too_long(script_dir.binary) - assert not sbang.shebang_too_long(script_dir.directory) +def test_shebang_handling(script_dir, sbang_line): sbang.filter_shebangs_in_directory(script_dir.tempdir) # Make sure this is untouched @@ -192,6 +217,12 @@ def test_shebang_handling(script_dir, sbang_line): assert f.readline() == last_line # Make sure this got patched. + with open(script_dir.luajit_shebang, 'r') as f: + assert f.readline() == sbang_line + assert f.readline() == luajit_line_patched + assert f.readline() == last_line + + # Make sure this got patched. with open(script_dir.node_shebang, 'r') as f: assert f.readline() == sbang_line assert f.readline() == node_line_patched @@ -199,8 +230,12 @@ def test_shebang_handling(script_dir, sbang_line): assert filecmp.cmp(script_dir.lua_textbang, os.path.join(script_dir.tempdir, 'lua_in_text')) + assert filecmp.cmp(script_dir.luajit_textbang, + os.path.join(script_dir.tempdir, 'luajit_in_text')) assert filecmp.cmp(script_dir.node_textbang, os.path.join(script_dir.tempdir, 'node_in_text')) + assert filecmp.cmp(script_dir.php_textbang, + os.path.join(script_dir.tempdir, 'php_in_text')) # Make sure this is untouched with open(script_dir.has_sbang, 'r') as f: @@ -261,7 +296,7 @@ def test_install_sbang(install_mockery): def test_install_sbang_too_long(tmpdir): root = str(tmpdir) - num_extend = sbang.shebang_limit - len(root) - len('/bin/sbang') + num_extend = sbang.system_shebang_limit - len(root) - len('/bin/sbang') long_path = root while num_extend > 1: add = min(num_extend, 255) @@ -282,7 +317,7 @@ def test_sbang_hook_skips_nonexecutable_blobs(tmpdir): # consisting of invalid UTF-8. The latter is technically not really necessary for # the test, but binary blobs accidentally starting with b'#!' usually do not contain # valid UTF-8, so we also ensure that Spack does not attempt to decode as UTF-8. - contents = b'#!' + b'\x80' * sbang.shebang_limit + contents = b'#!' + b'\x80' * sbang.system_shebang_limit file = str(tmpdir.join('non-executable.sh')) with open(file, 'wb') as f: f.write(contents) @@ -292,3 +327,50 @@ def test_sbang_hook_skips_nonexecutable_blobs(tmpdir): # Make sure there is no sbang shebang. with open(file, 'rb') as f: assert b'sbang' not in f.readline() + + +def test_sbang_handles_non_utf8_files(tmpdir): + # We have an executable with a copyright sign as filename + contents = (b'#!' + b'\xa9' * sbang.system_shebang_limit + + b'\nand another symbol: \xa9') + + # Make sure it's indeed valid latin1 but invalid utf-8. + assert contents.decode('latin1') + with pytest.raises(UnicodeDecodeError): + contents.decode('utf-8') + + # Put it in an executable file + file = str(tmpdir.join('latin1.sh')) + with open(file, 'wb') as f: + f.write(contents) + + # Run sbang + assert sbang.filter_shebang(file) + + with open(file, 'rb') as f: + new_contents = f.read() + + assert contents in new_contents + assert b'sbang' in new_contents + + +@pytest.fixture +def shebang_limits_system_8_spack_16(): + system_limit, sbang.system_shebang_limit = sbang.system_shebang_limit, 8 + spack_limit, sbang.spack_shebang_limit = sbang.spack_shebang_limit, 16 + yield + sbang.system_shebang_limit = system_limit + sbang.spack_shebang_limit = spack_limit + + +def test_shebang_exceeds_spack_shebang_limit(shebang_limits_system_8_spack_16, tmpdir): + """Tests whether shebangs longer than Spack's limit are skipped""" + file = str(tmpdir.join('longer_than_spack_limit.sh')) + with open(file, 'wb') as f: + f.write(b'#!' + b'x' * sbang.spack_shebang_limit) + + # Then Spack shouldn't try to add a shebang + assert not sbang.filter_shebang(file) + + with open(file, 'rb') as f: + assert b'sbang' not in f.read() |