From 525809632ef3f198c182c9350db2f7c5d3afd832 Mon Sep 17 00:00:00 2001
From: Harmen Stoppels <me@harmenstoppels.nl>
Date: Mon, 11 Dec 2023 10:30:14 +0100
Subject: petsc: improve hipsparse compat (#40311)

Co-authored-by: Satish Balay <balay@mcs.anl.gov>
---
 .../builtin/packages/petsc/hip-5.6.0-for-3.18.diff | 85 ++++++++++++++++++++++
 .../packages/petsc/hip-5.7-plus-for-3.18.diff      | 20 +++++
 var/spack/repos/builtin/packages/petsc/package.py  | 37 ++++++----
 3 files changed, 127 insertions(+), 15 deletions(-)
 create mode 100644 var/spack/repos/builtin/packages/petsc/hip-5.6.0-for-3.18.diff
 create mode 100644 var/spack/repos/builtin/packages/petsc/hip-5.7-plus-for-3.18.diff

diff --git a/var/spack/repos/builtin/packages/petsc/hip-5.6.0-for-3.18.diff b/var/spack/repos/builtin/packages/petsc/hip-5.6.0-for-3.18.diff
new file mode 100644
index 0000000000..c587e3451d
--- /dev/null
+++ b/var/spack/repos/builtin/packages/petsc/hip-5.6.0-for-3.18.diff
@@ -0,0 +1,85 @@
+commit 9b52b1224039b470f0f450943ce503af1df37b00
+Author: Satish Balay <balay@mcs.anl.gov>
+Date:   Fri Oct 6 15:19:34 2023 -0500
+
+    hip-6.0 fix
+
+diff --git a/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cpp b/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cpp
+index e6be2076975..0c388c90ca3 100644
+--- a/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cpp
++++ b/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cpp
+@@ -1259,14 +1259,22 @@ static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x)
+   /* Solve L*y = b */
+   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
+   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
++  #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
++  PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L,                   /* L Y = X */
++                                         fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
++  #else
+   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L,                                     /* L Y = X */
+                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
+-
++  #endif
+   /* Solve U*x = y */
+   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
++  #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
+   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */
++                                         fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U));
++  #else
++  PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U,                                     /* U X = Y */
+                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U));
+-
++  #endif
+   PetscCall(VecHIPRestoreArrayRead(b, &barray));
+   PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
+ 
+@@ -1309,14 +1317,22 @@ static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Ve
+   /* Solve Ut*y = b */
+   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
+   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
++  #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
++  PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
++                                         fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut));
++  #else
+   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
+                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut));
+-
++  #endif
+   /* Solve Lt*x = y */
+   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
++  #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
++  PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
++                                         fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
++  #else
+   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
+                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
+-
++  #endif
+   PetscCall(VecHIPRestoreArrayRead(b, &barray));
+   PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
+   PetscCall(PetscLogGpuTimeEnd());
+@@ -1544,14 +1560,22 @@ static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ICC0(Mat fact, Vec b, Vec x)
+   /* Solve L*y = b */
+   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
+   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
++  #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
++  PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
++                                         fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L));
++  #else
+   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
+                                          fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
+-
++  #endif
+   /* Solve Lt*x = y */
+   PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
++  #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
++  PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
++                                         fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
++  #else
+   PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
+                                          fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
+-
++  #endif
+   PetscCall(VecHIPRestoreArrayRead(b, &barray));
+   PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
+ 
diff --git a/var/spack/repos/builtin/packages/petsc/hip-5.7-plus-for-3.18.diff b/var/spack/repos/builtin/packages/petsc/hip-5.7-plus-for-3.18.diff
new file mode 100644
index 0000000000..22555963b3
--- /dev/null
+++ b/var/spack/repos/builtin/packages/petsc/hip-5.7-plus-for-3.18.diff
@@ -0,0 +1,20 @@
+diff --git a/src/vec/is/sf/impls/basic/hip/sfhip.hip.cpp b/src/vec/is/sf/impls/basic/hip/sfhip.hip.cpp
+index a39933c6893..6ef9f513bd6 100644
+--- a/src/vec/is/sf/impls/basic/hip/sfhip.hip.cpp
++++ b/src/vec/is/sf/impls/basic/hip/sfhip.hip.cpp
+@@ -471,6 +471,7 @@ __device__ static float atomicMax(float *address, float val)
+ #endif
+ 
+ /* As of ROCm 3.10 llint atomicMin/Max(llint*, llint) is not supported */
++#if PETSC_PKG_HIP_VERSION_LT(5, 7, 0)
+ __device__ static llint atomicMin(llint *address, llint val)
+ {
+   ullint *address_as_ull = (ullint *)(address);
+@@ -492,6 +493,7 @@ __device__ static llint atomicMax(llint *address, llint val)
+   } while (assumed != old);
+   return (llint)old;
+ }
++#endif
+ 
+ template <typename Type>
+ struct AtomicMin {
diff --git a/var/spack/repos/builtin/packages/petsc/package.py b/var/spack/repos/builtin/packages/petsc/package.py
index 403e2c2cb4..51d65f229c 100644
--- a/var/spack/repos/builtin/packages/petsc/package.py
+++ b/var/spack/repos/builtin/packages/petsc/package.py
@@ -161,12 +161,17 @@ class Petsc(Package, CudaPackage, ROCmPackage):
     variant("kokkos", default=False, description="Activates support for kokkos and kokkos-kernels")
     variant("fortran", default=True, description="Activates fortran support")
 
-    # https://github.com/spack/spack/issues/37416
-    conflicts("^rocprim@5.3.0:5.3.2", when="+rocm")
-    # petsc 3.20 has workaround for breaking change in hipsparseSpSV_solve api,
-    # but it seems to misdetect hipsparse@5.6.1 as 5.6.0, so the workaround
-    # only makes things worse
-    conflicts("^hipsparse@5.6", when="+rocm @3.20.0")
+    with when("+rocm"):
+        # https://github.com/spack/spack/issues/37416
+        conflicts("^rocprim@5.3.0:5.3.2")
+        # hipsparse@5.6.0 broke hipsparseSpSV_solve() API, reverted in 5.6.1.
+        patch(
+            "https://gitlab.com/petsc/petsc/-/commit/ef7140cce45367033b48bbd2624dfd2b6aa4b997.diff",
+            when="@3.20.0",
+            sha256="ba327f8b2a0fa45209dfb7a4278f3e9a323965b5a668be204c1c77c17a963a7f",
+        )
+        patch("hip-5.6.0-for-3.18.diff", when="@3.18:3.19 ^hipsparse@5.6.0")
+        patch("hip-5.7-plus-for-3.18.diff", when="@3.18:3.19 ^hipsparse@5.7:")
 
     # 3.8.0 has a build issue with MKL - so list this conflict explicitly
     conflicts("^intel-mkl", when="@3.8.0")
@@ -225,15 +230,17 @@ class Petsc(Package, CudaPackage, ROCmPackage):
     depends_on("mpi", when="+mpi")
     depends_on("cuda", when="+cuda")
     depends_on("hip", when="+rocm")
-    depends_on("hipblas", when="+rocm")
-    depends_on("hipsparse", when="+rocm")
-    depends_on("hipsolver", when="+rocm")
-    depends_on("rocsparse", when="+rocm")
-    depends_on("rocsolver", when="+rocm")
-    depends_on("rocblas", when="+rocm")
-    depends_on("rocrand", when="+rocm")
-    depends_on("rocthrust", when="+rocm")
-    depends_on("rocprim", when="+rocm")
+
+    with when("+rocm"):
+        depends_on("hipblas")
+        depends_on("hipsparse")
+        depends_on("hipsolver")
+        depends_on("rocsparse")
+        depends_on("rocsolver")
+        depends_on("rocblas")
+        depends_on("rocrand")
+        depends_on("rocthrust")
+        depends_on("rocprim")
 
     # Build dependencies
     depends_on("python@2.6:2.8,3.4:3.8", when="@:3.13", type="build")
-- 
cgit v1.2.3-70-g09d2