From ec16c2d7c2d3f8f315d3b0502ae1d7356fa393fc Mon Sep 17 00:00:00 2001 From: Todd Gamblin Date: Mon, 3 Jan 2022 14:44:11 -0800 Subject: unparser: do a better job of roundtripping strings Handle complex f-strings. Backport of: https://github.com/python/cpython/commit/a993e901ebe60c38d46ecb31f771d0b4a206828c# --- lib/spack/spack/util/unparse/unparser.py | 118 ++++++++++++++++++++++++++----- 1 file changed, 99 insertions(+), 19 deletions(-) diff --git a/lib/spack/spack/util/unparse/unparser.py b/lib/spack/spack/util/unparse/unparser.py index 7d3a7b69f7..8a43f3ac35 100644 --- a/lib/spack/spack/util/unparse/unparser.py +++ b/lib/spack/spack/util/unparse/unparser.py @@ -65,6 +65,11 @@ def interleave(inter, f, seq): f(x) +_SINGLE_QUOTES = ("'", '"') +_MULTI_QUOTES = ('"""', "'''") +_ALL_QUOTES = _SINGLE_QUOTES + _MULTI_QUOTES + + def is_simple_tuple(slice_value): # when unparsing a non-empty tuple, the parantheses can be safely # omitted if there aren't any elements that explicitly requires @@ -86,7 +91,7 @@ class Unparser: output source code for the abstract syntax; original formatting is disregarded. """ - def __init__(self, py_ver_consistent=False): + def __init__(self, py_ver_consistent=False, _avoid_backslashes=False): """Traverse an AST and generate its source. Arguments: @@ -118,6 +123,7 @@ class Unparser: self._indent = 0 self._py_ver_consistent = py_ver_consistent self._precedences = {} + self._avoid_backslashes = _avoid_backslashes def items_view(self, traverser, items): """Traverse and separate the given *items* with a comma and append it to @@ -596,6 +602,53 @@ class Unparser: def _AsyncWith(self, t): self._generic_With(t, async_=True) + def _str_literal_helper( + self, string, quote_types=_ALL_QUOTES, escape_special_whitespace=False + ): + """Helper for writing string literals, minimizing escapes. + Returns the tuple (string literal to write, possible quote types). + """ + def escape_char(c): + # \n and \t are non-printable, but we only escape them if + # escape_special_whitespace is True + if not escape_special_whitespace and c in "\n\t": + return c + # Always escape backslashes and other non-printable characters + if c == "\\" or not c.isprintable(): + return c.encode("unicode_escape").decode("ascii") + return c + + escaped_string = "".join(map(escape_char, string)) + possible_quotes = quote_types + if "\n" in escaped_string: + possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES] + possible_quotes = [q for q in possible_quotes if q not in escaped_string] + if not possible_quotes: + # If there aren't any possible_quotes, fallback to using repr + # on the original string. Try to use a quote from quote_types, + # e.g., so that we use triple quotes for docstrings. + string = repr(string) + quote = next((q for q in quote_types if string[0] in q), string[0]) + return string[1:-1], [quote] + if escaped_string: + # Sort so that we prefer '''"''' over """\"""" + possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1]) + # If we're using triple quotes and we'd need to escape a final + # quote, escape it + if possible_quotes[0][0] == escaped_string[-1]: + assert len(possible_quotes[0]) == 3 + escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1] + return escaped_string, possible_quotes + + def _write_str_avoiding_backslashes(self, string, quote_types=_ALL_QUOTES): + """Write string literal value w/a best effort attempt to avoid backslashes.""" + string, quote_types = self._str_literal_helper(string, quote_types=quote_types) + quote_type = quote_types[0] + self.write("{quote_type}{string}{quote_type}".format( + quote_type=quote_type, + string=string, + )) + # expr def _Bytes(self, t): self.write(repr(t.s)) @@ -625,33 +678,53 @@ class Unparser: def _JoinedStr(self, t): # JoinedStr(expr* values) self.write("f") - string = StringIO() - self._fstring_JoinedStr(t, string.write) - # Deviation from `unparse.py`: Try to find an unused quote. - # This change is made to handle _very_ complex f-strings. - v = string.getvalue() - if '\n' in v or '\r' in v: - quote_types = ["'''", '"""'] - else: - quote_types = ["'", '"', '"""', "'''"] - for quote_type in quote_types: - if quote_type not in v: - v = "{quote_type}{v}{quote_type}".format(quote_type=quote_type, v=v) - break - else: - v = repr(v) - self.write(v) + + if self._avoid_backslashes: + string = StringIO() + self._fstring_JoinedStr(t, string.write) + self._write_str_avoiding_backslashes(string.getvalue()) + return + + # If we don't need to avoid backslashes globally (i.e., we only need + # to avoid them inside FormattedValues), it's cosmetically preferred + # to use escaped whitespace. That is, it's preferred to use backslashes + # for cases like: f"{x}\n". To accomplish this, we keep track of what + # in our buffer corresponds to FormattedValues and what corresponds to + # Constant parts of the f-string, and allow escapes accordingly. + buffer = [] + for value in t.values: + meth = getattr(self, "_fstring_" + type(value).__name__) + string = StringIO() + meth(value, string.write) + buffer.append((string.getvalue(), isinstance(value, ast.Constant))) + new_buffer = [] + quote_types = _ALL_QUOTES + for value, is_constant in buffer: + # Repeatedly narrow down the list of possible quote_types + value, quote_types = self._str_literal_helper( + value, quote_types=quote_types, + escape_special_whitespace=is_constant + ) + new_buffer.append(value) + value = "".join(new_buffer) + quote_type = quote_types[0] + self.write("{quote_type}{value}{quote_type}".format( + quote_type=quote_type, + value=value, + )) def _FormattedValue(self, t): # FormattedValue(expr value, int? conversion, expr? format_spec) self.write("f") string = StringIO() self._fstring_JoinedStr(t, string.write) - self.write(repr(string.getvalue())) + self._write_str_avoiding_backslashes(string.getvalue()) def _fstring_JoinedStr(self, t, write): for value in t.values: + print(" ", value) meth = getattr(self, "_fstring_" + type(value).__name__) + print(meth) meth(value, write) def _fstring_Str(self, t, write): @@ -667,13 +740,18 @@ class Unparser: write("{") expr = StringIO() - unparser = type(self)(py_ver_consistent=self._py_ver_consistent) + unparser = type(self)( + py_ver_consistent=self._py_ver_consistent, + _avoid_backslashes=True, + ) unparser.set_precedence(pnext(_Precedence.TEST), t.value) unparser.visit(t.value, expr) expr = expr.getvalue().rstrip("\n") if expr.startswith("{"): write(" ") # Separate pair of opening brackets as "{ {" + if "\\" in expr: + raise ValueError("Unable to avoid backslash in f-string expression part") write(expr) if t.conversion != -1: conversion = chr(t.conversion) @@ -707,6 +785,8 @@ class Unparser: if raw.startswith(r"'\\u"): raw = "'\\" + raw[3:] self.write(raw) + elif self._avoid_backslashes and isinstance(value, str): + self._write_str_avoiding_backslashes(value) else: self.write(repr(value)) -- cgit v1.2.3-70-g09d2