diff options
author | Auriane R <48684432+aurianer@users.noreply.github.com> | 2024-05-07 23:56:34 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-07 14:56:34 -0700 |
commit | 84ed4cd331b15c83ea3473de028a9a53068397d4 (patch) | |
tree | c704708f5c89cad2f127d9464fc67baa8b064e20 /var | |
parent | f6d50f790ee8b123f7775429f6ca6394170e6de9 (diff) | |
download | spack-84ed4cd331b15c83ea3473de028a9a53068397d4.tar.gz spack-84ed4cd331b15c83ea3473de028a9a53068397d4.tar.bz2 spack-84ed4cd331b15c83ea3473de028a9a53068397d4.tar.xz spack-84ed4cd331b15c83ea3473de028a9a53068397d4.zip |
Add transformer engine package (#43982)
* Add py-flash-attn@2.4.2
* Add py-transfomer-engine package
---------
Co-authored-by: Tamara Dahlgren <35777542+tldahlgren@users.noreply.github.com>
Diffstat (limited to 'var')
-rw-r--r-- | var/spack/repos/builtin/packages/py-transformer-engine/package.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-transformer-engine/package.py b/var/spack/repos/builtin/packages/py-transformer-engine/package.py new file mode 100644 index 0000000000..a09e4c1f40 --- /dev/null +++ b/var/spack/repos/builtin/packages/py-transformer-engine/package.py @@ -0,0 +1,48 @@ +# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) + +from spack.package import * + + +class PyTransformerEngine(PythonPackage): + """ + A library for accelerating Transformer models on NVIDIA GPUs, including fp8 precision on Hopper + GPUs. + """ + + homepage = "https://github.com/NVIDIA/TransformerEngine" + url = "https://github.com/NVIDIA/TransformerEngine/archive/refs/tags/v0.0.tar.gz" + git = "https://github.com/NVIDIA/TransformerEngine.git" + maintainers("aurianer") + + license("Apache-2.0") + + version("1.4", tag="v1.4", submodules=True) + version("main", branch="main", submodules=True) + + variant("userbuffers", default=True, description="Enable userbuffers, this option needs MPI.") + + depends_on("py-setuptools", type="build") + depends_on("cmake@3.18:") + depends_on("py-pydantic") + depends_on("py-importlib-metadata") + + with default_args(type=("build", "run")): + depends_on("py-accelerate") + depends_on("py-datasets") + depends_on("py-flash-attn@2.2:2.4.2") + depends_on("py-packaging") + depends_on("py-torchvision") + depends_on("py-transformers") + depends_on("mpi", when="+userbuffers") + + with default_args(type=("build", "link", "run")): + depends_on("py-torch+cuda+cudnn") + + def setup_build_environment(self, env): + env.set("NVTE_FRAMEWORK", "pytorch") + if self.spec.satisfies("+userbuffers"): + env.set("NVTE_WITH_USERBUFFERS", "1") + env.set("MPI_HOME", self.spec["mpi"].prefix) |