summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSreenivasa Murthy Kolam <67086238+srekolam@users.noreply.github.com>2022-08-19 18:50:38 -0700
committerGitHub <noreply@github.com>2022-08-20 01:50:38 +0000
commit11a4f5e25df1ae137a13437d79d2eefdaf11bbfe (patch)
treec424b728f65bc5dd823571b74010536e1a4217a1
parent5590cad1ef588e6368e1cd301f0b3d81107b54e8 (diff)
downloadspack-11a4f5e25df1ae137a13437d79d2eefdaf11bbfe.tar.gz
spack-11a4f5e25df1ae137a13437d79d2eefdaf11bbfe.tar.bz2
spack-11a4f5e25df1ae137a13437d79d2eefdaf11bbfe.tar.xz
spack-11a4f5e25df1ae137a13437d79d2eefdaf11bbfe.zip
Enable Tensorflow for ROCm. Add ROCm dependencies. (#32248)
* Build Tensorflow using the fork for rocm. Initial commit * re-order the versions * fix style errors * address review comments * add conflicts for rocm version * address review comments * remove rocm variant as its added by ROCmPackage
-rw-r--r--var/spack/repos/builtin/packages/py-tensorflow/package.py34
1 files changed, 31 insertions, 3 deletions
diff --git a/var/spack/repos/builtin/packages/py-tensorflow/package.py b/var/spack/repos/builtin/packages/py-tensorflow/package.py
index 6cce91916e..1c8eb63a1e 100644
--- a/var/spack/repos/builtin/packages/py-tensorflow/package.py
+++ b/var/spack/repos/builtin/packages/py-tensorflow/package.py
@@ -10,7 +10,7 @@ from spack.operating_systems.mac_os import macos_version
from spack.package import *
-class PyTensorflow(Package, CudaPackage):
+class PyTensorflow(Package, CudaPackage, ROCmPackage):
"""An Open Source Machine Learning Framework for Everyone.
TensorFlow is an end-to-end open source platform for machine learning. It has a
@@ -35,6 +35,11 @@ class PyTensorflow(Package, CudaPackage):
version("2.8.2", sha256="b3f860c02c22a30e9787e2548ca252ab289a76b7778af6e9fa763d4aafd904c7")
version("2.8.1", sha256="4b487a63d6f0c1ca46a2ac37ba4687eabdc3a260c222616fa414f6df73228cec")
version("2.8.0", sha256="66b953ae7fba61fd78969a2e24e350b26ec116cf2e6a7eb93d02c63939c6f9f7")
+ version(
+ "2.7.4-rocm-enhanced",
+ sha256="45b79c125edfdc008274f1b150d8b5a53b3ff4713fd1ad1ff4738f515aad8191",
+ url="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/archive/refs/tags/v2.7.4-rocm-enhanced.tar.gz",
+ )
version("2.7.3", sha256="b576c2e124cd6d4d04cbfe985430a0d955614e882172b2258217f0ec9b61f39b")
version("2.7.2", sha256="b3c8577f3b7cc82368ff7f9315821d506abd2f716ea6692977d255b7d8bc54c0")
version("2.7.1", sha256="abebe2cf5ca379e18071693ca5f45b88ade941b16258a21cc1f12d77d5387a21")
@@ -128,7 +133,6 @@ class PyTensorflow(Package, CudaPackage):
variant("ngraph", default=False, description="Build with Intel nGraph support")
variant("opencl", default=False, description="Build with OpenCL SYCL support")
variant("computecpp", default=False, description="Build with ComputeCPP support")
- variant("rocm", default=False, description="Build with ROCm support")
variant("tensorrt", default=False, description="Build with TensorRT support")
variant("cuda", default=sys.platform != "darwin", description="Build with CUDA support")
variant(
@@ -279,6 +283,21 @@ class PyTensorflow(Package, CudaPackage):
# type=('build', 'run'), when='@2.8:')
# depends_on('py-tensorflow-io-gcs-filesystem@0.21:',
# type=('build', 'run'), when='@2.7')
+ with when("+rocm"):
+ depends_on("hip")
+ depends_on("rocrand")
+ depends_on("rocblas")
+ depends_on("rocfft")
+ depends_on("hipfft")
+ depends_on("rccl")
+ depends_on("hipsparse")
+ depends_on("hipcub")
+ depends_on("rocsolver")
+ depends_on("rocprim")
+ depends_on("miopen-hip")
+ depends_on("llvm-amdgpu")
+ depends_on("hsa-rocr-dev")
+ depends_on("rocminfo")
if sys.byteorder == "little":
# Only builds correctly on little-endian machines
@@ -357,7 +376,6 @@ class PyTensorflow(Package, CudaPackage):
conflicts("+opencl", when="@:0.11")
conflicts("+computecpp", when="@:0.11")
conflicts("+computecpp", when="~opencl")
- conflicts("+rocm", when="@:1.11")
conflicts("+cuda", when="platform=darwin", msg="There is no GPU support for macOS")
conflicts(
"cuda_arch=none",
@@ -416,6 +434,8 @@ class PyTensorflow(Package, CudaPackage):
conflicts("platform=darwin target=aarch64:", when="@:2.4")
# https://github.com/tensorflow/tensorflow/pull/39225
conflicts("target=aarch64:", when="@:2.2")
+ conflicts("~rocm", when="@2.7.4-rocm-enhanced")
+ conflicts("+rocm", when="@:2.7.4-a,2.7.4.0:")
# TODO: why is this needed?
patch("url-zlib.patch", when="@0.10.0")
@@ -720,6 +740,11 @@ class PyTensorflow(Package, CudaPackage):
env.set("INCLUDEDIR", spec["protobuf"].prefix.include)
def patch(self):
+ filter_file(
+ '"-U_FORTIFY_SOURCE",',
+ '"-U_FORTIFY_SOURCE", "-I%s",' % self.spec["protobuf"].prefix.include,
+ "third_party/gpus/crosstool/BUILD.rocm.tpl",
+ )
if self.spec.satisfies("@2.3.0:"):
filter_file(
"deps = protodeps + well_known_proto_libs(),",
@@ -976,6 +1001,9 @@ def protobuf_deps():
if "+cuda" in spec:
args.append("--config=cuda")
+ if "+rocm" in spec:
+ args.append("--config=rocm")
+
if "~aws" in spec:
args.append("--config=noaws")