summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--var/spack/repos/builtin/packages/amrex/package.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/var/spack/repos/builtin/packages/amrex/package.py b/var/spack/repos/builtin/packages/amrex/package.py
index d3414e575e..8c077e7921 100644
--- a/var/spack/repos/builtin/packages/amrex/package.py
+++ b/var/spack/repos/builtin/packages/amrex/package.py
@@ -6,7 +6,7 @@
from spack import *
-class Amrex(CMakePackage, CudaPackage):
+class Amrex(CMakePackage, CudaPackage, ROCmPackage):
"""AMReX is a publicly available software framework designed
for building massively parallel block- structured adaptive
mesh refinement (AMR) applications."""
@@ -81,6 +81,7 @@ class Amrex(CMakePackage, CudaPackage):
depends_on('cmake@3.14:', type='build', when='@19.04:')
# cmake @3.17: is necessary to handle cuda @11: correctly
depends_on('cmake@3.17:', type='build', when='^cuda @11:')
+ depends_on('rocrand', type='build', when='+rocm')
conflicts('%apple-clang')
conflicts('%clang')
@@ -113,6 +114,8 @@ class Amrex(CMakePackage, CudaPackage):
conflicts('cuda_arch=21', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5')
conflicts('cuda_arch=30', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5')
conflicts('cuda_arch=32', when='+cuda', msg='AMReX only supports compute capabilities >= 3.5')
+ conflicts('+rocm', when='@:20.11', msg='AMReX HIP support needs AMReX newer than version 20.11')
+ conflicts('+cuda', when='+rocm', msg='CUDA and HIP support are exclusive')
def url_for_version(self, version):
if version >= Version('20.05'):
@@ -200,4 +203,10 @@ class Amrex(CMakePackage, CudaPackage):
cuda_arch = self.spec.variants['cuda_arch'].value
args.append('-DCUDA_ARCH=' + self.get_cuda_arch_string(cuda_arch))
+ if '+rocm' in self.spec:
+ args.append('-DCMAKE_CXX_COMPILER={0}'.format(self.spec['hip'].hipcc))
+ args.append('-DAMReX_GPU_BACKEND=HIP')
+ targets = self.spec.variants['amdgpu_target'].value
+ args.append('-DAMReX_AMD_ARCH=' + ';'.join(str(x) for x in targets))
+
return args