diff options
Diffstat (limited to 'var/spack/repos/builtin/packages/py-torch/package.py')
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index 1ffeb81b23..bebf83d3d6 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: (Apache-2.0 OR MIT) from spack import * +import os class PyTorch(PythonPackage, CudaPackage): @@ -68,6 +69,7 @@ class PyTorch(PythonPackage, CudaPackage): variant('cuda', default=True, description='Build with CUDA') variant('cudnn', default=True, description='Enables the cuDNN build') + variant('rocm', default=False, description='Build with ROCm build') variant('magma', default=False, description='Enables the MAGMA build') variant('fbgemm', default=False, description='Enables the FBGEMM build') variant('test', default=False, description='Enables the test build') @@ -112,6 +114,7 @@ class PyTorch(PythonPackage, CudaPackage): conflicts('cuda_arch=none', when='+cuda', msg='Must specify CUDA compute capabilities of your GPU, see ' 'https://developer.nvidia.com/cuda-gpus') + conflicts('+rocm', when='+cuda') # Required dependencies depends_on('cmake@3.5:', type='build') @@ -173,6 +176,9 @@ class PyTorch(PythonPackage, CudaPackage): # Fixes CMake configuration error when XNNPACK is disabled patch('xnnpack.patch', when='@1.5.0:1.5.999') + # Fixes Build error for when ROCm is enable for pytorch-1.5 release + patch('rocm.patch', when='@1.5.0:1.5.999+rocm') + # https://github.com/pytorch/pytorch/pull/37086 # Fixes compilation with Clang 9.0.0 and Apple Clang 11.0.3 patch('https://github.com/pytorch/pytorch/commit/e921cd222a8fbeabf5a3e74e83e0d8dfb01aa8b5.patch', @@ -244,7 +250,9 @@ class PyTorch(PythonPackage, CudaPackage): enable_or_disable('fbgemm') enable_or_disable('test', keyword='BUILD') - + enable_or_disable('rocm') + if '+rocm' in self.spec: + env.set('USE_MKLDNN', 0) if '+miopen' in self.spec: env.set('MIOPEN_LIB_DIR', self.spec['miopen'].libs.directories[0]) env.set('MIOPEN_INCLUDE_DIR', self.spec['miopen'].prefix.include) @@ -297,6 +305,11 @@ class PyTorch(PythonPackage, CudaPackage): enable_or_disable('zstd', newer=True) enable_or_disable('tbb', newer=True) + @run_before('install') + def build_amd(self): + if '+rocm' in self.spec: + python(os.path.join('tools', 'amd_build', 'build_amd.py')) + def install_test(self): with working_dir('test'): python('run_test.py') |