diff options
-rw-r--r-- | var/spack/repos/builtin/packages/py-torch/package.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch/package.py b/var/spack/repos/builtin/packages/py-torch/package.py index 9c7abe41e7..c4883b95eb 100644 --- a/var/spack/repos/builtin/packages/py-torch/package.py +++ b/var/spack/repos/builtin/packages/py-torch/package.py @@ -6,6 +6,7 @@ import os import sys +from spack.operating_systems.mac_os import macos_version from spack.package import * @@ -59,6 +60,7 @@ class PyTorch(PythonPackage, CudaPackage): variant('kineto', default=True, description='Use Kineto profiling library', when='@1.8:') variant('magma', default=not is_darwin, description='Use MAGMA', when='+cuda') variant('metal', default=is_darwin, description='Use Metal for Caffe2 iOS build') + variant('mps', default=is_darwin and macos_version() >= Version('12.3'), description='Use MPS for macOS build', when='@1.12: platform=darwin') variant('nccl', default=True, description='Use NCCL', when='+cuda platform=linux') variant('nccl', default=True, description='Use NCCL', when='+cuda platform=cray') variant('nccl', default=True, description='Use NCCL', when='+rocm platform=linux') @@ -373,6 +375,7 @@ class PyTorch(PythonPackage, CudaPackage): enable_or_disable('kineto') enable_or_disable('magma') enable_or_disable('metal') + enable_or_disable('mps') enable_or_disable('breakpad') enable_or_disable('nccl') |