diff options
-rw-r--r-- | lib/spack/spack/hooks/sbang.py | 7 | ||||
-rw-r--r-- | lib/spack/spack/test/sbang.py | 118 |
2 files changed, 76 insertions, 49 deletions
diff --git a/lib/spack/spack/hooks/sbang.py b/lib/spack/spack/hooks/sbang.py index 17f6ac2528..09691b3f0f 100644 --- a/lib/spack/spack/hooks/sbang.py +++ b/lib/spack/spack/hooks/sbang.py @@ -38,9 +38,12 @@ shebang_limit = 127 def shebang_too_long(path): """Detects whether a file has a shebang line that is too long.""" - with open(path, 'r') as script: + if not os.path.isfile(path): + return False + + with open(path, 'rb') as script: bytes = script.read(2) - if bytes != '#!': + if bytes != b'#!': return False line = bytes + script.readline() diff --git a/lib/spack/spack/test/sbang.py b/lib/spack/spack/test/sbang.py index 12abce7b35..b1e2da3c3b 100644 --- a/lib/spack/spack/test/sbang.py +++ b/lib/spack/spack/test/sbang.py @@ -27,13 +27,16 @@ Test that Spack's shebang filtering works correctly. """ import os import stat -import unittest +import pytest import tempfile import shutil from llnl.util.filesystem import * -from spack.hooks.sbang import filter_shebangs_in_directory + import spack +from spack.hooks.sbang import * +from spack.util.executable import which + short_line = "#!/this/is/short/bin/bash\n" long_line = "#!/this/" + ('x' * 200) + "/is/long\n" @@ -43,14 +46,13 @@ sbang_line = '#!/bin/bash %s/bin/sbang\n' % spack.spack_root last_line = "last!\n" -class SbangTest(unittest.TestCase): - - def setUp(self): +class ScriptDirectory(object): + """Directory full of test scripts to run sbang instrumentation on.""" + def __init__(self): self.tempdir = tempfile.mkdtemp() - # make sure we can ignore non-files - directory = os.path.join(self.tempdir, 'dir') - mkdirp(directory) + self.directory = os.path.join(self.tempdir, 'dir') + mkdirp(self.directory) # Script with short shebang self.short_shebang = os.path.join(self.tempdir, 'short') @@ -71,48 +73,70 @@ class SbangTest(unittest.TestCase): f.write(last_line) # Script already using sbang. - self.has_shebang = os.path.join(self.tempdir, 'shebang') - with open(self.has_shebang, 'w') as f: + self.has_sbang = os.path.join(self.tempdir, 'shebang') + with open(self.has_sbang, 'w') as f: f.write(sbang_line) f.write(long_line) f.write(last_line) - def tearDown(self): + # Fake binary file. + self.binary = os.path.join(self.tempdir, 'binary') + tar = which('tar', required=True) + tar('czf', self.binary, self.has_sbang) + + def destroy(self): shutil.rmtree(self.tempdir, ignore_errors=True) - def test_shebang_handling(self): - filter_shebangs_in_directory(self.tempdir) - - # Make sure this is untouched - with open(self.short_shebang, 'r') as f: - self.assertEqual(f.readline(), short_line) - self.assertEqual(f.readline(), last_line) - - # Make sure this got patched. - with open(self.long_shebang, 'r') as f: - self.assertEqual(f.readline(), sbang_line) - self.assertEqual(f.readline(), long_line) - self.assertEqual(f.readline(), last_line) - - # Make sure this got patched. - with open(self.lua_shebang, 'r') as f: - self.assertEqual(f.readline(), sbang_line) - self.assertEqual(f.readline(), lua_line_patched) - self.assertEqual(f.readline(), last_line) - - # Make sure this is untouched - with open(self.has_shebang, 'r') as f: - self.assertEqual(f.readline(), sbang_line) - self.assertEqual(f.readline(), long_line) - self.assertEqual(f.readline(), last_line) - - def test_shebang_handles_non_writable_files(self): - # make a file non-writable - st = os.stat(self.long_shebang) - not_writable_mode = st.st_mode & ~stat.S_IWRITE - os.chmod(self.long_shebang, not_writable_mode) - - self.test_shebang_handling() - - st = os.stat(self.long_shebang) - self.assertEqual(oct(not_writable_mode), oct(st.st_mode)) + +@pytest.fixture +def script_dir(): + sdir = ScriptDirectory() + yield sdir + sdir.destroy() + + +def test_shebang_handling(script_dir): + assert shebang_too_long(script_dir.lua_shebang) + assert shebang_too_long(script_dir.long_shebang) + + assert not shebang_too_long(script_dir.short_shebang) + assert not shebang_too_long(script_dir.has_sbang) + assert not shebang_too_long(script_dir.binary) + assert not shebang_too_long(script_dir.directory) + + filter_shebangs_in_directory(script_dir.tempdir) + + # Make sure this is untouched + with open(script_dir.short_shebang, 'r') as f: + assert f.readline() == short_line + assert f.readline() == last_line + + # Make sure this got patched. + with open(script_dir.long_shebang, 'r') as f: + assert f.readline() == sbang_line + assert f.readline() == long_line + assert f.readline() == last_line + + # Make sure this got patched. + with open(script_dir.lua_shebang, 'r') as f: + assert f.readline() == sbang_line + assert f.readline() == lua_line_patched + assert f.readline() == last_line + + # Make sure this is untouched + with open(script_dir.has_sbang, 'r') as f: + assert f.readline() == sbang_line + assert f.readline() == long_line + assert f.readline() == last_line + + +def test_shebang_handles_non_writable_files(script_dir): + # make a file non-writable + st = os.stat(script_dir.long_shebang) + not_writable_mode = st.st_mode & ~stat.S_IWRITE + os.chmod(script_dir.long_shebang, not_writable_mode) + + test_shebang_handling(script_dir) + + st = os.stat(script_dir.long_shebang) + assert oct(not_writable_mode) == oct(st.st_mode) |