summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTodd Gamblin <tgamblin@llnl.gov>2022-01-03 14:44:11 -0800
committerGreg Becker <becker33@llnl.gov>2022-01-12 06:14:18 -0800
commitec16c2d7c2d3f8f315d3b0502ae1d7356fa393fc (patch)
treef3ef5bdf3f2d0c227a93ba1ff87277912b966be7
parente9612696fdf2af66d63640d4fca2639000afbe54 (diff)
downloadspack-ec16c2d7c2d3f8f315d3b0502ae1d7356fa393fc.tar.gz
spack-ec16c2d7c2d3f8f315d3b0502ae1d7356fa393fc.tar.bz2
spack-ec16c2d7c2d3f8f315d3b0502ae1d7356fa393fc.tar.xz
spack-ec16c2d7c2d3f8f315d3b0502ae1d7356fa393fc.zip
unparser: do a better job of roundtripping strings
Handle complex f-strings. Backport of: https://github.com/python/cpython/commit/a993e901ebe60c38d46ecb31f771d0b4a206828c#
-rw-r--r--lib/spack/spack/util/unparse/unparser.py118
1 files changed, 99 insertions, 19 deletions
diff --git a/lib/spack/spack/util/unparse/unparser.py b/lib/spack/spack/util/unparse/unparser.py
index 7d3a7b69f7..8a43f3ac35 100644
--- a/lib/spack/spack/util/unparse/unparser.py
+++ b/lib/spack/spack/util/unparse/unparser.py
@@ -65,6 +65,11 @@ def interleave(inter, f, seq):
f(x)
+_SINGLE_QUOTES = ("'", '"')
+_MULTI_QUOTES = ('"""', "'''")
+_ALL_QUOTES = _SINGLE_QUOTES + _MULTI_QUOTES
+
+
def is_simple_tuple(slice_value):
# when unparsing a non-empty tuple, the parantheses can be safely
# omitted if there aren't any elements that explicitly requires
@@ -86,7 +91,7 @@ class Unparser:
output source code for the abstract syntax; original formatting
is disregarded. """
- def __init__(self, py_ver_consistent=False):
+ def __init__(self, py_ver_consistent=False, _avoid_backslashes=False):
"""Traverse an AST and generate its source.
Arguments:
@@ -118,6 +123,7 @@ class Unparser:
self._indent = 0
self._py_ver_consistent = py_ver_consistent
self._precedences = {}
+ self._avoid_backslashes = _avoid_backslashes
def items_view(self, traverser, items):
"""Traverse and separate the given *items* with a comma and append it to
@@ -596,6 +602,53 @@ class Unparser:
def _AsyncWith(self, t):
self._generic_With(t, async_=True)
+ def _str_literal_helper(
+ self, string, quote_types=_ALL_QUOTES, escape_special_whitespace=False
+ ):
+ """Helper for writing string literals, minimizing escapes.
+ Returns the tuple (string literal to write, possible quote types).
+ """
+ def escape_char(c):
+ # \n and \t are non-printable, but we only escape them if
+ # escape_special_whitespace is True
+ if not escape_special_whitespace and c in "\n\t":
+ return c
+ # Always escape backslashes and other non-printable characters
+ if c == "\\" or not c.isprintable():
+ return c.encode("unicode_escape").decode("ascii")
+ return c
+
+ escaped_string = "".join(map(escape_char, string))
+ possible_quotes = quote_types
+ if "\n" in escaped_string:
+ possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES]
+ possible_quotes = [q for q in possible_quotes if q not in escaped_string]
+ if not possible_quotes:
+ # If there aren't any possible_quotes, fallback to using repr
+ # on the original string. Try to use a quote from quote_types,
+ # e.g., so that we use triple quotes for docstrings.
+ string = repr(string)
+ quote = next((q for q in quote_types if string[0] in q), string[0])
+ return string[1:-1], [quote]
+ if escaped_string:
+ # Sort so that we prefer '''"''' over """\""""
+ possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1])
+ # If we're using triple quotes and we'd need to escape a final
+ # quote, escape it
+ if possible_quotes[0][0] == escaped_string[-1]:
+ assert len(possible_quotes[0]) == 3
+ escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1]
+ return escaped_string, possible_quotes
+
+ def _write_str_avoiding_backslashes(self, string, quote_types=_ALL_QUOTES):
+ """Write string literal value w/a best effort attempt to avoid backslashes."""
+ string, quote_types = self._str_literal_helper(string, quote_types=quote_types)
+ quote_type = quote_types[0]
+ self.write("{quote_type}{string}{quote_type}".format(
+ quote_type=quote_type,
+ string=string,
+ ))
+
# expr
def _Bytes(self, t):
self.write(repr(t.s))
@@ -625,33 +678,53 @@ class Unparser:
def _JoinedStr(self, t):
# JoinedStr(expr* values)
self.write("f")
- string = StringIO()
- self._fstring_JoinedStr(t, string.write)
- # Deviation from `unparse.py`: Try to find an unused quote.
- # This change is made to handle _very_ complex f-strings.
- v = string.getvalue()
- if '\n' in v or '\r' in v:
- quote_types = ["'''", '"""']
- else:
- quote_types = ["'", '"', '"""', "'''"]
- for quote_type in quote_types:
- if quote_type not in v:
- v = "{quote_type}{v}{quote_type}".format(quote_type=quote_type, v=v)
- break
- else:
- v = repr(v)
- self.write(v)
+
+ if self._avoid_backslashes:
+ string = StringIO()
+ self._fstring_JoinedStr(t, string.write)
+ self._write_str_avoiding_backslashes(string.getvalue())
+ return
+
+ # If we don't need to avoid backslashes globally (i.e., we only need
+ # to avoid them inside FormattedValues), it's cosmetically preferred
+ # to use escaped whitespace. That is, it's preferred to use backslashes
+ # for cases like: f"{x}\n". To accomplish this, we keep track of what
+ # in our buffer corresponds to FormattedValues and what corresponds to
+ # Constant parts of the f-string, and allow escapes accordingly.
+ buffer = []
+ for value in t.values:
+ meth = getattr(self, "_fstring_" + type(value).__name__)
+ string = StringIO()
+ meth(value, string.write)
+ buffer.append((string.getvalue(), isinstance(value, ast.Constant)))
+ new_buffer = []
+ quote_types = _ALL_QUOTES
+ for value, is_constant in buffer:
+ # Repeatedly narrow down the list of possible quote_types
+ value, quote_types = self._str_literal_helper(
+ value, quote_types=quote_types,
+ escape_special_whitespace=is_constant
+ )
+ new_buffer.append(value)
+ value = "".join(new_buffer)
+ quote_type = quote_types[0]
+ self.write("{quote_type}{value}{quote_type}".format(
+ quote_type=quote_type,
+ value=value,
+ ))
def _FormattedValue(self, t):
# FormattedValue(expr value, int? conversion, expr? format_spec)
self.write("f")
string = StringIO()
self._fstring_JoinedStr(t, string.write)
- self.write(repr(string.getvalue()))
+ self._write_str_avoiding_backslashes(string.getvalue())
def _fstring_JoinedStr(self, t, write):
for value in t.values:
+ print(" ", value)
meth = getattr(self, "_fstring_" + type(value).__name__)
+ print(meth)
meth(value, write)
def _fstring_Str(self, t, write):
@@ -667,13 +740,18 @@ class Unparser:
write("{")
expr = StringIO()
- unparser = type(self)(py_ver_consistent=self._py_ver_consistent)
+ unparser = type(self)(
+ py_ver_consistent=self._py_ver_consistent,
+ _avoid_backslashes=True,
+ )
unparser.set_precedence(pnext(_Precedence.TEST), t.value)
unparser.visit(t.value, expr)
expr = expr.getvalue().rstrip("\n")
if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {"
+ if "\\" in expr:
+ raise ValueError("Unable to avoid backslash in f-string expression part")
write(expr)
if t.conversion != -1:
conversion = chr(t.conversion)
@@ -707,6 +785,8 @@ class Unparser:
if raw.startswith(r"'\\u"):
raw = "'\\" + raw[3:]
self.write(raw)
+ elif self._avoid_backslashes and isinstance(value, str):
+ self._write_str_avoiding_backslashes(value)
else:
self.write(repr(value))