summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHarmen Stoppels <harmenstoppels@gmail.com>2022-12-09 08:50:32 +0100
committerGitHub <noreply@github.com>2022-12-09 08:50:32 +0100
commit7e054cb7fc55c35d50480f04552536eed48f5e21 (patch)
tree3a2ad1f13b11fc448e83f27d4311c89369552839
parentd29cb87ecca1c9270eefa6e77ec15b0546fa7458 (diff)
downloadspack-7e054cb7fc55c35d50480f04552536eed48f5e21.tar.gz
spack-7e054cb7fc55c35d50480f04552536eed48f5e21.tar.bz2
spack-7e054cb7fc55c35d50480f04552536eed48f5e21.tar.xz
spack-7e054cb7fc55c35d50480f04552536eed48f5e21.zip
s3: cache client instance (#34372)
-rw-r--r--lib/spack/spack/s3_handler.py2
-rw-r--r--lib/spack/spack/test/web.py33
-rw-r--r--lib/spack/spack/util/s3.py148
-rw-r--r--lib/spack/spack/util/web.py12
4 files changed, 117 insertions, 78 deletions
diff --git a/lib/spack/spack/s3_handler.py b/lib/spack/spack/s3_handler.py
index 93aea8b160..aee5dc8943 100644
--- a/lib/spack/spack/s3_handler.py
+++ b/lib/spack/spack/s3_handler.py
@@ -44,7 +44,7 @@ class WrapStream(BufferedReader):
def _s3_open(url):
parsed = url_util.parse(url)
- s3 = s3_util.create_s3_session(parsed, connection=s3_util.get_mirror_connection(parsed))
+ s3 = s3_util.get_s3_session(url, method="fetch")
bucket = parsed.netloc
key = parsed.path
diff --git a/lib/spack/spack/test/web.py b/lib/spack/spack/test/web.py
index 21c00e652c..f4114eb05c 100644
--- a/lib/spack/spack/test/web.py
+++ b/lib/spack/spack/test/web.py
@@ -12,6 +12,7 @@ import pytest
import llnl.util.tty as tty
import spack.config
+import spack.mirror
import spack.paths
import spack.util.s3
import spack.util.web
@@ -246,14 +247,24 @@ class MockS3Client(object):
def test_gather_s3_information(monkeypatch, capfd):
- mock_connection_data = {
- "access_token": "AAAAAAA",
- "profile": "SPacKDeV",
- "access_pair": ("SPA", "CK"),
- "endpoint_url": "https://127.0.0.1:8888",
- }
+ mirror = spack.mirror.Mirror.from_dict(
+ {
+ "fetch": {
+ "access_token": "AAAAAAA",
+ "profile": "SPacKDeV",
+ "access_pair": ("SPA", "CK"),
+ "endpoint_url": "https://127.0.0.1:8888",
+ },
+ "push": {
+ "access_token": "AAAAAAA",
+ "profile": "SPacKDeV",
+ "access_pair": ("SPA", "CK"),
+ "endpoint_url": "https://127.0.0.1:8888",
+ },
+ }
+ )
- session_args, client_args = spack.util.s3.get_mirror_s3_connection_info(mock_connection_data)
+ session_args, client_args = spack.util.s3.get_mirror_s3_connection_info(mirror, "push")
# Session args are used to create the S3 Session object
assert "aws_session_token" in session_args
@@ -273,10 +284,10 @@ def test_gather_s3_information(monkeypatch, capfd):
def test_remove_s3_url(monkeypatch, capfd):
fake_s3_url = "s3://my-bucket/subdirectory/mirror"
- def mock_create_s3_session(url, connection={}):
+ def get_s3_session(url, method="fetch"):
return MockS3Client()
- monkeypatch.setattr(spack.util.s3, "create_s3_session", mock_create_s3_session)
+ monkeypatch.setattr(spack.util.s3, "get_s3_session", get_s3_session)
current_debug_level = tty.debug_level()
tty.set_debug(1)
@@ -292,10 +303,10 @@ def test_remove_s3_url(monkeypatch, capfd):
def test_s3_url_exists(monkeypatch, capfd):
- def mock_create_s3_session(url, connection={}):
+ def get_s3_session(url, method="fetch"):
return MockS3Client()
- monkeypatch.setattr(spack.util.s3, "create_s3_session", mock_create_s3_session)
+ monkeypatch.setattr(spack.util.s3, "get_s3_session", get_s3_session)
fake_s3_url_exists = "s3://my-bucket/subdirectory/my-file"
assert spack.util.web.url_exists(fake_s3_url_exists)
diff --git a/lib/spack/spack/util/s3.py b/lib/spack/spack/util/s3.py
index 06eeab3936..462afd05ec 100644
--- a/lib/spack/spack/util/s3.py
+++ b/lib/spack/spack/util/s3.py
@@ -4,83 +4,115 @@
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import os
import urllib.parse
+from typing import Any, Dict, Tuple
import spack
+import spack.config
import spack.util.url as url_util
+#: Map (mirror name, method) tuples to s3 client instances.
+s3_client_cache: Dict[Tuple[str, str], Any] = dict()
-def get_mirror_connection(url, url_type="push"):
- connection = {}
- # Try to find a mirror for potential connection information
- # Check to see if desired file starts with any of the mirror URLs
- rebuilt_path = url_util.format(url)
- # Gather dict of push URLS point to the value of the whole mirror
- mirror_dict = {x.push_url: x for x in spack.mirror.MirrorCollection().values()}
- # Ensure most specific URLs (longest) are presented first
- mirror_url_keys = mirror_dict.keys()
- mirror_url_keys = sorted(mirror_url_keys, key=len, reverse=True)
- for mURL in mirror_url_keys:
- # See if desired URL starts with the mirror's push URL
- if rebuilt_path.startswith(mURL):
- connection = mirror_dict[mURL].to_dict()[url_type]
- break
- return connection
+def get_s3_session(url, method="fetch"):
+ # import boto and friends as late as possible. We don't want to require boto as a
+ # dependency unless the user actually wants to access S3 mirrors.
+ from boto3 import Session
+ from botocore import UNSIGNED
+ from botocore.client import Config
+ from botocore.exceptions import ClientError
-def _parse_s3_endpoint_url(endpoint_url):
- if not urllib.parse.urlparse(endpoint_url, scheme="").scheme:
- endpoint_url = "://".join(("https", endpoint_url))
-
- return endpoint_url
+ # Circular dependency
+ from spack.mirror import MirrorCollection
+ global s3_client_cache
-def get_mirror_s3_connection_info(connection):
- s3_connection = {}
-
- s3_connection_is_dict = connection and isinstance(connection, dict)
- if s3_connection_is_dict:
- if connection.get("access_token"):
- s3_connection["aws_session_token"] = connection["access_token"]
- if connection.get("access_pair"):
- s3_connection["aws_access_key_id"] = connection["access_pair"][0]
- s3_connection["aws_secret_access_key"] = connection["access_pair"][1]
- if connection.get("profile"):
- s3_connection["profile_name"] = connection["profile"]
-
- s3_client_args = {"use_ssl": spack.config.get("config:verify_ssl")}
-
- endpoint_url = os.environ.get("S3_ENDPOINT_URL")
- if endpoint_url:
- s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(endpoint_url)
- elif s3_connection_is_dict and connection.get("endpoint_url"):
- s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(connection["endpoint_url"])
-
- return (s3_connection, s3_client_args)
-
-
-def create_s3_session(url, connection={}):
+ # Get a (recycled) s3 session for a particular URL
url = url_util.parse(url)
- if url.scheme != "s3":
- raise ValueError(
- "Can not create S3 session from URL with scheme: {SCHEME}".format(SCHEME=url.scheme)
+
+ url_str = url_util.format(url)
+
+ def get_mirror_url(mirror):
+ return mirror.fetch_url if method == "fetch" else mirror.push_url
+
+ # Get all configured mirrors that could match.
+ all_mirrors = MirrorCollection()
+ mirrors = [
+ (name, mirror)
+ for name, mirror in all_mirrors.items()
+ if url_str.startswith(get_mirror_url(mirror))
+ ]
+
+ if not mirrors:
+ name, mirror = None, {}
+ else:
+ # In case we have more than one mirror, we pick the longest matching url.
+ # The heuristic being that it's more specific, and you can have different
+ # credentials for a sub-bucket (if that is a thing).
+ name, mirror = max(
+ mirrors, key=lambda name_and_mirror: len(get_mirror_url(name_and_mirror[1]))
)
- # NOTE(opadron): import boto and friends as late as possible. We don't
- # want to require boto as a dependency unless the user actually wants to
- # access S3 mirrors.
- from boto3 import Session # type: ignore[import]
- from botocore.exceptions import ClientError # type: ignore[import]
+ key = (name, method)
+
+ # Did we already create a client for this? Then return it.
+ if key in s3_client_cache:
+ return s3_client_cache[key]
- s3_connection, s3_client_args = get_mirror_s3_connection_info(connection)
+ # Otherwise, create it.
+ s3_connection, s3_client_args = get_mirror_s3_connection_info(mirror, method)
session = Session(**s3_connection)
# if no access credentials provided above, then access anonymously
if not session.get_credentials():
- from botocore import UNSIGNED # type: ignore[import]
- from botocore.client import Config # type: ignore[import]
-
s3_client_args["config"] = Config(signature_version=UNSIGNED)
client = session.client("s3", **s3_client_args)
client.ClientError = ClientError
+
+ # Cache the client.
+ s3_client_cache[key] = client
return client
+
+
+def _parse_s3_endpoint_url(endpoint_url):
+ if not urllib.parse.urlparse(endpoint_url, scheme="").scheme:
+ endpoint_url = "://".join(("https", endpoint_url))
+
+ return endpoint_url
+
+
+def get_mirror_s3_connection_info(mirror, method):
+ """Create s3 config for session/client from a Mirror instance (or just set defaults
+ when no mirror is given.)"""
+ from spack.mirror import Mirror
+
+ s3_connection = {}
+ s3_client_args = {"use_ssl": spack.config.get("config:verify_ssl")}
+
+ # access token
+ if isinstance(mirror, Mirror):
+ access_token = mirror.get_access_token(method)
+ if access_token:
+ s3_connection["aws_session_token"] = access_token
+
+ # access pair
+ access_pair = mirror.get_access_pair(method)
+ if access_pair and access_pair[0] and access_pair[1]:
+ s3_connection["aws_access_key_id"] = access_pair[0]
+ s3_connection["aws_secret_access_key"] = access_pair[1]
+
+ # profile
+ profile = mirror.get_profile(method)
+ if profile:
+ s3_connection["profile_name"] = profile
+
+ # endpoint url
+ endpoint_url = mirror.get_endpoint_url(method) or os.environ.get("S3_ENDPOINT_URL")
+ else:
+ endpoint_url = os.environ.get("S3_ENDPOINT_URL")
+
+ if endpoint_url:
+ s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(endpoint_url)
+
+ return (s3_connection, s3_client_args)
diff --git a/lib/spack/spack/util/web.py b/lib/spack/spack/util/web.py
index 5aa63c4bb2..1f2c197460 100644
--- a/lib/spack/spack/util/web.py
+++ b/lib/spack/spack/util/web.py
@@ -175,9 +175,7 @@ def push_to_url(local_file_path, remote_path, keep_original=True, extra_args=Non
while remote_path.startswith("/"):
remote_path = remote_path[1:]
- s3 = s3_util.create_s3_session(
- remote_url, connection=s3_util.get_mirror_connection(remote_url)
- )
+ s3 = s3_util.get_s3_session(remote_url, method="push")
s3.upload_file(local_file_path, remote_url.netloc, remote_path, ExtraArgs=extra_args)
if not keep_original:
@@ -377,9 +375,7 @@ def url_exists(url, curl=None):
# Check if Amazon Simple Storage Service (S3) .. urllib-based fetch
if url_result.scheme == "s3":
# Check for URL-specific connection information
- s3 = s3_util.create_s3_session(
- url_result, connection=s3_util.get_mirror_connection(url_result)
- ) # noqa: E501
+ s3 = s3_util.get_s3_session(url_result, method="fetch")
try:
s3.get_object(Bucket=url_result.netloc, Key=url_result.path.lstrip("/"))
@@ -441,7 +437,7 @@ def remove_url(url, recursive=False):
if url.scheme == "s3":
# Try to find a mirror for potential connection information
- s3 = s3_util.create_s3_session(url, connection=s3_util.get_mirror_connection(url))
+ s3 = s3_util.get_s3_session(url, method="push")
bucket = url.netloc
if recursive:
# Because list_objects_v2 can only return up to 1000 items
@@ -551,7 +547,7 @@ def list_url(url, recursive=False):
]
if url.scheme == "s3":
- s3 = s3_util.create_s3_session(url, connection=s3_util.get_mirror_connection(url))
+ s3 = s3_util.get_s3_session(url, method="fetch")
if recursive:
return list(_iter_s3_prefix(s3, url))