diff options
-rw-r--r-- | var/spack/repos/builtin/packages/py-transformer-engine/package.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-transformer-engine/package.py b/var/spack/repos/builtin/packages/py-transformer-engine/package.py new file mode 100644 index 0000000000..a09e4c1f40 --- /dev/null +++ b/var/spack/repos/builtin/packages/py-transformer-engine/package.py @@ -0,0 +1,48 @@ +# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) + +from spack.package import * + + +class PyTransformerEngine(PythonPackage): + """ + A library for accelerating Transformer models on NVIDIA GPUs, including fp8 precision on Hopper + GPUs. + """ + + homepage = "https://github.com/NVIDIA/TransformerEngine" + url = "https://github.com/NVIDIA/TransformerEngine/archive/refs/tags/v0.0.tar.gz" + git = "https://github.com/NVIDIA/TransformerEngine.git" + maintainers("aurianer") + + license("Apache-2.0") + + version("1.4", tag="v1.4", submodules=True) + version("main", branch="main", submodules=True) + + variant("userbuffers", default=True, description="Enable userbuffers, this option needs MPI.") + + depends_on("py-setuptools", type="build") + depends_on("cmake@3.18:") + depends_on("py-pydantic") + depends_on("py-importlib-metadata") + + with default_args(type=("build", "run")): + depends_on("py-accelerate") + depends_on("py-datasets") + depends_on("py-flash-attn@2.2:2.4.2") + depends_on("py-packaging") + depends_on("py-torchvision") + depends_on("py-transformers") + depends_on("mpi", when="+userbuffers") + + with default_args(type=("build", "link", "run")): + depends_on("py-torch+cuda+cudnn") + + def setup_build_environment(self, env): + env.set("NVTE_FRAMEWORK", "pytorch") + if self.spec.satisfies("+userbuffers"): + env.set("NVTE_WITH_USERBUFFERS", "1") + env.set("MPI_HOME", self.spec["mpi"].prefix) |