summaryrefslogtreecommitdiff
path: root/var/spack/repos/builtin/packages/py-torch/package.py
diff options
context:
space:
mode:
Diffstat (limited to 'var/spack/repos/builtin/packages/py-torch/package.py')
-rw-r--r--var/spack/repos/builtin/packages/py-torch/package.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py
index 1ffeb81b23..bebf83d3d6 100644
--- a/var/spack/repos/builtin/packages/py-torch/package.py
+++ b/var/spack/repos/builtin/packages/py-torch/package.py
@@ -4,6 +4,7 @@
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from spack import *
+import os
class PyTorch(PythonPackage, CudaPackage):
@@ -68,6 +69,7 @@ class PyTorch(PythonPackage, CudaPackage):
variant('cuda', default=True, description='Build with CUDA')
variant('cudnn', default=True, description='Enables the cuDNN build')
+ variant('rocm', default=False, description='Build with ROCm build')
variant('magma', default=False, description='Enables the MAGMA build')
variant('fbgemm', default=False, description='Enables the FBGEMM build')
variant('test', default=False, description='Enables the test build')
@@ -112,6 +114,7 @@ class PyTorch(PythonPackage, CudaPackage):
conflicts('cuda_arch=none', when='+cuda',
msg='Must specify CUDA compute capabilities of your GPU, see '
'https://developer.nvidia.com/cuda-gpus')
+ conflicts('+rocm', when='+cuda')
# Required dependencies
depends_on('cmake@3.5:', type='build')
@@ -173,6 +176,9 @@ class PyTorch(PythonPackage, CudaPackage):
# Fixes CMake configuration error when XNNPACK is disabled
patch('xnnpack.patch', when='@1.5.0:1.5.999')
+ # Fixes Build error for when ROCm is enable for pytorch-1.5 release
+ patch('rocm.patch', when='@1.5.0:1.5.999+rocm')
+
# https://github.com/pytorch/pytorch/pull/37086
# Fixes compilation with Clang 9.0.0 and Apple Clang 11.0.3
patch('https://github.com/pytorch/pytorch/commit/e921cd222a8fbeabf5a3e74e83e0d8dfb01aa8b5.patch',
@@ -244,7 +250,9 @@ class PyTorch(PythonPackage, CudaPackage):
enable_or_disable('fbgemm')
enable_or_disable('test', keyword='BUILD')
-
+ enable_or_disable('rocm')
+ if '+rocm' in self.spec:
+ env.set('USE_MKLDNN', 0)
if '+miopen' in self.spec:
env.set('MIOPEN_LIB_DIR', self.spec['miopen'].libs.directories[0])
env.set('MIOPEN_INCLUDE_DIR', self.spec['miopen'].prefix.include)
@@ -297,6 +305,11 @@ class PyTorch(PythonPackage, CudaPackage):
enable_or_disable('zstd', newer=True)
enable_or_disable('tbb', newer=True)
+ @run_before('install')
+ def build_amd(self):
+ if '+rocm' in self.spec:
+ python(os.path.join('tools', 'amd_build', 'build_amd.py'))
+
def install_test(self):
with working_dir('test'):
python('run_test.py')