summaryrefslogtreecommitdiff
path: root/lib/spack/spack/patch.py
blob: cbab403f20a01205b5fce024f3a13e0b2cdd12d7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Copyright 2013-2018 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

import os
import os.path
import inspect
import hashlib

import spack.error
import spack.fetch_strategy as fs
import spack.stage
from spack.util.crypto import checksum, Checker
from llnl.util.filesystem import working_dir
from spack.util.executable import which
from spack.util.compression import allowed_archive


def absolute_path_for_package(pkg):
    """Returns the absolute path to the ``package.py`` file implementing
    the recipe for the package passed as argument.

    Args:
        pkg: a valid package object, or a Dependency object.
    """
    if isinstance(pkg, spack.dependency.Dependency):
        pkg = pkg.pkg
    m = inspect.getmodule(pkg)
    return os.path.abspath(m.__file__)


class Patch(object):
    """Base class to describe a patch that needs to be applied to some
    expanded source code.
    """

    @staticmethod
    def create(pkg, path_or_url, level=1, working_dir=".", **kwargs):
        """
        Factory method that creates an instance of some class derived from
        Patch

        Args:
            pkg: package that needs to be patched
            path_or_url: path or url where the patch is found
            level: patch level (default 1)
            working_dir (str): dir to change to before applying (default '.')

        Returns:
            instance of some Patch class
        """
        # Check if we are dealing with a URL
        if '://' in path_or_url:
            return UrlPatch(path_or_url, level, working_dir, **kwargs)
        # Assume patches are stored in the repository
        return FilePatch(pkg, path_or_url, level, working_dir)

    def __init__(self, path_or_url, level, working_dir):
        # Check on level (must be an integer > 0)
        if not isinstance(level, int) or not level >= 0:
            raise ValueError("Patch level needs to be a non-negative integer.")
        # Attributes shared by all patch subclasses
        self.path_or_url = path_or_url
        self.level = level
        self.working_dir = working_dir
        # self.path needs to be computed by derived classes
        # before a call to apply
        self.path = None

        if not isinstance(self.level, int) or not self.level >= 0:
            raise ValueError("Patch level needs to be a non-negative integer.")

    def apply(self, stage):
        """Apply the patch at self.path to the source code in the
        supplied stage

        Args:
            stage: stage for the package that needs to be patched
        """
        patch = which("patch", required=True)
        with working_dir(stage.source_path):
            # Use -N to allow the same patches to be applied multiple times.
            patch('-s', '-p', str(self.level), '-i', self.path,
                  "-d", self.working_dir)


class FilePatch(Patch):
    """Describes a patch that is retrieved from a file in the repository"""
    def __init__(self, pkg, path_or_url, level, working_dir):
        super(FilePatch, self).__init__(path_or_url, level, working_dir)

        pkg_dir = os.path.dirname(absolute_path_for_package(pkg))
        self.path = os.path.join(pkg_dir, path_or_url)
        if not os.path.isfile(self.path):
            raise NoSuchPatchError(
                "No such patch for package %s: %s" % (pkg.name, self.path))
        self._sha256 = None

    @property
    def sha256(self):
        if self._sha256 is None:
            self._sha256 = checksum(hashlib.sha256, self.path)
        return self._sha256


class UrlPatch(Patch):
    """Describes a patch that is retrieved from a URL"""
    def __init__(self, path_or_url, level, working_dir, **kwargs):
        super(UrlPatch, self).__init__(path_or_url, level, working_dir)
        self.url = path_or_url

        self.archive_sha256 = None
        if allowed_archive(self.url):
            if 'archive_sha256' not in kwargs:
                raise PatchDirectiveError(
                    "Compressed patches require 'archive_sha256' "
                    "and patch 'sha256' attributes: %s" % self.url)
            self.archive_sha256 = kwargs.get('archive_sha256')

        if 'sha256' not in kwargs:
            raise PatchDirectiveError("URL patches require a sha256 checksum")
        self.sha256 = kwargs.get('sha256')

    def apply(self, stage):
        """Retrieve the patch in a temporary stage, computes
        self.path and calls `super().apply(stage)`

        Args:
            stage: stage for the package that needs to be patched
        """
        # use archive digest for compressed archives
        fetch_digest = self.sha256
        if self.archive_sha256:
            fetch_digest = self.archive_sha256

        fetcher = fs.URLFetchStrategy(self.url, fetch_digest)
        mirror = os.path.join(
            os.path.dirname(stage.mirror_path),
            os.path.basename(self.url))

        with spack.stage.Stage(fetcher, mirror_path=mirror) as patch_stage:
            patch_stage.fetch()
            patch_stage.check()
            patch_stage.cache_local()

            root = patch_stage.path
            if self.archive_sha256:
                patch_stage.expand_archive()
                root = patch_stage.source_path

            files = os.listdir(root)
            if not files:
                if self.archive_sha256:
                    raise NoSuchPatchError(
                        "Archive was empty: %s" % self.url)
                else:
                    raise NoSuchPatchError(
                        "Patch failed to download: %s" % self.url)

            self.path = os.path.join(root, files.pop())

            if not os.path.isfile(self.path):
                raise NoSuchPatchError(
                    "Archive %s contains no patch file!" % self.url)

            # for a compressed archive, Need to check the patch sha256 again
            # and the patch is in a directory, not in the same place
            if self.archive_sha256 and spack.config.get('config:checksum'):
                checker = Checker(self.sha256)
                if not checker.check(self.path):
                    raise fs.ChecksumError(
                        "sha256 checksum failed for %s" % self.path,
                        "Expected %s but got %s" % (self.sha256, checker.sum))

            super(UrlPatch, self).apply(stage)


class NoSuchPatchError(spack.error.SpackError):
    """Raised when a patch file doesn't exist."""


class PatchDirectiveError(spack.error.SpackError):
    """Raised when the wrong arguments are suppled to the patch directive."""