From 7e054cb7fc55c35d50480f04552536eed48f5e21 Mon Sep 17 00:00:00 2001 From: Harmen Stoppels Date: Fri, 9 Dec 2022 08:50:32 +0100 Subject: s3: cache client instance (#34372) --- lib/spack/spack/s3_handler.py | 2 +- lib/spack/spack/test/web.py | 33 ++++++---- lib/spack/spack/util/s3.py | 148 +++++++++++++++++++++++++----------------- lib/spack/spack/util/web.py | 12 ++-- 4 files changed, 117 insertions(+), 78 deletions(-) diff --git a/lib/spack/spack/s3_handler.py b/lib/spack/spack/s3_handler.py index 93aea8b160..aee5dc8943 100644 --- a/lib/spack/spack/s3_handler.py +++ b/lib/spack/spack/s3_handler.py @@ -44,7 +44,7 @@ class WrapStream(BufferedReader): def _s3_open(url): parsed = url_util.parse(url) - s3 = s3_util.create_s3_session(parsed, connection=s3_util.get_mirror_connection(parsed)) + s3 = s3_util.get_s3_session(url, method="fetch") bucket = parsed.netloc key = parsed.path diff --git a/lib/spack/spack/test/web.py b/lib/spack/spack/test/web.py index 21c00e652c..f4114eb05c 100644 --- a/lib/spack/spack/test/web.py +++ b/lib/spack/spack/test/web.py @@ -12,6 +12,7 @@ import pytest import llnl.util.tty as tty import spack.config +import spack.mirror import spack.paths import spack.util.s3 import spack.util.web @@ -246,14 +247,24 @@ class MockS3Client(object): def test_gather_s3_information(monkeypatch, capfd): - mock_connection_data = { - "access_token": "AAAAAAA", - "profile": "SPacKDeV", - "access_pair": ("SPA", "CK"), - "endpoint_url": "https://127.0.0.1:8888", - } + mirror = spack.mirror.Mirror.from_dict( + { + "fetch": { + "access_token": "AAAAAAA", + "profile": "SPacKDeV", + "access_pair": ("SPA", "CK"), + "endpoint_url": "https://127.0.0.1:8888", + }, + "push": { + "access_token": "AAAAAAA", + "profile": "SPacKDeV", + "access_pair": ("SPA", "CK"), + "endpoint_url": "https://127.0.0.1:8888", + }, + } + ) - session_args, client_args = spack.util.s3.get_mirror_s3_connection_info(mock_connection_data) + session_args, client_args = spack.util.s3.get_mirror_s3_connection_info(mirror, "push") # Session args are used to create the S3 Session object assert "aws_session_token" in session_args @@ -273,10 +284,10 @@ def test_gather_s3_information(monkeypatch, capfd): def test_remove_s3_url(monkeypatch, capfd): fake_s3_url = "s3://my-bucket/subdirectory/mirror" - def mock_create_s3_session(url, connection={}): + def get_s3_session(url, method="fetch"): return MockS3Client() - monkeypatch.setattr(spack.util.s3, "create_s3_session", mock_create_s3_session) + monkeypatch.setattr(spack.util.s3, "get_s3_session", get_s3_session) current_debug_level = tty.debug_level() tty.set_debug(1) @@ -292,10 +303,10 @@ def test_remove_s3_url(monkeypatch, capfd): def test_s3_url_exists(monkeypatch, capfd): - def mock_create_s3_session(url, connection={}): + def get_s3_session(url, method="fetch"): return MockS3Client() - monkeypatch.setattr(spack.util.s3, "create_s3_session", mock_create_s3_session) + monkeypatch.setattr(spack.util.s3, "get_s3_session", get_s3_session) fake_s3_url_exists = "s3://my-bucket/subdirectory/my-file" assert spack.util.web.url_exists(fake_s3_url_exists) diff --git a/lib/spack/spack/util/s3.py b/lib/spack/spack/util/s3.py index 06eeab3936..462afd05ec 100644 --- a/lib/spack/spack/util/s3.py +++ b/lib/spack/spack/util/s3.py @@ -4,83 +4,115 @@ # SPDX-License-Identifier: (Apache-2.0 OR MIT) import os import urllib.parse +from typing import Any, Dict, Tuple import spack +import spack.config import spack.util.url as url_util +#: Map (mirror name, method) tuples to s3 client instances. +s3_client_cache: Dict[Tuple[str, str], Any] = dict() -def get_mirror_connection(url, url_type="push"): - connection = {} - # Try to find a mirror for potential connection information - # Check to see if desired file starts with any of the mirror URLs - rebuilt_path = url_util.format(url) - # Gather dict of push URLS point to the value of the whole mirror - mirror_dict = {x.push_url: x for x in spack.mirror.MirrorCollection().values()} - # Ensure most specific URLs (longest) are presented first - mirror_url_keys = mirror_dict.keys() - mirror_url_keys = sorted(mirror_url_keys, key=len, reverse=True) - for mURL in mirror_url_keys: - # See if desired URL starts with the mirror's push URL - if rebuilt_path.startswith(mURL): - connection = mirror_dict[mURL].to_dict()[url_type] - break - return connection +def get_s3_session(url, method="fetch"): + # import boto and friends as late as possible. We don't want to require boto as a + # dependency unless the user actually wants to access S3 mirrors. + from boto3 import Session + from botocore import UNSIGNED + from botocore.client import Config + from botocore.exceptions import ClientError -def _parse_s3_endpoint_url(endpoint_url): - if not urllib.parse.urlparse(endpoint_url, scheme="").scheme: - endpoint_url = "://".join(("https", endpoint_url)) - - return endpoint_url + # Circular dependency + from spack.mirror import MirrorCollection + global s3_client_cache -def get_mirror_s3_connection_info(connection): - s3_connection = {} - - s3_connection_is_dict = connection and isinstance(connection, dict) - if s3_connection_is_dict: - if connection.get("access_token"): - s3_connection["aws_session_token"] = connection["access_token"] - if connection.get("access_pair"): - s3_connection["aws_access_key_id"] = connection["access_pair"][0] - s3_connection["aws_secret_access_key"] = connection["access_pair"][1] - if connection.get("profile"): - s3_connection["profile_name"] = connection["profile"] - - s3_client_args = {"use_ssl": spack.config.get("config:verify_ssl")} - - endpoint_url = os.environ.get("S3_ENDPOINT_URL") - if endpoint_url: - s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(endpoint_url) - elif s3_connection_is_dict and connection.get("endpoint_url"): - s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(connection["endpoint_url"]) - - return (s3_connection, s3_client_args) - - -def create_s3_session(url, connection={}): + # Get a (recycled) s3 session for a particular URL url = url_util.parse(url) - if url.scheme != "s3": - raise ValueError( - "Can not create S3 session from URL with scheme: {SCHEME}".format(SCHEME=url.scheme) + + url_str = url_util.format(url) + + def get_mirror_url(mirror): + return mirror.fetch_url if method == "fetch" else mirror.push_url + + # Get all configured mirrors that could match. + all_mirrors = MirrorCollection() + mirrors = [ + (name, mirror) + for name, mirror in all_mirrors.items() + if url_str.startswith(get_mirror_url(mirror)) + ] + + if not mirrors: + name, mirror = None, {} + else: + # In case we have more than one mirror, we pick the longest matching url. + # The heuristic being that it's more specific, and you can have different + # credentials for a sub-bucket (if that is a thing). + name, mirror = max( + mirrors, key=lambda name_and_mirror: len(get_mirror_url(name_and_mirror[1])) ) - # NOTE(opadron): import boto and friends as late as possible. We don't - # want to require boto as a dependency unless the user actually wants to - # access S3 mirrors. - from boto3 import Session # type: ignore[import] - from botocore.exceptions import ClientError # type: ignore[import] + key = (name, method) + + # Did we already create a client for this? Then return it. + if key in s3_client_cache: + return s3_client_cache[key] - s3_connection, s3_client_args = get_mirror_s3_connection_info(connection) + # Otherwise, create it. + s3_connection, s3_client_args = get_mirror_s3_connection_info(mirror, method) session = Session(**s3_connection) # if no access credentials provided above, then access anonymously if not session.get_credentials(): - from botocore import UNSIGNED # type: ignore[import] - from botocore.client import Config # type: ignore[import] - s3_client_args["config"] = Config(signature_version=UNSIGNED) client = session.client("s3", **s3_client_args) client.ClientError = ClientError + + # Cache the client. + s3_client_cache[key] = client return client + + +def _parse_s3_endpoint_url(endpoint_url): + if not urllib.parse.urlparse(endpoint_url, scheme="").scheme: + endpoint_url = "://".join(("https", endpoint_url)) + + return endpoint_url + + +def get_mirror_s3_connection_info(mirror, method): + """Create s3 config for session/client from a Mirror instance (or just set defaults + when no mirror is given.)""" + from spack.mirror import Mirror + + s3_connection = {} + s3_client_args = {"use_ssl": spack.config.get("config:verify_ssl")} + + # access token + if isinstance(mirror, Mirror): + access_token = mirror.get_access_token(method) + if access_token: + s3_connection["aws_session_token"] = access_token + + # access pair + access_pair = mirror.get_access_pair(method) + if access_pair and access_pair[0] and access_pair[1]: + s3_connection["aws_access_key_id"] = access_pair[0] + s3_connection["aws_secret_access_key"] = access_pair[1] + + # profile + profile = mirror.get_profile(method) + if profile: + s3_connection["profile_name"] = profile + + # endpoint url + endpoint_url = mirror.get_endpoint_url(method) or os.environ.get("S3_ENDPOINT_URL") + else: + endpoint_url = os.environ.get("S3_ENDPOINT_URL") + + if endpoint_url: + s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(endpoint_url) + + return (s3_connection, s3_client_args) diff --git a/lib/spack/spack/util/web.py b/lib/spack/spack/util/web.py index 5aa63c4bb2..1f2c197460 100644 --- a/lib/spack/spack/util/web.py +++ b/lib/spack/spack/util/web.py @@ -175,9 +175,7 @@ def push_to_url(local_file_path, remote_path, keep_original=True, extra_args=Non while remote_path.startswith("/"): remote_path = remote_path[1:] - s3 = s3_util.create_s3_session( - remote_url, connection=s3_util.get_mirror_connection(remote_url) - ) + s3 = s3_util.get_s3_session(remote_url, method="push") s3.upload_file(local_file_path, remote_url.netloc, remote_path, ExtraArgs=extra_args) if not keep_original: @@ -377,9 +375,7 @@ def url_exists(url, curl=None): # Check if Amazon Simple Storage Service (S3) .. urllib-based fetch if url_result.scheme == "s3": # Check for URL-specific connection information - s3 = s3_util.create_s3_session( - url_result, connection=s3_util.get_mirror_connection(url_result) - ) # noqa: E501 + s3 = s3_util.get_s3_session(url_result, method="fetch") try: s3.get_object(Bucket=url_result.netloc, Key=url_result.path.lstrip("/")) @@ -441,7 +437,7 @@ def remove_url(url, recursive=False): if url.scheme == "s3": # Try to find a mirror for potential connection information - s3 = s3_util.create_s3_session(url, connection=s3_util.get_mirror_connection(url)) + s3 = s3_util.get_s3_session(url, method="push") bucket = url.netloc if recursive: # Because list_objects_v2 can only return up to 1000 items @@ -551,7 +547,7 @@ def list_url(url, recursive=False): ] if url.scheme == "s3": - s3 = s3_util.create_s3_session(url, connection=s3_util.get_mirror_connection(url)) + s3 = s3_util.get_s3_session(url, method="fetch") if recursive: return list(_iter_s3_prefix(s3, url)) -- cgit v1.2.3-60-g2f50