summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
authorAuriane R <48684432+aurianer@users.noreply.github.com>2024-05-07 23:56:34 +0200
committerGitHub <noreply@github.com>2024-05-07 14:56:34 -0700
commit84ed4cd331b15c83ea3473de028a9a53068397d4 (patch)
treec704708f5c89cad2f127d9464fc67baa8b064e20 /var
parentf6d50f790ee8b123f7775429f6ca6394170e6de9 (diff)
downloadspack-84ed4cd331b15c83ea3473de028a9a53068397d4.tar.gz
spack-84ed4cd331b15c83ea3473de028a9a53068397d4.tar.bz2
spack-84ed4cd331b15c83ea3473de028a9a53068397d4.tar.xz
spack-84ed4cd331b15c83ea3473de028a9a53068397d4.zip
Add transformer engine package (#43982)
* Add py-flash-attn@2.4.2 * Add py-transfomer-engine package --------- Co-authored-by: Tamara Dahlgren <35777542+tldahlgren@users.noreply.github.com>
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-transformer-engine/package.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-transformer-engine/package.py b/var/spack/repos/builtin/packages/py-transformer-engine/package.py
new file mode 100644
index 0000000000..a09e4c1f40
--- /dev/null
+++ b/var/spack/repos/builtin/packages/py-transformer-engine/package.py
@@ -0,0 +1,48 @@
+# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other
+# Spack Project Developers. See the top-level COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+from spack.package import *
+
+
+class PyTransformerEngine(PythonPackage):
+ """
+ A library for accelerating Transformer models on NVIDIA GPUs, including fp8 precision on Hopper
+ GPUs.
+ """
+
+ homepage = "https://github.com/NVIDIA/TransformerEngine"
+ url = "https://github.com/NVIDIA/TransformerEngine/archive/refs/tags/v0.0.tar.gz"
+ git = "https://github.com/NVIDIA/TransformerEngine.git"
+ maintainers("aurianer")
+
+ license("Apache-2.0")
+
+ version("1.4", tag="v1.4", submodules=True)
+ version("main", branch="main", submodules=True)
+
+ variant("userbuffers", default=True, description="Enable userbuffers, this option needs MPI.")
+
+ depends_on("py-setuptools", type="build")
+ depends_on("cmake@3.18:")
+ depends_on("py-pydantic")
+ depends_on("py-importlib-metadata")
+
+ with default_args(type=("build", "run")):
+ depends_on("py-accelerate")
+ depends_on("py-datasets")
+ depends_on("py-flash-attn@2.2:2.4.2")
+ depends_on("py-packaging")
+ depends_on("py-torchvision")
+ depends_on("py-transformers")
+ depends_on("mpi", when="+userbuffers")
+
+ with default_args(type=("build", "link", "run")):
+ depends_on("py-torch+cuda+cudnn")
+
+ def setup_build_environment(self, env):
+ env.set("NVTE_FRAMEWORK", "pytorch")
+ if self.spec.satisfies("+userbuffers"):
+ env.set("NVTE_WITH_USERBUFFERS", "1")
+ env.set("MPI_HOME", self.spec["mpi"].prefix)