From fa1faa61c415ccab4de88b4f3ed3eb96b0e5ec4c Mon Sep 17 00:00:00 2001 From: Todd Gamblin Date: Tue, 22 Aug 2017 14:01:02 -0700 Subject: SpackCommand uses log_output to capture command output. --- lib/spack/spack/main.py | 40 +++++++++++--------------------- lib/spack/spack/test/cmd/dependencies.py | 8 +++---- lib/spack/spack/test/cmd/dependents.py | 8 +++---- lib/spack/spack/test/cmd/python.py | 2 +- lib/spack/spack/test/cmd/url.py | 14 +++++------ 5 files changed, 29 insertions(+), 43 deletions(-) (limited to 'lib') diff --git a/lib/spack/spack/main.py b/lib/spack/spack/main.py index b17cb3cdc5..2f542604d5 100644 --- a/lib/spack/spack/main.py +++ b/lib/spack/spack/main.py @@ -34,9 +34,10 @@ import os import inspect import pstats import argparse -import tempfile +from six import StringIO import llnl.util.tty as tty +from llnl.util.tty.log import log_output from llnl.util.tty.color import * import spack @@ -367,7 +368,7 @@ class SpackCommand(object): install('-v', 'mpich') Use this to invoke Spack commands directly from Python and check - their stdout and stderr. + their output. """ def __init__(self, command): """Create a new SpackCommand that invokes ``command`` when called.""" @@ -376,9 +377,6 @@ class SpackCommand(object): self.command_name = command self.command = spack.cmd.get_command(command) - self.returncode = None - self.error = None - def __call__(self, *argv, **kwargs): """Invoke this SpackCommand. @@ -389,26 +387,26 @@ class SpackCommand(object): fail_on_error (optional bool): Don't raise an exception on error Returns: - (str, str): output and error as a strings + (str): combined output and error as a string On return, if ``fail_on_error`` is False, return value of comman is set in ``returncode`` property, and the error is set in the ``error`` property. Otherwise, raise an error. """ + # set these before every call to clear them out + self.returncode = None + self.error = None + args, unknown = self.parser.parse_known_args( [self.command_name] + list(argv)) fail_on_error = kwargs.get('fail_on_error', True) - out, err = sys.stdout, sys.stderr - ofd, ofn = tempfile.mkstemp() - efd, efn = tempfile.mkstemp() - + out = StringIO() try: - sys.stdout = open(ofn, 'w') - sys.stderr = open(efn, 'w') - self.returncode = _invoke_spack_command( - self.command, self.parser, args, unknown) + with log_output(out): + self.returncode = _invoke_spack_command( + self.command, self.parser, args, unknown) except SystemExit as e: self.returncode = e.code @@ -418,25 +416,13 @@ class SpackCommand(object): if fail_on_error: raise - finally: - sys.stdout.flush() - sys.stdout.close() - sys.stderr.flush() - sys.stderr.close() - sys.stdout, sys.stderr = out, err - - return_out = open(ofn).read() - return_err = open(efn).read() - os.unlink(ofn) - os.unlink(efn) - if fail_on_error and self.returncode not in (None, 0): raise SpackCommandError( "Command exited with code %d: %s(%s)" % ( self.returncode, self.command_name, ', '.join("'%s'" % a for a in argv))) - return return_out, return_err + return out.getvalue() def _main(command, parser, args, unknown_args): diff --git a/lib/spack/spack/test/cmd/dependencies.py b/lib/spack/spack/test/cmd/dependencies.py index e024fcc2e6..58f778c660 100644 --- a/lib/spack/spack/test/cmd/dependencies.py +++ b/lib/spack/spack/test/cmd/dependencies.py @@ -36,14 +36,14 @@ mpi_deps = ['fake'] def test_immediate_dependencies(builtin_mock): - out, err = dependencies('mpileaks') + out = dependencies('mpileaks') actual = set(re.split(r'\s+', out.strip())) expected = set(['callpath'] + mpis) assert expected == actual def test_transitive_dependencies(builtin_mock): - out, err = dependencies('--transitive', 'mpileaks') + out = dependencies('--transitive', 'mpileaks') actual = set(re.split(r'\s+', out.strip())) expected = set( ['callpath', 'dyninst', 'libdwarf', 'libelf'] + mpis + mpi_deps) @@ -52,7 +52,7 @@ def test_transitive_dependencies(builtin_mock): def test_immediate_installed_dependencies(builtin_mock, database): with color_when(False): - out, err = dependencies('--installed', 'mpileaks^mpich') + out = dependencies('--installed', 'mpileaks^mpich') lines = [l for l in out.strip().split('\n') if not l.startswith('--')] hashes = set([re.split(r'\s+', l)[0] for l in lines]) @@ -65,7 +65,7 @@ def test_immediate_installed_dependencies(builtin_mock, database): def test_transitive_installed_dependencies(builtin_mock, database): with color_when(False): - out, err = dependencies('--installed', '--transitive', 'mpileaks^zmpi') + out = dependencies('--installed', '--transitive', 'mpileaks^zmpi') lines = [l for l in out.strip().split('\n') if not l.startswith('--')] hashes = set([re.split(r'\s+', l)[0] for l in lines]) diff --git a/lib/spack/spack/test/cmd/dependents.py b/lib/spack/spack/test/cmd/dependents.py index 546d6d48c9..c43270a2af 100644 --- a/lib/spack/spack/test/cmd/dependents.py +++ b/lib/spack/spack/test/cmd/dependents.py @@ -33,13 +33,13 @@ dependents = SpackCommand('dependents') def test_immediate_dependents(builtin_mock): - out, err = dependents('libelf') + out = dependents('libelf') actual = set(re.split(r'\s+', out.strip())) assert actual == set(['dyninst', 'libdwarf']) def test_transitive_dependents(builtin_mock): - out, err = dependents('--transitive', 'libelf') + out = dependents('--transitive', 'libelf') actual = set(re.split(r'\s+', out.strip())) assert actual == set( ['callpath', 'dyninst', 'libdwarf', 'mpileaks', 'multivalue_variant', @@ -48,7 +48,7 @@ def test_transitive_dependents(builtin_mock): def test_immediate_installed_dependents(builtin_mock, database): with color_when(False): - out, err = dependents('--installed', 'libelf') + out = dependents('--installed', 'libelf') lines = [l for l in out.strip().split('\n') if not l.startswith('--')] hashes = set([re.split(r'\s+', l)[0] for l in lines]) @@ -64,7 +64,7 @@ def test_immediate_installed_dependents(builtin_mock, database): def test_transitive_installed_dependents(builtin_mock, database): with color_when(False): - out, err = dependents('--installed', '--transitive', 'fake') + out = dependents('--installed', '--transitive', 'fake') lines = [l for l in out.strip().split('\n') if not l.startswith('--')] hashes = set([re.split(r'\s+', l)[0] for l in lines]) diff --git a/lib/spack/spack/test/cmd/python.py b/lib/spack/spack/test/cmd/python.py index 5e3ea83053..db9d9c5e41 100644 --- a/lib/spack/spack/test/cmd/python.py +++ b/lib/spack/spack/test/cmd/python.py @@ -29,5 +29,5 @@ python = SpackCommand('python') def test_python(): - out, err = python('-c', 'import spack; print(spack.spack_version)') + out = python('-c', 'import spack; print(spack.spack_version)') assert out.strip() == str(spack.spack_version) diff --git a/lib/spack/spack/test/cmd/url.py b/lib/spack/spack/test/cmd/url.py index 21f88e928b..ab2d750dee 100644 --- a/lib/spack/spack/test/cmd/url.py +++ b/lib/spack/spack/test/cmd/url.py @@ -83,30 +83,30 @@ def test_url_with_no_version_fails(): def test_url_list(): - out, err = url('list') + out = url('list') total_urls = len(out.split('\n')) # The following two options should not change the number of URLs printed. - out, err = url('list', '--color', '--extrapolation') + out = url('list', '--color', '--extrapolation') colored_urls = len(out.split('\n')) assert colored_urls == total_urls # The following options should print fewer URLs than the default. # If they print the same number of URLs, something is horribly broken. # If they say we missed 0 URLs, something is probably broken too. - out, err = url('list', '--incorrect-name') + out = url('list', '--incorrect-name') incorrect_name_urls = len(out.split('\n')) assert 0 < incorrect_name_urls < total_urls - out, err = url('list', '--incorrect-version') + out = url('list', '--incorrect-version') incorrect_version_urls = len(out.split('\n')) assert 0 < incorrect_version_urls < total_urls - out, err = url('list', '--correct-name') + out = url('list', '--correct-name') correct_name_urls = len(out.split('\n')) assert 0 < correct_name_urls < total_urls - out, err = url('list', '--correct-version') + out = url('list', '--correct-version') correct_version_urls = len(out.split('\n')) assert 0 < correct_version_urls < total_urls @@ -121,7 +121,7 @@ def test_url_summary(): assert 0 < correct_versions <= sum(version_count_dict.values()) <= total_urls # noqa # make sure it agrees with the actual command. - out, err = url('summary') + out = url('summary') out_total_urls = int( re.search(r'Total URLs found:\s*(\d+)', out).group(1)) assert out_total_urls == total_urls -- cgit v1.2.3-60-g2f50