diff options
author | Adam J. Stewart <ajstewart426@gmail.com> | 2022-06-27 09:21:49 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-27 18:21:49 +0200 |
commit | a6b0de3beb49afbf87b2b81dc4e601ee3a64ecc4 (patch) | |
tree | 82bb96e2e8874e5166dcdd7b79980cd0d565d1a7 | |
parent | 11d71ca85e1724f2bac1e9d7d366e3b8097add0d (diff) | |
download | spack-a6b0de3beb49afbf87b2b81dc4e601ee3a64ecc4.tar.gz spack-a6b0de3beb49afbf87b2b81dc4e601ee3a64ecc4.tar.bz2 spack-a6b0de3beb49afbf87b2b81dc4e601ee3a64ecc4.tar.xz spack-a6b0de3beb49afbf87b2b81dc4e601ee3a64ecc4.zip |
py-torch: add M1 GPU support (#31283)
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index 9c7abe41e7..c4883b95eb 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -6,6 +6,7 @@ import os import sys +from spack.operating_systems.mac_os import macos_version from spack.package import * @@ -59,6 +60,7 @@ class PyTorch(PythonPackage, CudaPackage): variant('kineto', default=True, description='Use Kineto profiling library', when='@1.8:') variant('magma', default=not is_darwin, description='Use MAGMA', when='+cuda') variant('metal', default=is_darwin, description='Use Metal for Caffe2 iOS build') + variant('mps', default=is_darwin and macos_version() >= Version('12.3'), description='Use MPS for macOS build', when='@1.12: platform=darwin') variant('nccl', default=True, description='Use NCCL', when='+cuda platform=linux') variant('nccl', default=True, description='Use NCCL', when='+cuda platform=cray') variant('nccl', default=True, description='Use NCCL', when='+rocm platform=linux') @@ -373,6 +375,7 @@ class PyTorch(PythonPackage, CudaPackage): enable_or_disable('kineto') enable_or_disable('magma') enable_or_disable('metal') + enable_or_disable('mps') enable_or_disable('breakpad') enable_or_disable('nccl') |