From 3b4ca0374e605085218a3840bf7a75cc1cc4cac9 Mon Sep 17 00:00:00 2001 From: Massimiliano Culpo Date: Tue, 19 Sep 2023 15:32:59 +0200 Subject: Use process pool executors for web-crawling and retrieving archives (#39888) Fix a race condition when searching urls, and updating a shared set '_visited'. --- lib/spack/spack/binary_distribution.py | 7 +- lib/spack/spack/cmd/buildcache.py | 2 +- lib/spack/spack/cmd/checksum.py | 8 +- lib/spack/spack/cmd/versions.py | 7 +- lib/spack/spack/error.py | 4 + lib/spack/spack/fetch_strategy.py | 28 ++-- lib/spack/spack/package_base.py | 9 +- lib/spack/spack/patch.py | 4 +- lib/spack/spack/stage.py | 176 ++++++++++++--------- lib/spack/spack/test/cmd/spec.py | 3 +- lib/spack/spack/test/conftest.py | 4 +- lib/spack/spack/test/gcs_fetch.py | 4 +- lib/spack/spack/test/packaging.py | 3 +- lib/spack/spack/test/s3_fetch.py | 4 +- lib/spack/spack/test/stage.py | 6 +- lib/spack/spack/test/url_fetch.py | 15 +- lib/spack/spack/test/web.py | 2 +- lib/spack/spack/util/web.py | 280 ++++++++++++++++----------------- 18 files changed, 288 insertions(+), 278 deletions(-) (limited to 'lib') diff --git a/lib/spack/spack/binary_distribution.py b/lib/spack/spack/binary_distribution.py index b85dec9bf6..50043af762 100644 --- a/lib/spack/spack/binary_distribution.py +++ b/lib/spack/spack/binary_distribution.py @@ -34,6 +34,7 @@ from llnl.util.filesystem import BaseDirectoryVisitor, mkdirp, visit_directory_t import spack.cmd import spack.config as config import spack.database as spack_db +import spack.error import spack.hooks import spack.hooks.sbang import spack.mirror @@ -1417,7 +1418,7 @@ def try_fetch(url_to_fetch): try: stage.fetch() - except web_util.FetchError: + except spack.error.FetchError: stage.destroy() return None @@ -2144,7 +2145,7 @@ def get_keys(install=False, trust=False, force=False, mirrors=None): if not os.path.exists(stage.save_filename): try: stage.fetch() - except web_util.FetchError: + except spack.error.FetchError: continue tty.debug("Found key {0}".format(fingerprint)) @@ -2296,7 +2297,7 @@ def _download_buildcache_entry(mirror_root, descriptions): try: stage.fetch() break - except web_util.FetchError as e: + except spack.error.FetchError as e: tty.debug(e) else: if fail_if_missing: diff --git a/lib/spack/spack/cmd/buildcache.py b/lib/spack/spack/cmd/buildcache.py index 17d3f6728c..f611956c36 100644 --- a/lib/spack/spack/cmd/buildcache.py +++ b/lib/spack/spack/cmd/buildcache.py @@ -527,7 +527,7 @@ def copy_buildcache_file(src_url, dest_url, local_path=None): temp_stage.create() temp_stage.fetch() web_util.push_to_url(local_path, dest_url, keep_original=True) - except web_util.FetchError as e: + except spack.error.FetchError as e: # Expected, since we have to try all the possible extensions tty.debug("no such file: {0}".format(src_url)) tty.debug(e) diff --git a/lib/spack/spack/cmd/checksum.py b/lib/spack/spack/cmd/checksum.py index c94756b0c0..a0d6611d94 100644 --- a/lib/spack/spack/cmd/checksum.py +++ b/lib/spack/spack/cmd/checksum.py @@ -66,7 +66,7 @@ def setup_parser(subparser): modes_parser.add_argument( "--verify", action="store_true", default=False, help="verify known package checksums" ) - arguments.add_common_arguments(subparser, ["package"]) + arguments.add_common_arguments(subparser, ["package", "jobs"]) subparser.add_argument( "versions", nargs=argparse.REMAINDER, help="versions to generate checksums for" ) @@ -96,7 +96,7 @@ def checksum(parser, args): # Add latest version if requested if args.latest: - remote_versions = pkg.fetch_remote_versions() + remote_versions = pkg.fetch_remote_versions(args.jobs) if len(remote_versions) > 0: latest_version = sorted(remote_versions.keys(), reverse=True)[0] versions.append(latest_version) @@ -119,13 +119,13 @@ def checksum(parser, args): # if we get here, it's because no valid url was provided by the package # do expensive fallback to try to recover if remote_versions is None: - remote_versions = pkg.fetch_remote_versions() + remote_versions = pkg.fetch_remote_versions(args.jobs) if version in remote_versions: url_dict[version] = remote_versions[version] if len(versions) <= 0: if remote_versions is None: - remote_versions = pkg.fetch_remote_versions() + remote_versions = pkg.fetch_remote_versions(args.jobs) url_dict = remote_versions if not url_dict: diff --git a/lib/spack/spack/cmd/versions.py b/lib/spack/spack/cmd/versions.py index d35a823032..9ac6c9e4da 100644 --- a/lib/spack/spack/cmd/versions.py +++ b/lib/spack/spack/cmd/versions.py @@ -37,10 +37,7 @@ def setup_parser(subparser): action="store_true", help="only list remote versions newer than the latest checksummed version", ) - subparser.add_argument( - "-c", "--concurrency", default=32, type=int, help="number of concurrent requests" - ) - arguments.add_common_arguments(subparser, ["package"]) + arguments.add_common_arguments(subparser, ["package", "jobs"]) def versions(parser, args): @@ -68,7 +65,7 @@ def versions(parser, args): if args.safe: return - fetched_versions = pkg.fetch_remote_versions(args.concurrency) + fetched_versions = pkg.fetch_remote_versions(args.jobs) if args.new: if sys.stdout.isatty(): diff --git a/lib/spack/spack/error.py b/lib/spack/spack/error.py index 33986c9cde..8c9015ed6d 100644 --- a/lib/spack/spack/error.py +++ b/lib/spack/spack/error.py @@ -128,3 +128,7 @@ class UnsatisfiableSpecError(SpecError): self.provided = provided self.required = required self.constraint_type = constraint_type + + +class FetchError(SpackError): + """Superclass for fetch-related errors.""" diff --git a/lib/spack/spack/fetch_strategy.py b/lib/spack/spack/fetch_strategy.py index 87c6e0fc61..90ff8527fd 100644 --- a/lib/spack/spack/fetch_strategy.py +++ b/lib/spack/spack/fetch_strategy.py @@ -401,7 +401,7 @@ class URLFetchStrategy(FetchStrategy): try: web_util.check_curl_code(curl.returncode) - except web_util.FetchError as err: + except spack.error.FetchError as err: raise spack.fetch_strategy.FailedDownloadError(url, str(err)) self._check_headers(headers) @@ -1290,7 +1290,7 @@ class S3FetchStrategy(URLFetchStrategy): parsed_url = urllib.parse.urlparse(self.url) if parsed_url.scheme != "s3": - raise web_util.FetchError("S3FetchStrategy can only fetch from s3:// urls.") + raise spack.error.FetchError("S3FetchStrategy can only fetch from s3:// urls.") tty.debug("Fetching {0}".format(self.url)) @@ -1337,7 +1337,7 @@ class GCSFetchStrategy(URLFetchStrategy): parsed_url = urllib.parse.urlparse(self.url) if parsed_url.scheme != "gs": - raise web_util.FetchError("GCSFetchStrategy can only fetch from gs:// urls.") + raise spack.error.FetchError("GCSFetchStrategy can only fetch from gs:// urls.") tty.debug("Fetching {0}".format(self.url)) @@ -1431,7 +1431,7 @@ def from_kwargs(**kwargs): on attribute names (e.g., ``git``, ``hg``, etc.) Raises: - spack.util.web.FetchError: If no ``fetch_strategy`` matches the args. + spack.error.FetchError: If no ``fetch_strategy`` matches the args. """ for fetcher in all_strategies: if fetcher.matches(kwargs): @@ -1538,7 +1538,7 @@ def for_package_version(pkg, version=None): # if it's a commit, we must use a GitFetchStrategy if isinstance(version, spack.version.GitVersion): if not hasattr(pkg, "git"): - raise web_util.FetchError( + raise spack.error.FetchError( f"Cannot fetch git version for {pkg.name}. Package has no 'git' attribute" ) # Populate the version with comparisons to other commits @@ -1688,11 +1688,11 @@ class FsCache: shutil.rmtree(self.root, ignore_errors=True) -class NoCacheError(web_util.FetchError): +class NoCacheError(spack.error.FetchError): """Raised when there is no cached archive for a package.""" -class FailedDownloadError(web_util.FetchError): +class FailedDownloadError(spack.error.FetchError): """Raised when a download fails.""" def __init__(self, url, msg=""): @@ -1700,23 +1700,23 @@ class FailedDownloadError(web_util.FetchError): self.url = url -class NoArchiveFileError(web_util.FetchError): +class NoArchiveFileError(spack.error.FetchError): """Raised when an archive file is expected but none exists.""" -class NoDigestError(web_util.FetchError): +class NoDigestError(spack.error.FetchError): """Raised after attempt to checksum when URL has no digest.""" -class ExtrapolationError(web_util.FetchError): +class ExtrapolationError(spack.error.FetchError): """Raised when we can't extrapolate a version for a package.""" -class FetcherConflict(web_util.FetchError): +class FetcherConflict(spack.error.FetchError): """Raised for packages with invalid fetch attributes.""" -class InvalidArgsError(web_util.FetchError): +class InvalidArgsError(spack.error.FetchError): """Raised when a version can't be deduced from a set of arguments.""" def __init__(self, pkg=None, version=None, **args): @@ -1729,11 +1729,11 @@ class InvalidArgsError(web_util.FetchError): super().__init__(msg, long_msg) -class ChecksumError(web_util.FetchError): +class ChecksumError(spack.error.FetchError): """Raised when archive fails to checksum.""" -class NoStageError(web_util.FetchError): +class NoStageError(spack.error.FetchError): """Raised when fetch operations are called before set_stage().""" def __init__(self, method): diff --git a/lib/spack/spack/package_base.py b/lib/spack/spack/package_base.py index 67cebb3a8f..940c12c11a 100644 --- a/lib/spack/spack/package_base.py +++ b/lib/spack/spack/package_base.py @@ -66,7 +66,6 @@ from spack.installer import InstallError, PackageInstaller from spack.stage import DIYStage, ResourceStage, Stage, StageComposite, compute_stage_name from spack.util.executable import ProcessError, which from spack.util.package_hash import package_hash -from spack.util.web import FetchError from spack.version import GitVersion, StandardVersion, Version FLAG_HANDLER_RETURN_TYPE = Tuple[ @@ -1394,7 +1393,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta): tty.debug("Fetching with no checksum. {0}".format(ck_msg)) if not ignore_checksum: - raise FetchError( + raise spack.error.FetchError( "Will not fetch %s" % self.spec.format("{name}{@version}"), ck_msg ) @@ -1420,7 +1419,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta): tty.debug("Fetching deprecated version. {0}".format(dp_msg)) if not ignore_deprecation: - raise FetchError( + raise spack.error.FetchError( "Will not fetch {0}".format(self.spec.format("{name}{@version}")), dp_msg ) @@ -1447,7 +1446,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta): self.stage.expand_archive() if not os.listdir(self.stage.path): - raise FetchError("Archive was empty for %s" % self.name) + raise spack.error.FetchError("Archive was empty for %s" % self.name) else: # Support for post-install hooks requires a stage.source_path fsys.mkdirp(self.stage.source_path) @@ -2365,7 +2364,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta): urls.append(args["url"]) return urls - def fetch_remote_versions(self, concurrency=128): + def fetch_remote_versions(self, concurrency=None): """Find remote versions of this package. Uses ``list_url`` and any other URLs listed in the package file. diff --git a/lib/spack/spack/patch.py b/lib/spack/spack/patch.py index 7bbab326d1..23a5ee20a8 100644 --- a/lib/spack/spack/patch.py +++ b/lib/spack/spack/patch.py @@ -76,7 +76,7 @@ class Patch: self.level = level self.working_dir = working_dir - def apply(self, stage: spack.stage.Stage): + def apply(self, stage: "spack.stage.Stage"): """Apply a patch to source in a stage. Arguments: @@ -190,7 +190,7 @@ class UrlPatch(Patch): if not self.sha256: raise PatchDirectiveError("URL patches require a sha256 checksum") - def apply(self, stage: spack.stage.Stage): + def apply(self, stage: "spack.stage.Stage"): assert self.stage.expanded, "Stage must be expanded before applying patches" # Get the patch file. diff --git a/lib/spack/spack/stage.py b/lib/spack/spack/stage.py index 119a81ad9b..065f74eb8b 100644 --- a/lib/spack/spack/stage.py +++ b/lib/spack/spack/stage.py @@ -2,7 +2,7 @@ # Spack Project Developers. See the top-level COPYRIGHT file for details. # # SPDX-License-Identifier: (Apache-2.0 OR MIT) - +import concurrent.futures import errno import getpass import glob @@ -12,7 +12,7 @@ import shutil import stat import sys import tempfile -from typing import Dict, Iterable +from typing import Callable, Dict, Iterable, Optional import llnl.util.lang import llnl.util.tty as tty @@ -37,9 +37,9 @@ import spack.spec import spack.util.lock import spack.util.path as sup import spack.util.pattern as pattern +import spack.util.string import spack.util.url as url_util from spack.util.crypto import bit_length, prefix_bits -from spack.util.web import FetchError # The well-known stage source subdirectory name. _source_path_subdir = "spack-src" @@ -241,10 +241,7 @@ class Stage: similar, and are intended to persist for only one run of spack. """ - """Shared dict of all stage locks.""" - stage_locks: Dict[str, spack.util.lock.Lock] = {} - - """Most staging is managed by Spack. DIYStage is one exception.""" + #: Most staging is managed by Spack. DIYStage is one exception. managed_by_spack = True def __init__( @@ -330,17 +327,12 @@ class Stage: # details on this approach. self._lock = None if lock: - if self.name not in Stage.stage_locks: - sha1 = hashlib.sha1(self.name.encode("utf-8")).digest() - lock_id = prefix_bits(sha1, bit_length(sys.maxsize)) - stage_lock_path = os.path.join(get_stage_root(), ".lock") - - tty.debug("Creating stage lock {0}".format(self.name)) - Stage.stage_locks[self.name] = spack.util.lock.Lock( - stage_lock_path, start=lock_id, length=1, desc=self.name - ) - - self._lock = Stage.stage_locks[self.name] + sha1 = hashlib.sha1(self.name.encode("utf-8")).digest() + lock_id = prefix_bits(sha1, bit_length(sys.maxsize)) + stage_lock_path = os.path.join(get_stage_root(), ".lock") + self._lock = spack.util.lock.Lock( + stage_lock_path, start=lock_id, length=1, desc=self.name + ) # When stages are reused, we need to know whether to re-create # it. This marks whether it has been created/destroyed. @@ -522,7 +514,7 @@ class Stage: self.fetcher = self.default_fetcher default_msg = "All fetchers failed for {0}".format(self.name) - raise FetchError(err_msg or default_msg, None) + raise spack.error.FetchError(err_msg or default_msg, None) print_errors(errors) @@ -868,45 +860,47 @@ def purge(): os.remove(stage_path) -def get_checksums_for_versions(url_dict, name, **kwargs): - """Fetches and checksums archives from URLs. +def get_checksums_for_versions( + url_by_version: Dict[str, str], + package_name: str, + *, + batch: bool = False, + first_stage_function: Optional[Callable[[Stage, str], None]] = None, + keep_stage: bool = False, + concurrency: Optional[int] = None, + fetch_options: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: + """Computes the checksums for each version passed in input, and returns the results. - This function is called by both ``spack checksum`` and ``spack - create``. The ``first_stage_function`` argument allows the caller to - inspect the first downloaded archive, e.g., to determine the build - system. + Archives are fetched according to the usl dictionary passed as input. + + The ``first_stage_function`` argument allows the caller to inspect the first downloaded + archive, e.g., to determine the build system. Args: - url_dict (dict): A dictionary of the form: version -> URL - name (str): The name of the package - first_stage_function (typing.Callable): function that takes a Stage and a URL; - this is run on the stage of the first URL downloaded - keep_stage (bool): whether to keep staging area when command completes - batch (bool): whether to ask user how many versions to fetch (false) - or fetch all versions (true) - fetch_options (dict): Options used for the fetcher (such as timeout - or cookies) + url_by_version: URL keyed by version + package_name: name of the package + first_stage_function: function that takes a Stage and a URL; this is run on the stage + of the first URL downloaded + keep_stage: whether to keep staging area when command completes + batch: whether to ask user how many versions to fetch (false) or fetch all versions (true) + fetch_options: options used for the fetcher (such as timeout or cookies) + concurrency: maximum number of workers to use for retrieving archives Returns: - (dict): A dictionary of the form: version -> checksum - + A dictionary mapping each version to the corresponding checksum """ - batch = kwargs.get("batch", False) - fetch_options = kwargs.get("fetch_options", None) - first_stage_function = kwargs.get("first_stage_function", None) - keep_stage = kwargs.get("keep_stage", False) - - sorted_versions = sorted(url_dict.keys(), reverse=True) + sorted_versions = sorted(url_by_version.keys(), reverse=True) # Find length of longest string in the list for padding max_len = max(len(str(v)) for v in sorted_versions) num_ver = len(sorted_versions) tty.msg( - "Found {0} version{1} of {2}:".format(num_ver, "" if num_ver == 1 else "s", name), + f"Found {spack.util.string.plural(num_ver, 'version')} of {package_name}:", "", *llnl.util.lang.elide_list( - ["{0:{1}} {2}".format(str(v), max_len, url_dict[v]) for v in sorted_versions] + ["{0:{1}} {2}".format(str(v), max_len, url_by_version[v]) for v in sorted_versions] ), ) print() @@ -922,50 +916,76 @@ def get_checksums_for_versions(url_dict, name, **kwargs): tty.die("Aborted.") versions = sorted_versions[:archives_to_fetch] - urls = [url_dict[v] for v in versions] + search_arguments = [(url_by_version[v], v) for v in versions] - tty.debug("Downloading...") - version_hashes = {} - i = 0 - errors = [] - for url, version in zip(urls, versions): - try: - if fetch_options: - url_or_fs = fs.URLFetchStrategy(url, fetch_options=fetch_options) - else: - url_or_fs = url - with Stage(url_or_fs, keep=keep_stage) as stage: - # Fetch the archive - stage.fetch() - if i == 0 and first_stage_function: - # Only run first_stage_function the first time, - # no need to run it every time - first_stage_function(stage, url) - - # Checksum the archive and add it to the list - version_hashes[version] = spack.util.crypto.checksum( - hashlib.sha256, stage.archive_file - ) - i += 1 - except FailedDownloadError: - errors.append("Failed to fetch {0}".format(url)) - except Exception as e: - tty.msg("Something failed on {0}, skipping. ({1})".format(url, e)) + version_hashes, errors = {}, [] + + # Don't spawn 16 processes when we need to fetch 2 urls + if concurrency is not None: + concurrency = min(concurrency, len(search_arguments)) + else: + concurrency = min(os.cpu_count() or 1, len(search_arguments)) + + # The function might have side effects in memory, that would not be reflected in the + # parent process, if run in a child process. If this pattern happens frequently, we + # can move this function call *after* having distributed the work to executors. + if first_stage_function is not None: + (url, version), search_arguments = search_arguments[0], search_arguments[1:] + checksum, error = _fetch_and_checksum(url, fetch_options, keep_stage, first_stage_function) + if error is not None: + errors.append(error) + + if checksum is not None: + version_hashes[version] = checksum + + with concurrent.futures.ProcessPoolExecutor(max_workers=concurrency) as executor: + results = [] + for url, version in search_arguments: + future = executor.submit(_fetch_and_checksum, url, fetch_options, keep_stage) + results.append((version, future)) + + for version, future in results: + checksum, error = future.result() + if error is not None: + errors.append(error) + continue + version_hashes[version] = checksum - for msg in errors: - tty.debug(msg) + for msg in errors: + tty.debug(msg) if not version_hashes: - tty.die("Could not fetch any versions for {0}".format(name)) + tty.die(f"Could not fetch any versions for {package_name}") num_hash = len(version_hashes) - tty.debug( - "Checksummed {0} version{1} of {2}:".format(num_hash, "" if num_hash == 1 else "s", name) - ) + tty.debug(f"Checksummed {num_hash} version{'' if num_hash == 1 else 's'} of {package_name}:") return version_hashes +def _fetch_and_checksum(url, options, keep_stage, action_fn=None): + try: + url_or_fs = url + if options: + url_or_fs = fs.URLFetchStrategy(url, fetch_options=options) + + with Stage(url_or_fs, keep=keep_stage) as stage: + # Fetch the archive + stage.fetch() + if action_fn is not None: + # Only run first_stage_function the first time, + # no need to run it every time + action_fn(stage, url) + + # Checksum the archive and add it to the list + checksum = spack.util.crypto.checksum(hashlib.sha256, stage.archive_file) + return checksum, None + except FailedDownloadError: + return None, f"[WORKER] Failed to fetch {url}" + except Exception as e: + return None, f"[WORKER] Something failed on {url}, skipping. ({e})" + + class StageError(spack.error.SpackError): """ "Superclass for all errors encountered during staging.""" diff --git a/lib/spack/spack/test/cmd/spec.py b/lib/spack/spack/test/cmd/spec.py index fd8fe1beef..66dfce9308 100644 --- a/lib/spack/spack/test/cmd/spec.py +++ b/lib/spack/spack/test/cmd/spec.py @@ -14,7 +14,6 @@ import spack.parser import spack.spec import spack.store from spack.main import SpackCommand, SpackCommandError -from spack.util.web import FetchError pytestmark = pytest.mark.usefixtures("config", "mutable_mock_repo") @@ -208,7 +207,7 @@ def test_env_aware_spec(mutable_mock_env_path): [ ("develop-branch-version", "f3c7206350ac8ee364af687deaae5c574dcfca2c=develop", None), ("develop-branch-version", "git." + "a" * 40 + "=develop", None), - ("callpath", "f3c7206350ac8ee364af687deaae5c574dcfca2c=1.0", FetchError), + ("callpath", "f3c7206350ac8ee364af687deaae5c574dcfca2c=1.0", spack.error.FetchError), ("develop-branch-version", "git.foo=0.2.15", None), ], ) diff --git a/lib/spack/spack/test/conftest.py b/lib/spack/spack/test/conftest.py index bfc03f216f..25417de6f4 100644 --- a/lib/spack/spack/test/conftest.py +++ b/lib/spack/spack/test/conftest.py @@ -36,6 +36,7 @@ import spack.config import spack.database import spack.directory_layout import spack.environment as ev +import spack.error import spack.package_base import spack.package_prefs import spack.paths @@ -52,7 +53,6 @@ import spack.util.spack_yaml as syaml import spack.util.url as url_util from spack.fetch_strategy import URLFetchStrategy from spack.util.pattern import Bunch -from spack.util.web import FetchError def ensure_configuration_fixture_run_before(request): @@ -472,7 +472,7 @@ class MockCache: class MockCacheFetcher: def fetch(self): - raise FetchError("Mock cache always fails for tests") + raise spack.error.FetchError("Mock cache always fails for tests") def __str__(self): return "[mock fetch cache]" diff --git a/lib/spack/spack/test/gcs_fetch.py b/lib/spack/spack/test/gcs_fetch.py index 2122eeb36d..76b9971471 100644 --- a/lib/spack/spack/test/gcs_fetch.py +++ b/lib/spack/spack/test/gcs_fetch.py @@ -8,9 +8,9 @@ import os import pytest import spack.config +import spack.error import spack.fetch_strategy import spack.stage -from spack.util.web import FetchError @pytest.mark.parametrize("_fetch_method", ["curl", "urllib"]) @@ -33,7 +33,7 @@ def test_gcsfetchstrategy_bad_url(tmpdir, _fetch_method): with spack.stage.Stage(fetcher, path=testpath) as stage: assert stage is not None assert fetcher.archive_file is None - with pytest.raises(FetchError): + with pytest.raises(spack.error.FetchError): fetcher.fetch() diff --git a/lib/spack/spack/test/packaging.py b/lib/spack/spack/test/packaging.py index 3ee992128b..bd49cf94ae 100644 --- a/lib/spack/spack/test/packaging.py +++ b/lib/spack/spack/test/packaging.py @@ -20,6 +20,7 @@ from llnl.util.symlink import symlink import spack.binary_distribution as bindist import spack.cmd.buildcache as buildcache +import spack.error import spack.package_base import spack.repo import spack.store @@ -522,7 +523,7 @@ def test_manual_download( monkeypatch.setattr(spack.package_base.PackageBase, "download_instr", _instr) expected = spec.package.download_instr if manual else "All fetchers failed" - with pytest.raises(spack.util.web.FetchError, match=expected): + with pytest.raises(spack.error.FetchError, match=expected): spec.package.do_fetch() diff --git a/lib/spack/spack/test/s3_fetch.py b/lib/spack/spack/test/s3_fetch.py index a495e4d5e8..241d2648b5 100644 --- a/lib/spack/spack/test/s3_fetch.py +++ b/lib/spack/spack/test/s3_fetch.py @@ -8,9 +8,9 @@ import os import pytest import spack.config as spack_config +import spack.error import spack.fetch_strategy as spack_fs import spack.stage as spack_stage -from spack.util.web import FetchError @pytest.mark.parametrize("_fetch_method", ["curl", "urllib"]) @@ -33,7 +33,7 @@ def test_s3fetchstrategy_bad_url(tmpdir, _fetch_method): with spack_stage.Stage(fetcher, path=testpath) as stage: assert stage is not None assert fetcher.archive_file is None - with pytest.raises(FetchError): + with pytest.raises(spack.error.FetchError): fetcher.fetch() diff --git a/lib/spack/spack/test/stage.py b/lib/spack/spack/test/stage.py index dd89edf415..8b2b53dd05 100644 --- a/lib/spack/spack/test/stage.py +++ b/lib/spack/spack/test/stage.py @@ -16,6 +16,7 @@ import pytest from llnl.util.filesystem import getuid, mkdirp, partition_path, touch, working_dir +import spack.error import spack.paths import spack.stage import spack.util.executable @@ -23,7 +24,6 @@ import spack.util.url as url_util from spack.resource import Resource from spack.stage import DIYStage, ResourceStage, Stage, StageComposite from spack.util.path import canonicalize_path -from spack.util.web import FetchError # The following values are used for common fetch and stage mocking fixtures: _archive_base = "test-files" @@ -522,7 +522,7 @@ class TestStage: with stage: try: stage.fetch(mirror_only=True) - except FetchError: + except spack.error.FetchError: pass check_destroy(stage, self.stage_name) @@ -537,7 +537,7 @@ class TestStage: stage = Stage(failing_fetch_strategy, name=self.stage_name, search_fn=search_fn) with stage: - with pytest.raises(FetchError, match=expected): + with pytest.raises(spack.error.FetchError, match=expected): stage.fetch(mirror_only=False, err_msg=err_msg) check_destroy(stage, self.stage_name) diff --git a/lib/spack/spack/test/url_fetch.py b/lib/spack/spack/test/url_fetch.py index fac2bcb16d..a3c0f7c10b 100644 --- a/lib/spack/spack/test/url_fetch.py +++ b/lib/spack/spack/test/url_fetch.py @@ -13,6 +13,7 @@ import llnl.util.tty as tty from llnl.util.filesystem import is_exe, working_dir import spack.config +import spack.error import spack.fetch_strategy as fs import spack.repo import spack.util.crypto as crypto @@ -349,7 +350,7 @@ def test_missing_curl(tmpdir, monkeypatch): def test_url_fetch_text_without_url(tmpdir): - with pytest.raises(web_util.FetchError, match="URL is required"): + with pytest.raises(spack.error.FetchError, match="URL is required"): web_util.fetch_url_text(None) @@ -366,18 +367,18 @@ def test_url_fetch_text_curl_failures(tmpdir, monkeypatch): monkeypatch.setattr(spack.util.web, "which", _which) with spack.config.override("config:url_fetch_method", "curl"): - with pytest.raises(web_util.FetchError, match="Missing required curl"): + with pytest.raises(spack.error.FetchError, match="Missing required curl"): web_util.fetch_url_text("https://github.com/") def test_url_check_curl_errors(): """Check that standard curl error returncodes raise expected errors.""" # Check returncode 22 (i.e., 404) - with pytest.raises(web_util.FetchError, match="not found"): + with pytest.raises(spack.error.FetchError, match="not found"): web_util.check_curl_code(22) # Check returncode 60 (certificate error) - with pytest.raises(web_util.FetchError, match="invalid certificate"): + with pytest.raises(spack.error.FetchError, match="invalid certificate"): web_util.check_curl_code(60) @@ -394,7 +395,7 @@ def test_url_missing_curl(tmpdir, monkeypatch): monkeypatch.setattr(spack.util.web, "which", _which) with spack.config.override("config:url_fetch_method", "curl"): - with pytest.raises(web_util.FetchError, match="Missing required curl"): + with pytest.raises(spack.error.FetchError, match="Missing required curl"): web_util.url_exists("https://github.com/") @@ -409,7 +410,7 @@ def test_url_fetch_text_urllib_bad_returncode(tmpdir, monkeypatch): monkeypatch.setattr(spack.util.web, "read_from_url", _read_from_url) with spack.config.override("config:url_fetch_method", "urllib"): - with pytest.raises(web_util.FetchError, match="failed with error code"): + with pytest.raises(spack.error.FetchError, match="failed with error code"): web_util.fetch_url_text("https://github.com/") @@ -420,5 +421,5 @@ def test_url_fetch_text_urllib_web_error(tmpdir, monkeypatch): monkeypatch.setattr(spack.util.web, "read_from_url", _raise_web_error) with spack.config.override("config:url_fetch_method", "urllib"): - with pytest.raises(web_util.FetchError, match="fetch failed to verify"): + with pytest.raises(spack.error.FetchError, match="fetch failed to verify"): web_util.fetch_url_text("https://github.com/") diff --git a/lib/spack/spack/test/web.py b/lib/spack/spack/test/web.py index a012e7524e..ed4f693c5a 100644 --- a/lib/spack/spack/test/web.py +++ b/lib/spack/spack/test/web.py @@ -98,7 +98,7 @@ def test_spider(depth, expected_found, expected_not_found, expected_text): def test_spider_no_response(monkeypatch): # Mock the absence of a response monkeypatch.setattr(spack.util.web, "read_from_url", lambda x, y: (None, None, None)) - pages, links = spack.util.web.spider(root, depth=0) + pages, links, _, _ = spack.util.web._spider(root, collect_nested=False, _visited=set()) assert not pages and not links diff --git a/lib/spack/spack/util/web.py b/lib/spack/spack/util/web.py index 79ad39ebd7..eca7bd72a2 100644 --- a/lib/spack/spack/util/web.py +++ b/lib/spack/spack/util/web.py @@ -4,9 +4,9 @@ # SPDX-License-Identifier: (Apache-2.0 OR MIT) import codecs +import concurrent.futures import email.message import errno -import multiprocessing.pool import os import os.path import re @@ -17,7 +17,7 @@ import traceback import urllib.parse from html.parser import HTMLParser from pathlib import Path, PurePosixPath -from typing import IO, Optional +from typing import IO, Dict, List, Optional, Set, Union from urllib.error import HTTPError, URLError from urllib.request import HTTPSHandler, Request, build_opener @@ -257,11 +257,11 @@ def check_curl_code(returncode): if returncode != 0: if returncode == 22: # This is a 404. Curl will print the error. - raise FetchError("URL was not found!") + raise spack.error.FetchError("URL was not found!") if returncode == 60: # This is a certificate error. Suggest spack -k - raise FetchError( + raise spack.error.FetchError( "Curl was unable to fetch due to invalid certificate. " "This is either an attack, or your cluster's SSL " "configuration is bad. If you believe your SSL " @@ -270,7 +270,7 @@ def check_curl_code(returncode): "Use this at your own risk." ) - raise FetchError("Curl failed with error {0}".format(returncode)) + raise spack.error.FetchError("Curl failed with error {0}".format(returncode)) def _curl(curl=None): @@ -279,7 +279,7 @@ def _curl(curl=None): curl = which("curl", required=True) except CommandNotFoundError as exc: tty.error(str(exc)) - raise FetchError("Missing required curl fetch method") + raise spack.error.FetchError("Missing required curl fetch method") return curl @@ -307,7 +307,7 @@ def fetch_url_text(url, curl=None, dest_dir="."): Raises FetchError if the curl returncode indicates failure """ if not url: - raise FetchError("A URL is required to fetch its text") + raise spack.error.FetchError("A URL is required to fetch its text") tty.debug("Fetching text at {0}".format(url)) @@ -319,7 +319,7 @@ def fetch_url_text(url, curl=None, dest_dir="."): if fetch_method == "curl": curl_exe = _curl(curl) if not curl_exe: - raise FetchError("Missing required fetch method (curl)") + raise spack.error.FetchError("Missing required fetch method (curl)") curl_args = ["-O"] curl_args.extend(base_curl_fetch_args(url)) @@ -337,7 +337,9 @@ def fetch_url_text(url, curl=None, dest_dir="."): returncode = response.getcode() if returncode and returncode != 200: - raise FetchError("Urllib failed with error code {0}".format(returncode)) + raise spack.error.FetchError( + "Urllib failed with error code {0}".format(returncode) + ) output = codecs.getreader("utf-8")(response).read() if output: @@ -348,7 +350,7 @@ def fetch_url_text(url, curl=None, dest_dir="."): return path except SpackWebError as err: - raise FetchError("Urllib fetch failed to verify url: {0}".format(str(err))) + raise spack.error.FetchError("Urllib fetch failed to verify url: {0}".format(str(err))) return None @@ -543,168 +545,158 @@ def list_url(url, recursive=False): return gcs.get_all_blobs(recursive=recursive) -def spider(root_urls, depth=0, concurrency=32): +def spider(root_urls: Union[str, List[str]], depth: int = 0, concurrency: Optional[int] = None): """Get web pages from root URLs. - If depth is specified (e.g., depth=2), then this will also follow - up to levels of links from each root. + If depth is specified (e.g., depth=2), then this will also follow up to levels + of links from each root. Args: - root_urls (str or list): root urls used as a starting point - for spidering - depth (int): level of recursion into links - concurrency (int): number of simultaneous requests that can be sent + root_urls: root urls used as a starting point for spidering + depth: level of recursion into links + concurrency: number of simultaneous requests that can be sent Returns: - A dict of pages visited (URL) mapped to their full text and the - set of visited links. + A dict of pages visited (URL) mapped to their full text and the set of visited links. """ - # Cache of visited links, meant to be captured by the closure below - _visited = set() - - def _spider(url, collect_nested): - """Fetches URL and any pages it links to. - - Prints out a warning only if the root can't be fetched; it ignores - errors with pages that the root links to. - - Args: - url (str): url being fetched and searched for links - collect_nested (bool): whether we want to collect arguments - for nested spidering on the links found in this url - - Returns: - A tuple of: - - pages: dict of pages visited (URL) mapped to their full text. - - links: set of links encountered while visiting the pages. - - spider_args: argument for subsequent call to spider - """ - pages = {} # dict from page URL -> text content. - links = set() # set of all links seen on visited pages. - subcalls = [] + if isinstance(root_urls, str): + root_urls = [root_urls] - try: - response_url, _, response = read_from_url(url, "text/html") - if not response_url or not response: - return pages, links, subcalls - - page = codecs.getreader("utf-8")(response).read() - pages[response_url] = page - - # Parse out the include-fragments in the page - # https://github.github.io/include-fragment-element - include_fragment_parser = IncludeFragmentParser() - include_fragment_parser.feed(page) - - fragments = set() - while include_fragment_parser.links: - raw_link = include_fragment_parser.links.pop() - abs_link = url_util.join(response_url, raw_link.strip(), resolve_href=True) - - try: - # This seems to be text/html, though text/fragment+html is also used - fragment_response_url, _, fragment_response = read_from_url( - abs_link, "text/html" - ) - except Exception as e: - msg = f"Error reading fragment: {(type(e), str(e))}:{traceback.format_exc()}" - tty.debug(msg) - - if not fragment_response_url or not fragment_response: - continue + current_depth = 0 + pages, links, spider_args = {}, set(), [] - fragment = codecs.getreader("utf-8")(fragment_response).read() - fragments.add(fragment) + _visited: Set[str] = set() + go_deeper = current_depth < depth + for root_str in root_urls: + root = urllib.parse.urlparse(root_str) + spider_args.append((root, go_deeper, _visited)) - pages[fragment_response_url] = fragment + with concurrent.futures.ProcessPoolExecutor(max_workers=concurrency) as tp: + while current_depth <= depth: + tty.debug( + f"SPIDER: [depth={current_depth}, max_depth={depth}, urls={len(spider_args)}]" + ) + results = [tp.submit(_spider, *one_search_args) for one_search_args in spider_args] + spider_args = [] + go_deeper = current_depth < depth + for future in results: + sub_pages, sub_links, sub_spider_args, sub_visited = future.result() + _visited.update(sub_visited) + sub_spider_args = [(x, go_deeper, _visited) for x in sub_spider_args] + pages.update(sub_pages) + links.update(sub_links) + spider_args.extend(sub_spider_args) - # Parse out the links in the page and all fragments - link_parser = LinkParser() - link_parser.feed(page) - for fragment in fragments: - link_parser.feed(fragment) + current_depth += 1 - while link_parser.links: - raw_link = link_parser.links.pop() - abs_link = url_util.join(response_url, raw_link.strip(), resolve_href=True) - links.add(abs_link) + return pages, links - # Skip stuff that looks like an archive - if any(raw_link.endswith(s) for s in llnl.url.ALLOWED_ARCHIVE_TYPES): - continue - # Skip already-visited links - if abs_link in _visited: - continue +def _spider(url: urllib.parse.ParseResult, collect_nested: bool, _visited: Set[str]): + """Fetches URL and any pages it links to. - # If we're not at max depth, follow links. - if collect_nested: - subcalls.append((abs_link,)) - _visited.add(abs_link) + Prints out a warning only if the root can't be fetched; it ignores errors with pages + that the root links to. - except URLError as e: - tty.debug(str(e)) + Args: + url: url being fetched and searched for links + collect_nested: whether we want to collect arguments for nested spidering on the + links found in this url + _visited: links already visited - if hasattr(e, "reason") and isinstance(e.reason, ssl.SSLError): - tty.warn( - "Spack was unable to fetch url list due to a " - "certificate verification problem. You can try " - "running spack -k, which will not check SSL " - "certificates. Use this at your own risk." - ) + Returns: + A tuple of: + - pages: dict of pages visited (URL) mapped to their full text. + - links: set of links encountered while visiting the pages. + - spider_args: argument for subsequent call to spider + - visited: updated set of visited urls + """ + pages: Dict[str, str] = {} # dict from page URL -> text content. + links: Set[str] = set() # set of all links seen on visited pages. + subcalls: List[str] = [] - except HTMLParseError as e: - # This error indicates that Python's HTML parser sucks. - msg = "Got an error parsing HTML." - tty.warn(msg, url, "HTMLParseError: " + str(e)) + try: + response_url, _, response = read_from_url(url, "text/html") + if not response_url or not response: + return pages, links, subcalls, _visited - except Exception as e: - # Other types of errors are completely ignored, - # except in debug mode - tty.debug("Error in _spider: %s:%s" % (type(e), str(e)), traceback.format_exc()) + page = codecs.getreader("utf-8")(response).read() + pages[response_url] = page - finally: - tty.debug("SPIDER: [url={0}]".format(url)) + # Parse out the include-fragments in the page + # https://github.github.io/include-fragment-element + include_fragment_parser = IncludeFragmentParser() + include_fragment_parser.feed(page) - return pages, links, subcalls + fragments = set() + while include_fragment_parser.links: + raw_link = include_fragment_parser.links.pop() + abs_link = url_util.join(response_url, raw_link.strip(), resolve_href=True) - if isinstance(root_urls, str): - root_urls = [root_urls] + try: + # This seems to be text/html, though text/fragment+html is also used + fragment_response_url, _, fragment_response = read_from_url(abs_link, "text/html") + except Exception as e: + msg = f"Error reading fragment: {(type(e), str(e))}:{traceback.format_exc()}" + tty.debug(msg) - # Clear the local cache of visited pages before starting the search - _visited.clear() + if not fragment_response_url or not fragment_response: + continue - current_depth = 0 - pages, links, spider_args = {}, set(), [] + fragment = codecs.getreader("utf-8")(fragment_response).read() + fragments.add(fragment) - collect = current_depth < depth - for root in root_urls: - root = urllib.parse.urlparse(root) - spider_args.append((root, collect)) + pages[fragment_response_url] = fragment - tp = multiprocessing.pool.ThreadPool(processes=concurrency) - try: - while current_depth <= depth: - tty.debug( - "SPIDER: [depth={0}, max_depth={1}, urls={2}]".format( - current_depth, depth, len(spider_args) - ) + # Parse out the links in the page and all fragments + link_parser = LinkParser() + link_parser.feed(page) + for fragment in fragments: + link_parser.feed(fragment) + + while link_parser.links: + raw_link = link_parser.links.pop() + abs_link = url_util.join(response_url, raw_link.strip(), resolve_href=True) + links.add(abs_link) + + # Skip stuff that looks like an archive + if any(raw_link.endswith(s) for s in llnl.url.ALLOWED_ARCHIVE_TYPES): + continue + + # Skip already-visited links + if abs_link in _visited: + continue + + # If we're not at max depth, follow links. + if collect_nested: + subcalls.append(abs_link) + _visited.add(abs_link) + + except URLError as e: + tty.debug(f"[SPIDER] Unable to read: {url}") + tty.debug(str(e), level=2) + if hasattr(e, "reason") and isinstance(e.reason, ssl.SSLError): + tty.warn( + "Spack was unable to fetch url list due to a " + "certificate verification problem. You can try " + "running spack -k, which will not check SSL " + "certificates. Use this at your own risk." ) - results = tp.map(lang.star(_spider), spider_args) - spider_args = [] - collect = current_depth < depth - for sub_pages, sub_links, sub_spider_args in results: - sub_spider_args = [x + (collect,) for x in sub_spider_args] - pages.update(sub_pages) - links.update(sub_links) - spider_args.extend(sub_spider_args) - current_depth += 1 + except HTMLParseError as e: + # This error indicates that Python's HTML parser sucks. + msg = "Got an error parsing HTML." + tty.warn(msg, url, "HTMLParseError: " + str(e)) + + except Exception as e: + # Other types of errors are completely ignored, + # except in debug mode + tty.debug(f"Error in _spider: {type(e)}:{str(e)}", traceback.format_exc()) + finally: - tp.terminate() - tp.join() + tty.debug(f"SPIDER: [url={url}]") - return pages, links + return pages, links, subcalls, _visited def get_header(headers, header_name): @@ -767,10 +759,6 @@ def parse_etag(header_value): return valid.group(1) if valid else None -class FetchError(spack.error.SpackError): - """Superclass for fetch-related errors.""" - - class SpackWebError(spack.error.SpackError): """Superclass for Spack web spidering errors.""" -- cgit v1.2.3-60-g2f50