summaryrefslogtreecommitdiff
path: root/var/spack/repos/builtin/packages/py-jaxlib/package.py
blob: fbd3a13422009d3b7295fa5b3f218a254f89d789 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

import tempfile

from spack.package import *


class PyJaxlib(PythonPackage, CudaPackage):
    """XLA library for Jax"""

    homepage = "https://github.com/google/jax"
    url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz"

    tmp_path = ""
    buildtmp = ""

    license("Apache-2.0")

    version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
    version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d")
    version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35")

    variant("cuda", default=True, description="Build with CUDA")

    # jaxlib/setup.py
    depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
    depends_on("py-setuptools", type="build")
    depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
    depends_on("py-numpy@1.18:", type=("build", "run"))
    depends_on("py-scipy@1.5:", type=("build", "run"))

    # .bazelversion
    depends_on("bazel@5.1.1:5.9", when="@0.3:", type="build")
    # https://github.com/google/jax/issues/8440
    depends_on("bazel@4.1:4", when="@0.1", type="build")

    # README.md
    depends_on("cuda@11.4:", when="@0.4:+cuda")
    depends_on("cuda@11.1:", when="@0.3+cuda")
    # https://github.com/google/jax/issues/12614
    depends_on("cuda@11.1:11.7.0", when="@0.1+cuda")
    depends_on("cudnn@8.2:", when="@0.4:+cuda")
    depends_on("cudnn@8.0.5:", when="+cuda")

    # Historical dependencies
    depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
    depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run"))

    conflicts(
        "cuda_arch=none",
        when="+cuda",
        msg="Must specify CUDA compute capabilities of your GPU, see "
        "https://developer.nvidia.com/cuda-gpus",
    )

    def patch(self):
        self.tmp_path = tempfile.mkdtemp(prefix="spack")
        self.buildtmp = tempfile.mkdtemp(prefix="spack")
        filter_file(
            "build --spawn_strategy=standalone",
            f"""
# Limit CPU workers to spack jobs instead of using all HOST_CPUS.
build --spawn_strategy=standalone
build --local_cpu_resources={make_jobs}
""".strip(),
            ".bazelrc",
            string=True,
        )
        filter_file(
            'f"--output_path={output_path}",',
            'f"--output_path={output_path}",'
            f' "--sources_path={self.tmp_path}",'
            ' "--nohome_rc",'
            ' "--nosystem_rc",'
            f' "--jobs={make_jobs}",',
            "build/build.py",
            string=True,
        )
        filter_file(
            "args = parser.parse_args()",
            "args, junk = parser.parse_known_args()",
            "build/build_wheel.py",
            string=True,
        )

    def install(self, spec, prefix):
        args = []
        args.append("build/build.py")
        if "+cuda" in spec:
            args.append("--enable_cuda")
            args.append("--cuda_path={0}".format(self.spec["cuda"].prefix))
            args.append("--cudnn_path={0}".format(self.spec["cudnn"].prefix))
            capabilities = ",".join(
                "{0:.1f}".format(float(i) / 10.0) for i in spec.variants["cuda_arch"].value
            )
            args.append("--cuda_compute_capabilities={0}".format(capabilities))
        args.append(
            "--bazel_startup_options="
            "--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
        )
        python(*args)
        with working_dir(self.wrapped_package_object.tmp_path):
            args = std_pip_args + ["--prefix=" + self.prefix, "."]
            pip(*args)
        remove_linked_tree(self.wrapped_package_object.tmp_path)
        remove_linked_tree(self.wrapped_package_object.buildtmp)