diff options
author | Todd Gamblin <tgamblin@llnl.gov> | 2024-11-29 23:21:07 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-30 08:21:07 +0100 |
commit | 2a2d1989c1a2d8241822c7696ce2926a22949444 (patch) | |
tree | f888836a716309c0fbde3aa0a74c99d8be3e629e | |
parent | c6e292f55fc424e048a61347033c9b8fa533d8f5 (diff) | |
download | spack-2a2d1989c1a2d8241822c7696ce2926a22949444.tar.gz spack-2a2d1989c1a2d8241822c7696ce2926a22949444.tar.bz2 spack-2a2d1989c1a2d8241822c7696ce2926a22949444.tar.xz spack-2a2d1989c1a2d8241822c7696ce2926a22949444.zip |
`version_types`: clean up type hierarchy and add annotations (#47781)
In preparation for adding `when=` to `version()`, I'm cleaning up the types in
`version_types` and making sure the methods here pass `mypy` checks. This started as an
attempt to use `ConcreteVersion` outside of `spack.version` and grew into a larger type
refactor.
The hierarchy now looks like this:
* `VersionType`
* `ConcreteVersion`
* `StandardVersion`
* `GitVersion`
* `ClosedOpenRange`
* `VersionList`
Note that the top-level thing can't easily be `Version` as that is a method and it
returns only `ConcreteVersion` right now. I *could* do something fancy with `__new__` to
make `Version` a synonym for the `ConcreteVersion` constructor, which would allow it to
be used as a type. I could also do something similar with `VersionRange` but not sure if
it's worth it just to make these into types.
There are still some places where I think `GitVersion` might not be handled properly,
but I have not attempted to fix those here.
- [x] Add a top-level `VersionType` class that all version types extend from
- [x] Define and document common methods and rich comparisons on `VersionType`
- [x] Replace complicated `Union` types with `VersionType` and `ConcreteVersion` as needed
- [x] Annotate most methods (skipping `__getitem__` and friends as the typing is a pain)
- [x] Fix up the `VersionList` constructor a bit
- [x] Add cases to methods that weren't handling all `VersionType`s
- [x] Rework some places to clarify typing for `mypy`
- [x] Simplify / optimize _next_version
- [x] Make StandardVersion.string a property to enable lazy comparison
Signed-off-by: Todd Gamblin <tgamblin@llnl.gov>
-rw-r--r-- | lib/spack/spack/test/versions.py | 3 | ||||
-rw-r--r-- | lib/spack/spack/version/__init__.py | 28 | ||||
-rw-r--r-- | lib/spack/spack/version/version_types.py | 498 |
3 files changed, 328 insertions, 201 deletions
diff --git a/lib/spack/spack/test/versions.py b/lib/spack/spack/test/versions.py index 734ba4ca4a..4c5081e8d1 100644 --- a/lib/spack/spack/test/versions.py +++ b/lib/spack/spack/test/versions.py @@ -607,6 +607,9 @@ def test_stringify_version(version_str): v.string = None assert str(v) == version_str + v.string = None + assert v.string == version_str + def test_len(): a = Version("1.2.3.4") diff --git a/lib/spack/spack/version/__init__.py b/lib/spack/spack/version/__init__.py index 18d739ae0c..a94f641cff 100644 --- a/lib/spack/spack/version/__init__.py +++ b/lib/spack/spack/version/__init__.py @@ -25,11 +25,13 @@ from .common import ( ) from .version_types import ( ClosedOpenRange, + ConcreteVersion, GitVersion, StandardVersion, Version, VersionList, VersionRange, + VersionType, _next_version, _prev_version, from_string, @@ -40,21 +42,23 @@ from .version_types import ( any_version: VersionList = VersionList([":"]) __all__ = [ - "Version", - "VersionRange", - "ver", - "from_string", - "is_git_version", - "infinity_versions", - "_prev_version", - "_next_version", - "VersionList", "ClosedOpenRange", - "StandardVersion", + "ConcreteVersion", + "EmptyRangeError", "GitVersion", - "VersionError", + "StandardVersion", + "Version", "VersionChecksumError", + "VersionError", + "VersionList", "VersionLookupError", - "EmptyRangeError", + "VersionRange", + "VersionType", + "_next_version", + "_prev_version", "any_version", + "from_string", + "infinity_versions", + "is_git_version", + "ver", ] diff --git a/lib/spack/spack/version/version_types.py b/lib/spack/spack/version/version_types.py index f35192192d..4c7a9606f4 100644 --- a/lib/spack/spack/version/version_types.py +++ b/lib/spack/spack/version/version_types.py @@ -3,10 +3,9 @@ # # SPDX-License-Identifier: (Apache-2.0 OR MIT) -import numbers import re from bisect import bisect_left -from typing import List, Optional, Tuple, Union +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union from spack.util.spack_yaml import syaml_dict @@ -32,26 +31,44 @@ SEGMENT_REGEX = re.compile(r"(?:(?P<num>[0-9]+)|(?P<str>[a-zA-Z]+))(?P<sep>[_.-] class VersionStrComponent: + """Internal representation of the string (non-integer) components of Spack versions. + + Versions comprise string and integer components (see ``SEGMENT_REGEX`` above). + + This represents a string component, which is either some component consisting only + of alphabetical characters, *or* a special "infinity version" like ``main``, + ``develop``, ``master``, etc. + + For speed, Spack versions are designed to map to Python tuples, so that we can use + Python's fast lexicographic tuple comparison on them. ``VersionStrComponent`` is + designed to work as a component in these version tuples, and as such must compare + directly with ``int`` or other ``VersionStrComponent`` objects. + + """ + __slots__ = ["data"] - def __init__(self, data): + data: Union[int, str] + + def __init__(self, data: Union[int, str]): # int for infinity index, str for literal. - self.data: Union[int, str] = data + self.data = data @staticmethod - def from_string(string): + def from_string(string: str) -> "VersionStrComponent": + value: Union[int, str] = string if len(string) >= iv_min_len: try: - string = infinity_versions.index(string) + value = infinity_versions.index(string) except ValueError: pass - return VersionStrComponent(string) + return VersionStrComponent(value) - def __hash__(self): + def __hash__(self) -> int: return hash(self.data) - def __str__(self): + def __str__(self) -> str: return ( ("infinity" if self.data >= len(infinity_versions) else infinity_versions[self.data]) if isinstance(self.data, int) @@ -61,38 +78,61 @@ class VersionStrComponent: def __repr__(self) -> str: return f'VersionStrComponent("{self}")' - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, VersionStrComponent) and self.data == other.data - def __lt__(self, other): - lhs_inf = isinstance(self.data, int) + # ignore typing for certain parts of these methods b/c a) they are performance-critical, and + # b) mypy isn't smart enough to figure out that if l_inf and r_inf are the same, comparing + # self.data and other.data is type safe. + def __lt__(self, other: object) -> bool: + l_inf = isinstance(self.data, int) if isinstance(other, int): - return not lhs_inf - rhs_inf = isinstance(other.data, int) - return (not lhs_inf and rhs_inf) if lhs_inf ^ rhs_inf else self.data < other.data + return not l_inf + r_inf = isinstance(other.data, int) # type: ignore + return (not l_inf and r_inf) if l_inf ^ r_inf else self.data < other.data # type: ignore - def __le__(self, other): - return self < other or self == other - - def __gt__(self, other): - lhs_inf = isinstance(self.data, int) + def __gt__(self, other: object) -> bool: + l_inf = isinstance(self.data, int) if isinstance(other, int): - return lhs_inf - rhs_inf = isinstance(other.data, int) - return (lhs_inf and not rhs_inf) if lhs_inf ^ rhs_inf else self.data > other.data + return l_inf + r_inf = isinstance(other.data, int) # type: ignore + return (l_inf and not r_inf) if l_inf ^ r_inf else self.data > other.data # type: ignore - def __ge__(self, other): + def __le__(self, other: object) -> bool: + return self < other or self == other + + def __ge__(self, other: object) -> bool: return self > other or self == other -def parse_string_components(string: str) -> Tuple[tuple, tuple]: +# Tuple types that make up the internal representation of StandardVersion. +# We use Tuples so that Python can quickly compare versions. + +#: Version components are integers for numeric parts, VersionStrComponents for string parts. +VersionComponentTuple = Tuple[Union[int, VersionStrComponent], ...] + +#: A Prerelease identifier is a constant for alpha/beta/rc/final and one optional number. +#: Most versions will have this set to ``(FINAL,)``. Prereleases will have some other +#: initial constant followed by a number, e.g. ``(RC, 1)``. +PrereleaseTuple = Tuple[int, ...] + +#: Actual version tuple, including the split version number itself and the prerelease, +#: all represented as tuples. +VersionTuple = Tuple[VersionComponentTuple, PrereleaseTuple] + +#: Separators from a parsed version. +SeparatorTuple = Tuple[str, ...] + + +def parse_string_components(string: str) -> Tuple[VersionTuple, SeparatorTuple]: + """Parse a string into a ``VersionTuple`` and ``SeparatorTuple``.""" string = string.strip() if string and not VALID_VERSION.match(string): raise ValueError("Bad characters in version string: %s" % string) segments = SEGMENT_REGEX.findall(string) - separators = tuple(m[2] for m in segments) + separators: Tuple[str] = tuple(m[2] for m in segments) prerelease: Tuple[int, ...] # <version>(alpha|beta|rc)<number> @@ -109,63 +149,150 @@ def parse_string_components(string: str) -> Tuple[tuple, tuple]: else: prerelease = (FINAL,) - release = tuple(int(m[0]) if m[0] else VersionStrComponent.from_string(m[1]) for m in segments) + release: VersionComponentTuple = tuple( + int(m[0]) if m[0] else VersionStrComponent.from_string(m[1]) for m in segments + ) return (release, prerelease), separators -class ConcreteVersion: - pass +class VersionType: + """Base type for all versions in Spack (ranges, lists, regular versions, and git versions). + + Versions in Spack behave like sets, and support some basic set operations. There are + four subclasses of ``VersionType``: + + * ``StandardVersion``: a single, concrete version, e.g. 3.4.5 or 5.4b0. + * ``GitVersion``: subclass of ``StandardVersion`` for handling git repositories. + * ``ClosedOpenRange``: an inclusive version range, closed or open, e.g. ``3.0:5.0``, + ``3.0:``, or ``:5.0`` + * ``VersionList``: An ordered list of any of the above types. + + Notably, when Spack parses a version, it's always a range *unless* specified with + ``@=`` to make it concrete. + + """ + def intersection(self, other: "VersionType") -> "VersionType": + """Any versions contained in both self and other, or empty VersionList if no overlap.""" + raise NotImplementedError -def _stringify_version(versions: Tuple[tuple, tuple], separators: tuple) -> str: + def intersects(self, other: "VersionType") -> bool: + """Whether self and other overlap.""" + raise NotImplementedError + + def overlaps(self, other: "VersionType") -> bool: + """Whether self and other overlap (same as ``intersects()``).""" + return self.intersects(other) + + def satisfies(self, other: "VersionType") -> bool: + """Whether self is entirely contained in other.""" + raise NotImplementedError + + def union(self, other: "VersionType") -> "VersionType": + """Return a VersionType containing self and other.""" + raise NotImplementedError + + # We can use SupportsRichComparisonT in Python 3.8 or later, but alas in 3.6 we need + # to write all the operators out + def __eq__(self, other: object) -> bool: + raise NotImplementedError + + def __lt__(self, other: object) -> bool: + raise NotImplementedError + + def __gt__(self, other: object) -> bool: + raise NotImplementedError + + def __ge__(self, other: object) -> bool: + raise NotImplementedError + + def __le__(self, other: object) -> bool: + raise NotImplementedError + + def __hash__(self) -> int: + raise NotImplementedError + + +class ConcreteVersion(VersionType): + """Base type for versions that represents a single (non-range or list) version.""" + + +def _stringify_version(versions: VersionTuple, separators: Tuple[str, ...]) -> str: + """Create a string representation from version components.""" release, prerelease = versions - string = "" - for i in range(len(release)): - string += f"{release[i]}{separators[i]}" + + components = [f"{rel}{sep}" for rel, sep in zip(release, separators)] if prerelease[0] != FINAL: - string += f"{PRERELEASE_TO_STRING[prerelease[0]]}{separators[len(release)]}" - if len(prerelease) > 1: - string += str(prerelease[1]) - return string + components.append(PRERELEASE_TO_STRING[prerelease[0]]) + if len(prerelease) > 1: + components.append(separators[len(release)]) + components.append(str(prerelease[1])) + + return "".join(components) class StandardVersion(ConcreteVersion): """Class to represent versions""" - __slots__ = ["version", "string", "separators"] + __slots__ = ["version", "_string", "separators"] + + _string: str + version: VersionTuple + separators: Tuple[str, ...] + + def __init__(self, string: str, version: VersionTuple, separators: Tuple[str, ...]): + """Create a StandardVersion from a string and parsed version components. - def __init__(self, string: Optional[str], version: Tuple[tuple, tuple], separators: tuple): - self.string = string + Arguments: + string: The original version string, or ``""`` if the it is not available. + version: A tuple as returned by ``parse_string_components()``. Contains two tuples: + one with alpha or numeric components and another with prerelease components. + separators: separators parsed from the original version string. + + If constructed with ``string=""``, the string will be lazily constructed from components + when ``str()`` is called. + """ + self._string = string self.version = version self.separators = separators @staticmethod - def from_string(string: str): + def from_string(string: str) -> "StandardVersion": return StandardVersion(string, *parse_string_components(string)) @staticmethod - def typemin(): + def typemin() -> "StandardVersion": return _STANDARD_VERSION_TYPEMIN @staticmethod - def typemax(): + def typemax() -> "StandardVersion": return _STANDARD_VERSION_TYPEMAX - def __bool__(self): + @property + def string(self) -> str: + if not self._string: + self._string = _stringify_version(self.version, self.separators) + return self._string + + @string.setter + def string(self, string) -> None: + self._string = string + + def __bool__(self) -> bool: return True - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, StandardVersion): return self.version == other.version return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: if isinstance(other, StandardVersion): return self.version != other.version return True - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if isinstance(other, StandardVersion): return self.version < other.version if isinstance(other, ClosedOpenRange): @@ -173,7 +300,7 @@ class StandardVersion(ConcreteVersion): return self <= other.lo return NotImplemented - def __le__(self, other): + def __le__(self, other: object) -> bool: if isinstance(other, StandardVersion): return self.version <= other.version if isinstance(other, ClosedOpenRange): @@ -181,7 +308,7 @@ class StandardVersion(ConcreteVersion): return self <= other.lo return NotImplemented - def __ge__(self, other): + def __ge__(self, other: object) -> bool: if isinstance(other, StandardVersion): return self.version >= other.version if isinstance(other, ClosedOpenRange): @@ -189,25 +316,25 @@ class StandardVersion(ConcreteVersion): return self > other.lo return NotImplemented - def __gt__(self, other): + def __gt__(self, other: object) -> bool: if isinstance(other, StandardVersion): return self.version > other.version if isinstance(other, ClosedOpenRange): return self > other.lo return NotImplemented - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.version[0]) - def __len__(self): + def __len__(self) -> int: return len(self.version[0]) - def __getitem__(self, idx): + def __getitem__(self, idx: Union[int, slice]): cls = type(self) release = self.version[0] - if isinstance(idx, numbers.Integral): + if isinstance(idx, int): return release[idx] elif isinstance(idx, slice): @@ -220,45 +347,38 @@ class StandardVersion(ConcreteVersion): if string_arg: string_arg.pop() # We don't need the last separator - string_arg = "".join(string_arg) - return cls.from_string(string_arg) + return cls.from_string("".join(string_arg)) else: return StandardVersion.from_string("") - message = "{cls.__name__} indices must be integers" - raise TypeError(message.format(cls=cls)) + raise TypeError(f"{cls.__name__} indices must be integers or slices") - def __str__(self): - return self.string or _stringify_version(self.version, self.separators) + def __str__(self) -> str: + return self.string def __repr__(self) -> str: # Print indirect repr through Version(...) return f'Version("{str(self)}")' - def __hash__(self): + def __hash__(self) -> int: # If this is a final release, do not hash the prerelease part for backward compat. return hash(self.version if self.is_prerelease() else self.version[0]) - def __contains__(rhs, lhs): + def __contains__(rhs, lhs) -> bool: # We should probably get rid of `x in y` for versions, since # versions still have a dual interpretation as singleton sets # or elements. x in y should be: is the lhs-element in the # rhs-set. Instead this function also does subset checks. - if isinstance(lhs, (StandardVersion, ClosedOpenRange, VersionList)): + if isinstance(lhs, VersionType): return lhs.satisfies(rhs) - raise ValueError(lhs) + raise TypeError(f"'in' not supported for instances of {type(lhs)}") - def intersects(self, other: Union["StandardVersion", "GitVersion", "ClosedOpenRange"]) -> bool: + def intersects(self, other: VersionType) -> bool: if isinstance(other, StandardVersion): return self == other return other.intersects(self) - def overlaps(self, other) -> bool: - return self.intersects(other) - - def satisfies( - self, other: Union["ClosedOpenRange", "StandardVersion", "GitVersion", "VersionList"] - ) -> bool: + def satisfies(self, other: VersionType) -> bool: if isinstance(other, GitVersion): return False @@ -271,19 +391,19 @@ class StandardVersion(ConcreteVersion): if isinstance(other, VersionList): return other.intersects(self) - return NotImplemented + raise NotImplementedError - def union(self, other: Union["ClosedOpenRange", "StandardVersion"]): + def union(self, other: VersionType) -> VersionType: if isinstance(other, StandardVersion): return self if self == other else VersionList([self, other]) return other.union(self) - def intersection(self, other: Union["ClosedOpenRange", "StandardVersion"]): + def intersection(self, other: VersionType) -> VersionType: if isinstance(other, StandardVersion): return self if self == other else VersionList() return other.intersection(self) - def isdevelop(self): + def isdevelop(self) -> bool: """Triggers on the special case of the `@develop-like` version.""" return any( isinstance(p, VersionStrComponent) and isinstance(p.data, int) for p in self.version[0] @@ -304,7 +424,7 @@ class StandardVersion(ConcreteVersion): return ".".join(str(v) for v in numeric) @property - def dotted(self): + def dotted(self) -> "StandardVersion": """The dotted representation of the version. Example: @@ -318,7 +438,7 @@ class StandardVersion(ConcreteVersion): return type(self).from_string(self.string.replace("-", ".").replace("_", ".")) @property - def underscored(self): + def underscored(self) -> "StandardVersion": """The underscored representation of the version. Example: @@ -333,7 +453,7 @@ class StandardVersion(ConcreteVersion): return type(self).from_string(self.string.replace(".", "_").replace("-", "_")) @property - def dashed(self): + def dashed(self) -> "StandardVersion": """The dashed representation of the version. Example: @@ -347,7 +467,7 @@ class StandardVersion(ConcreteVersion): return type(self).from_string(self.string.replace(".", "-").replace("_", "-")) @property - def joined(self): + def joined(self) -> "StandardVersion": """The joined representation of the version. Example: @@ -362,7 +482,7 @@ class StandardVersion(ConcreteVersion): self.string.replace(".", "").replace("-", "").replace("_", "") ) - def up_to(self, index): + def up_to(self, index: int) -> "StandardVersion": """The version up to the specified component. Examples: @@ -482,7 +602,7 @@ class GitVersion(ConcreteVersion): ) return self._ref_version - def intersects(self, other): + def intersects(self, other: VersionType) -> bool: # For concrete things intersects = satisfies = equality if isinstance(other, GitVersion): return self == other @@ -492,19 +612,14 @@ class GitVersion(ConcreteVersion): return self.ref_version.intersects(other) if isinstance(other, VersionList): return any(self.intersects(rhs) for rhs in other) - raise ValueError(f"Unexpected type {type(other)}") + raise TypeError(f"'intersects()' not supported for instances of {type(other)}") - def intersection(self, other): + def intersection(self, other: VersionType) -> VersionType: if isinstance(other, ConcreteVersion): return self if self == other else VersionList() return other.intersection(self) - def overlaps(self, other) -> bool: - return self.intersects(other) - - def satisfies( - self, other: Union["GitVersion", StandardVersion, "ClosedOpenRange", "VersionList"] - ): + def satisfies(self, other: VersionType) -> bool: # Concrete versions mean we have to do an equality check if isinstance(other, GitVersion): return self == other @@ -514,9 +629,9 @@ class GitVersion(ConcreteVersion): return self.ref_version.satisfies(other) if isinstance(other, VersionList): return any(self.satisfies(rhs) for rhs in other) - raise ValueError(f"Unexpected type {type(other)}") + raise TypeError(f"'satisfies()' not supported for instances of {type(other)}") - def __str__(self): + def __str__(self) -> str: s = f"git.{self.ref}" if self.has_git_prefix else self.ref # Note: the solver actually depends on str(...) to produce the effective version. # So when a lookup is attached, we require the resolved version to be printed. @@ -534,7 +649,7 @@ class GitVersion(ConcreteVersion): def __bool__(self): return True - def __eq__(self, other): + def __eq__(self, other: object) -> bool: # GitVersion cannot be equal to StandardVersion, otherwise == is not transitive return ( isinstance(other, GitVersion) @@ -542,10 +657,10 @@ class GitVersion(ConcreteVersion): and self.ref_version == other.ref_version ) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if isinstance(other, GitVersion): return (self.ref_version, self.ref) < (other.ref_version, other.ref) if isinstance(other, StandardVersion): @@ -553,9 +668,9 @@ class GitVersion(ConcreteVersion): return self.ref_version < other if isinstance(other, ClosedOpenRange): return self.ref_version < other - raise ValueError(f"Unexpected type {type(other)}") + raise TypeError(f"'<' not supported between instances of {type(self)} and {type(other)}") - def __le__(self, other): + def __le__(self, other: object) -> bool: if isinstance(other, GitVersion): return (self.ref_version, self.ref) <= (other.ref_version, other.ref) if isinstance(other, StandardVersion): @@ -564,9 +679,9 @@ class GitVersion(ConcreteVersion): if isinstance(other, ClosedOpenRange): # Equality is not a thing return self.ref_version < other - raise ValueError(f"Unexpected type {type(other)}") + raise TypeError(f"'<=' not supported between instances of {type(self)} and {type(other)}") - def __ge__(self, other): + def __ge__(self, other: object) -> bool: if isinstance(other, GitVersion): return (self.ref_version, self.ref) >= (other.ref_version, other.ref) if isinstance(other, StandardVersion): @@ -574,9 +689,9 @@ class GitVersion(ConcreteVersion): return self.ref_version >= other if isinstance(other, ClosedOpenRange): return self.ref_version > other - raise ValueError(f"Unexpected type {type(other)}") + raise TypeError(f"'>=' not supported between instances of {type(self)} and {type(other)}") - def __gt__(self, other): + def __gt__(self, other: object) -> bool: if isinstance(other, GitVersion): return (self.ref_version, self.ref) > (other.ref_version, other.ref) if isinstance(other, StandardVersion): @@ -584,14 +699,14 @@ class GitVersion(ConcreteVersion): return self.ref_version >= other if isinstance(other, ClosedOpenRange): return self.ref_version > other - raise ValueError(f"Unexpected type {type(other)}") + raise TypeError(f"'>' not supported between instances of {type(self)} and {type(other)}") def __hash__(self): # hashing should not cause version lookup return hash(self.ref) - def __contains__(self, other): - raise Exception("Not implemented yet") + def __contains__(self, other: object) -> bool: + raise NotImplementedError @property def ref_lookup(self): @@ -649,7 +764,7 @@ class GitVersion(ConcreteVersion): return self.ref_version.up_to(index) -class ClosedOpenRange: +class ClosedOpenRange(VersionType): def __init__(self, lo: StandardVersion, hi: StandardVersion): if hi < lo: raise EmptyRangeError(f"{lo}..{hi} is an empty range") @@ -657,14 +772,14 @@ class ClosedOpenRange: self.hi: StandardVersion = hi @classmethod - def from_version_range(cls, lo: StandardVersion, hi: StandardVersion): + def from_version_range(cls, lo: StandardVersion, hi: StandardVersion) -> "ClosedOpenRange": """Construct ClosedOpenRange from lo:hi range.""" try: return ClosedOpenRange(lo, _next_version(hi)) except EmptyRangeError as e: raise EmptyRangeError(f"{lo}:{hi} is an empty range") from e - def __str__(self): + def __str__(self) -> str: # This simplifies 3.1:<3.2 to 3.1:3.1 to 3.1 # 3:3 -> 3 hi_prev = _prev_version(self.hi) @@ -726,9 +841,9 @@ class ClosedOpenRange: def __contains__(rhs, lhs): if isinstance(lhs, (ConcreteVersion, ClosedOpenRange, VersionList)): return lhs.satisfies(rhs) - raise ValueError(f"Unexpected type {type(lhs)}") + raise TypeError(f"'in' not supported between instances of {type(rhs)} and {type(lhs)}") - def intersects(self, other: Union[ConcreteVersion, "ClosedOpenRange", "VersionList"]): + def intersects(self, other: VersionType) -> bool: if isinstance(other, StandardVersion): return self.lo <= other < self.hi if isinstance(other, GitVersion): @@ -737,23 +852,18 @@ class ClosedOpenRange: return (self.lo < other.hi) and (other.lo < self.hi) if isinstance(other, VersionList): return any(self.intersects(rhs) for rhs in other) - raise ValueError(f"Unexpected type {type(other)}") + raise TypeError(f"'intersects' not supported for instances of {type(other)}") - def satisfies(self, other: Union["ClosedOpenRange", ConcreteVersion, "VersionList"]): + def satisfies(self, other: VersionType) -> bool: if isinstance(other, ConcreteVersion): return False if isinstance(other, ClosedOpenRange): return not (self.lo < other.lo or other.hi < self.hi) if isinstance(other, VersionList): return any(self.satisfies(rhs) for rhs in other) - raise ValueError(other) - - def overlaps(self, other: Union["ClosedOpenRange", ConcreteVersion, "VersionList"]) -> bool: - return self.intersects(other) + raise TypeError(f"'satisfies()' not supported for instances of {type(other)}") - def _union_if_not_disjoint( - self, other: Union["ClosedOpenRange", ConcreteVersion] - ) -> Optional["ClosedOpenRange"]: + def _union_if_not_disjoint(self, other: VersionType) -> Optional["ClosedOpenRange"]: """Same as union, but returns None when the union is not connected. This function is not implemented for version lists as right-hand side, as that makes little sense.""" if isinstance(other, StandardVersion): @@ -770,9 +880,9 @@ class ClosedOpenRange: else None ) - raise TypeError(f"Unexpected type {type(other)}") + raise TypeError(f"'union()' not supported for instances of {type(other)}") - def union(self, other: Union["ClosedOpenRange", ConcreteVersion, "VersionList"]): + def union(self, other: VersionType) -> VersionType: if isinstance(other, VersionList): v = other.copy() v.add(self) @@ -781,35 +891,51 @@ class ClosedOpenRange: result = self._union_if_not_disjoint(other) return result if result is not None else VersionList([self, other]) - def intersection(self, other: Union["ClosedOpenRange", ConcreteVersion]): + def intersection(self, other: VersionType) -> VersionType: # range - version -> singleton or nothing. + if isinstance(other, ClosedOpenRange): + # range - range -> range or nothing. + max_lo = max(self.lo, other.lo) + min_hi = min(self.hi, other.hi) + return ClosedOpenRange(max_lo, min_hi) if max_lo < min_hi else VersionList() + if isinstance(other, ConcreteVersion): return other if self.intersects(other) else VersionList() - # range - range -> range or nothing. - max_lo = max(self.lo, other.lo) - min_hi = min(self.hi, other.hi) - return ClosedOpenRange(max_lo, min_hi) if max_lo < min_hi else VersionList() + raise TypeError(f"'intersection()' not supported for instances of {type(other)}") -class VersionList: +class VersionList(VersionType): """Sorted, non-redundant list of Version and ClosedOpenRange elements.""" - def __init__(self, vlist=None): - self.versions: List[Union[StandardVersion, GitVersion, ClosedOpenRange]] = [] + versions: List[VersionType] + + def __init__(self, vlist: Optional[Union[str, VersionType, Iterable]] = None): if vlist is None: - pass + self.versions = [] + elif isinstance(vlist, str): vlist = from_string(vlist) if isinstance(vlist, VersionList): self.versions = vlist.versions else: self.versions = [vlist] - else: + + elif isinstance(vlist, (ConcreteVersion, ClosedOpenRange)): + self.versions = [vlist] + + elif isinstance(vlist, VersionList): + self.versions = vlist[:] + + elif isinstance(vlist, Iterable): + self.versions = [] for v in vlist: self.add(ver(v)) - def add(self, item: Union[StandardVersion, GitVersion, ClosedOpenRange, "VersionList"]): + else: + raise TypeError(f"Cannot construct VersionList from {type(vlist)}") + + def add(self, item: VersionType) -> None: if isinstance(item, (StandardVersion, GitVersion)): i = bisect_left(self, item) # Only insert when prev and next are not intersected. @@ -865,7 +991,7 @@ class VersionList: return v.lo return None - def copy(self): + def copy(self) -> "VersionList": return VersionList(self) def lowest(self) -> Optional[StandardVersion]: @@ -889,7 +1015,7 @@ class VersionList: """Get the preferred (latest) version in the list.""" return self.highest_numeric() or self.highest() - def satisfies(self, other) -> bool: + def satisfies(self, other: VersionType) -> bool: # This exploits the fact that version lists are "reduced" and normalized, so we can # never have a list like [1:3, 2:4] since that would be normalized to [1:4] if isinstance(other, VersionList): @@ -898,9 +1024,9 @@ class VersionList: if isinstance(other, (ConcreteVersion, ClosedOpenRange)): return all(lhs.satisfies(other) for lhs in self) - raise ValueError(f"Unsupported type {type(other)}") + raise TypeError(f"'satisfies()' not supported for instances of {type(other)}") - def intersects(self, other): + def intersects(self, other: VersionType) -> bool: if isinstance(other, VersionList): s = o = 0 while s < len(self) and o < len(other): @@ -915,19 +1041,16 @@ class VersionList: if isinstance(other, (ClosedOpenRange, StandardVersion)): return any(v.intersects(other) for v in self) - raise ValueError(f"Unsupported type {type(other)}") + raise TypeError(f"'intersects()' not supported for instances of {type(other)}") - def overlaps(self, other) -> bool: - return self.intersects(other) - - def to_dict(self): + def to_dict(self) -> Dict: """Generate human-readable dict for YAML.""" if self.concrete: return syaml_dict([("version", str(self[0]))]) return syaml_dict([("versions", [str(v) for v in self])]) @staticmethod - def from_dict(dictionary): + def from_dict(dictionary) -> "VersionList": """Parse dict from to_dict.""" if "versions" in dictionary: return VersionList(dictionary["versions"]) @@ -935,27 +1058,29 @@ class VersionList: return VersionList([Version(dictionary["version"])]) raise ValueError("Dict must have 'version' or 'versions' in it.") - def update(self, other: "VersionList"): - for v in other.versions: - self.add(v) + def update(self, other: "VersionList") -> None: + self.add(other) - def union(self, other: "VersionList"): + def union(self, other: VersionType) -> VersionType: result = self.copy() - result.update(other) + result.add(other) return result - def intersection(self, other: "VersionList") -> "VersionList": + def intersection(self, other: VersionType) -> "VersionList": result = VersionList() - for lhs, rhs in ((self, other), (other, self)): - for x in lhs: - i = bisect_left(rhs.versions, x) - if i > 0: - result.add(rhs[i - 1].intersection(x)) - if i < len(rhs): - result.add(rhs[i].intersection(x)) - return result + if isinstance(other, VersionList): + for lhs, rhs in ((self, other), (other, self)): + for x in lhs: + i = bisect_left(rhs.versions, x) + if i > 0: + result.add(rhs[i - 1].intersection(x)) + if i < len(rhs): + result.add(rhs[i].intersection(x)) + return result + else: + return self.intersection(VersionList(other)) - def intersect(self, other) -> bool: + def intersect(self, other: VersionType) -> bool: """Intersect this spec's list with other. Return True if the spec changed as a result; False otherwise @@ -965,6 +1090,7 @@ class VersionList: self.versions = isection.versions return changed + # typing this and getitem are a pain in Python 3.6 def __contains__(self, other): if isinstance(other, (ClosedOpenRange, StandardVersion)): i = bisect_left(self, other) @@ -978,52 +1104,52 @@ class VersionList: def __getitem__(self, index): return self.versions[index] - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.versions) - def __reversed__(self): + def __reversed__(self) -> Iterator: return reversed(self.versions) - def __len__(self): + def __len__(self) -> int: return len(self.versions) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.versions) - def __eq__(self, other): + def __eq__(self, other) -> bool: if isinstance(other, VersionList): return self.versions == other.versions return False - def __ne__(self, other): + def __ne__(self, other) -> bool: if isinstance(other, VersionList): return self.versions != other.versions return False - def __lt__(self, other): + def __lt__(self, other) -> bool: if isinstance(other, VersionList): return self.versions < other.versions return NotImplemented - def __le__(self, other): + def __le__(self, other) -> bool: if isinstance(other, VersionList): return self.versions <= other.versions return NotImplemented - def __ge__(self, other): + def __ge__(self, other) -> bool: if isinstance(other, VersionList): return self.versions >= other.versions return NotImplemented - def __gt__(self, other): + def __gt__(self, other) -> bool: if isinstance(other, VersionList): return self.versions > other.versions return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(tuple(self.versions)) - def __str__(self): + def __str__(self) -> str: if not self.versions: return "" @@ -1031,7 +1157,7 @@ class VersionList: f"={v}" if isinstance(v, StandardVersion) else str(v) for v in self.versions ) - def __repr__(self): + def __repr__(self) -> str: return str(self.versions) @@ -1106,12 +1232,10 @@ def _next_version(v: StandardVersion) -> StandardVersion: release = release[:-1] + (_next_version_str_component(release[-1]),) else: release = release[:-1] + (release[-1] + 1,) - components = [""] * (2 * len(release)) - components[::2] = release - components[1::2] = separators[: len(release)] - if prerelease_type != FINAL: - components.extend((PRERELEASE_TO_STRING[prerelease_type], prerelease[1])) - return StandardVersion("".join(str(c) for c in components), (release, prerelease), separators) + + # Avoid constructing a string here for performance. Instead, pass "" to + # StandardVersion to lazily stringify. + return StandardVersion("", (release, prerelease), separators) def _prev_version(v: StandardVersion) -> StandardVersion: @@ -1130,19 +1254,15 @@ def _prev_version(v: StandardVersion) -> StandardVersion: release = release[:-1] + (_prev_version_str_component(release[-1]),) else: release = release[:-1] + (release[-1] - 1,) - components = [""] * (2 * len(release)) - components[::2] = release - components[1::2] = separators[: len(release)] - if prerelease_type != FINAL: - components.extend((PRERELEASE_TO_STRING[prerelease_type], *prerelease[1:])) - # this is only used for comparison functions, so don't bother making a string - return StandardVersion(None, (release, prerelease), separators) + # Avoid constructing a string here for performance. Instead, pass "" to + # StandardVersion to lazily stringify. + return StandardVersion("", (release, prerelease), separators) -def Version(string: Union[str, int]) -> Union[GitVersion, StandardVersion]: +def Version(string: Union[str, int]) -> ConcreteVersion: if not isinstance(string, (str, int)): - raise ValueError(f"Cannot construct a version from {type(string)}") + raise TypeError(f"Cannot construct a version from {type(string)}") string = str(string) if is_git_version(string): return GitVersion(string) @@ -1155,7 +1275,7 @@ def VersionRange(lo: Union[str, StandardVersion], hi: Union[str, StandardVersion return ClosedOpenRange.from_version_range(lo, hi) -def from_string(string) -> Union[VersionList, ClosedOpenRange, StandardVersion, GitVersion]: +def from_string(string: str) -> VersionType: """Converts a string to a version object. This is private. Client code should use ver().""" string = string.replace(" ", "") @@ -1184,17 +1304,17 @@ def from_string(string) -> Union[VersionList, ClosedOpenRange, StandardVersion, return VersionRange(v, v) -def ver(obj) -> Union[VersionList, ClosedOpenRange, StandardVersion, GitVersion]: +def ver(obj: Union[VersionType, str, list, tuple, int, float]) -> VersionType: """Parses a Version, VersionRange, or VersionList from a string or list of strings. """ - if isinstance(obj, (list, tuple)): - return VersionList(obj) + if isinstance(obj, VersionType): + return obj elif isinstance(obj, str): return from_string(obj) + elif isinstance(obj, (list, tuple)): + return VersionList(obj) elif isinstance(obj, (int, float)): return from_string(str(obj)) - elif isinstance(obj, (StandardVersion, GitVersion, ClosedOpenRange, VersionList)): - return obj else: raise TypeError("ver() can't convert %s to version!" % type(obj)) |