From 11a4f5e25df1ae137a13437d79d2eefdaf11bbfe Mon Sep 17 00:00:00 2001 From: Sreenivasa Murthy Kolam <67086238+srekolam@users.noreply.github.com> Date: Fri, 19 Aug 2022 18:50:38 -0700 Subject: Enable Tensorflow for ROCm. Add ROCm dependencies. (#32248) * Build Tensorflow using the fork for rocm. Initial commit * re-order the versions * fix style errors * address review comments * add conflicts for rocm version * address review comments * remove rocm variant as its added by ROCmPackage --- .../builtin/packages/py-tensorflow/package.py | 34 ++++++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) (limited to 'var') diff --git a/var/spack/repos/builtin/packages/py-tensorflow/package.py b/var/spack/repos/builtin/packages/py-tensorflow/package.py index 6cce91916e..1c8eb63a1e 100644 --- a/var/spack/repos/builtin/packages/py-tensorflow/package.py +++ b/var/spack/repos/builtin/packages/py-tensorflow/package.py @@ -10,7 +10,7 @@ from spack.operating_systems.mac_os import macos_version from spack.package import * -class PyTensorflow(Package, CudaPackage): +class PyTensorflow(Package, CudaPackage, ROCmPackage): """An Open Source Machine Learning Framework for Everyone. TensorFlow is an end-to-end open source platform for machine learning. It has a @@ -35,6 +35,11 @@ class PyTensorflow(Package, CudaPackage): version("2.8.2", sha256="b3f860c02c22a30e9787e2548ca252ab289a76b7778af6e9fa763d4aafd904c7") version("2.8.1", sha256="4b487a63d6f0c1ca46a2ac37ba4687eabdc3a260c222616fa414f6df73228cec") version("2.8.0", sha256="66b953ae7fba61fd78969a2e24e350b26ec116cf2e6a7eb93d02c63939c6f9f7") + version( + "2.7.4-rocm-enhanced", + sha256="45b79c125edfdc008274f1b150d8b5a53b3ff4713fd1ad1ff4738f515aad8191", + url="https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/archive/refs/tags/v2.7.4-rocm-enhanced.tar.gz", + ) version("2.7.3", sha256="b576c2e124cd6d4d04cbfe985430a0d955614e882172b2258217f0ec9b61f39b") version("2.7.2", sha256="b3c8577f3b7cc82368ff7f9315821d506abd2f716ea6692977d255b7d8bc54c0") version("2.7.1", sha256="abebe2cf5ca379e18071693ca5f45b88ade941b16258a21cc1f12d77d5387a21") @@ -128,7 +133,6 @@ class PyTensorflow(Package, CudaPackage): variant("ngraph", default=False, description="Build with Intel nGraph support") variant("opencl", default=False, description="Build with OpenCL SYCL support") variant("computecpp", default=False, description="Build with ComputeCPP support") - variant("rocm", default=False, description="Build with ROCm support") variant("tensorrt", default=False, description="Build with TensorRT support") variant("cuda", default=sys.platform != "darwin", description="Build with CUDA support") variant( @@ -279,6 +283,21 @@ class PyTensorflow(Package, CudaPackage): # type=('build', 'run'), when='@2.8:') # depends_on('py-tensorflow-io-gcs-filesystem@0.21:', # type=('build', 'run'), when='@2.7') + with when("+rocm"): + depends_on("hip") + depends_on("rocrand") + depends_on("rocblas") + depends_on("rocfft") + depends_on("hipfft") + depends_on("rccl") + depends_on("hipsparse") + depends_on("hipcub") + depends_on("rocsolver") + depends_on("rocprim") + depends_on("miopen-hip") + depends_on("llvm-amdgpu") + depends_on("hsa-rocr-dev") + depends_on("rocminfo") if sys.byteorder == "little": # Only builds correctly on little-endian machines @@ -357,7 +376,6 @@ class PyTensorflow(Package, CudaPackage): conflicts("+opencl", when="@:0.11") conflicts("+computecpp", when="@:0.11") conflicts("+computecpp", when="~opencl") - conflicts("+rocm", when="@:1.11") conflicts("+cuda", when="platform=darwin", msg="There is no GPU support for macOS") conflicts( "cuda_arch=none", @@ -416,6 +434,8 @@ class PyTensorflow(Package, CudaPackage): conflicts("platform=darwin target=aarch64:", when="@:2.4") # https://github.com/tensorflow/tensorflow/pull/39225 conflicts("target=aarch64:", when="@:2.2") + conflicts("~rocm", when="@2.7.4-rocm-enhanced") + conflicts("+rocm", when="@:2.7.4-a,2.7.4.0:") # TODO: why is this needed? patch("url-zlib.patch", when="@0.10.0") @@ -720,6 +740,11 @@ class PyTensorflow(Package, CudaPackage): env.set("INCLUDEDIR", spec["protobuf"].prefix.include) def patch(self): + filter_file( + '"-U_FORTIFY_SOURCE",', + '"-U_FORTIFY_SOURCE", "-I%s",' % self.spec["protobuf"].prefix.include, + "third_party/gpus/crosstool/BUILD.rocm.tpl", + ) if self.spec.satisfies("@2.3.0:"): filter_file( "deps = protodeps + well_known_proto_libs(),", @@ -976,6 +1001,9 @@ def protobuf_deps(): if "+cuda" in spec: args.append("--config=cuda") + if "+rocm" in spec: + args.append("--config=rocm") + if "~aws" in spec: args.append("--config=noaws") -- cgit v1.2.3-70-g09d2