summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py
index 45f9b279cb..00a357603f 100644
--- a/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py
+++ b/var/spack/repos/builtin/packages/py-torch-nvidia-apex/package.py
@@ -21,6 +21,7 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage):
depends_on("python@3:", type=("build", "run"))
depends_on("py-setuptools", type="build")
+ depends_on("py-packaging", type="build")
depends_on("py-torch@0.4:", type=("build", "run"))
depends_on("cuda@9:", when="+cuda")
depends_on("py-pybind11", type=("build", "link", "run"))
@@ -43,6 +44,7 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage):
else:
env.unset("CUDA_HOME")
+ @when("^python@:3.10")
def global_options(self, spec, prefix):
args = []
if spec.satisfies("^py-torch@1.0:"):
@@ -50,3 +52,11 @@ class PyTorchNvidiaApex(PythonPackage, CudaPackage):
if "+cuda" in spec:
args.append("--cuda_ext")
return args
+
+ @when("^python@3.11:")
+ def config_settings(self, spec, prefix):
+ return {
+ "builddir": "build",
+ "compile-args": f"-j{make_jobs}",
+ "--global-option": "--cpp_ext --cuda_ext",
+ }