summaryrefslogtreecommitdiff
path: root/lib/spack/spack/util/s3.py
blob: 5457abdca584537b5db73c93051426bea95b13bc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import os
import urllib.error
import urllib.parse
import urllib.request
import urllib.response
from io import BufferedReader, BytesIO, IOBase
from typing import Any, Dict, Tuple

import spack.config

#: Map (mirror name, method) tuples to s3 client instances.
s3_client_cache: Dict[Tuple[str, str], Any] = dict()


def get_s3_session(url, method="fetch"):
    # import boto and friends as late as possible.  We don't want to require boto as a
    # dependency unless the user actually wants to access S3 mirrors.
    from boto3 import Session
    from botocore import UNSIGNED
    from botocore.client import Config
    from botocore.exceptions import ClientError

    # Circular dependency
    from spack.mirror import MirrorCollection

    global s3_client_cache

    # Parse the URL if not already done.
    if not isinstance(url, urllib.parse.ParseResult):
        url = urllib.parse.urlparse(url)
    url_str = url.geturl()

    def get_mirror_url(mirror):
        return mirror.fetch_url if method == "fetch" else mirror.push_url

    # Get all configured mirrors that could match.
    all_mirrors = MirrorCollection()
    mirrors = [
        (name, mirror)
        for name, mirror in all_mirrors.items()
        if url_str.startswith(get_mirror_url(mirror))
    ]

    if not mirrors:
        name, mirror = None, {}
    else:
        # In case we have more than one mirror, we pick the longest matching url.
        # The heuristic being that it's more specific, and you can have different
        # credentials for a sub-bucket (if that is a thing).
        name, mirror = max(
            mirrors, key=lambda name_and_mirror: len(get_mirror_url(name_and_mirror[1]))
        )

    key = (name, method)

    # Did we already create a client for this? Then return it.
    if key in s3_client_cache:
        return s3_client_cache[key]

    # Otherwise, create it.
    s3_connection, s3_client_args = get_mirror_s3_connection_info(mirror, method)

    session = Session(**s3_connection)
    # if no access credentials provided above, then access anonymously
    if not session.get_credentials():
        s3_client_args["config"] = Config(signature_version=UNSIGNED)

    client = session.client("s3", **s3_client_args)
    client.ClientError = ClientError

    # Cache the client.
    s3_client_cache[key] = client
    return client


def _parse_s3_endpoint_url(endpoint_url):
    if not urllib.parse.urlparse(endpoint_url, scheme="").scheme:
        endpoint_url = "://".join(("https", endpoint_url))

    return endpoint_url


def get_mirror_s3_connection_info(mirror, method):
    """Create s3 config for session/client from a Mirror instance (or just set defaults
    when no mirror is given.)"""
    from spack.mirror import Mirror

    s3_connection = {}
    s3_client_args = {"use_ssl": spack.config.get("config:verify_ssl")}

    # access token
    if isinstance(mirror, Mirror):
        access_token = mirror.get_access_token(method)
        if access_token:
            s3_connection["aws_session_token"] = access_token

        # access pair
        access_pair = mirror.get_access_pair(method)
        if access_pair and access_pair[0] and access_pair[1]:
            s3_connection["aws_access_key_id"] = access_pair[0]
            s3_connection["aws_secret_access_key"] = access_pair[1]

        # profile
        profile = mirror.get_profile(method)
        if profile:
            s3_connection["profile_name"] = profile

        # endpoint url
        endpoint_url = mirror.get_endpoint_url(method) or os.environ.get("S3_ENDPOINT_URL")
    else:
        endpoint_url = os.environ.get("S3_ENDPOINT_URL")

    if endpoint_url:
        s3_client_args["endpoint_url"] = _parse_s3_endpoint_url(endpoint_url)

    return s3_connection, s3_client_args


# NOTE(opadron): Workaround issue in boto where its StreamingBody
# implementation is missing several APIs expected from IOBase.  These missing
# APIs prevent the streams returned by boto from being passed as-are along to
# urllib.
#
# https://github.com/boto/botocore/issues/879
# https://github.com/python/cpython/pull/3249
class WrapStream(BufferedReader):
    def __init__(self, raw):
        # In botocore >=1.23.47, StreamingBody inherits from IOBase, so we
        # only add missing attributes in older versions.
        # https://github.com/boto/botocore/commit/a624815eabac50442ed7404f3c4f2664cd0aa784
        if not isinstance(raw, IOBase):
            raw.readable = lambda: True
            raw.writable = lambda: False
            raw.seekable = lambda: False
            raw.closed = False
            raw.flush = lambda: None
        super().__init__(raw)

    def detach(self):
        self.raw = None

    def read(self, *args, **kwargs):
        return self.raw.read(*args, **kwargs)

    def __getattr__(self, key):
        return getattr(self.raw, key)


def _s3_open(url, method="GET"):
    parsed = urllib.parse.urlparse(url)
    s3 = get_s3_session(url, method="fetch")

    bucket = parsed.netloc
    key = parsed.path

    if key.startswith("/"):
        key = key[1:]

    if method not in ("GET", "HEAD"):
        raise urllib.error.URLError(
            "Only GET and HEAD verbs are currently supported for the s3:// scheme"
        )

    try:
        if method == "GET":
            obj = s3.get_object(Bucket=bucket, Key=key)
            # NOTE(opadron): Apply workaround here (see above)
            stream = WrapStream(obj["Body"])
        elif method == "HEAD":
            obj = s3.head_object(Bucket=bucket, Key=key)
            stream = BytesIO()
    except s3.ClientError as e:
        raise urllib.error.URLError(e) from e

    headers = obj["ResponseMetadata"]["HTTPHeaders"]

    return url, headers, stream


class UrllibS3Handler(urllib.request.BaseHandler):
    def s3_open(self, req):
        orig_url = req.get_full_url()
        url, headers, stream = _s3_open(orig_url, method=req.get_method())
        return urllib.response.addinfourl(stream, headers, url)