summaryrefslogtreecommitdiff
path: root/var
diff options
context:
space:
mode:
authora-saitoh-fj <63334055+a-saitoh-fj@users.noreply.github.com>2021-03-12 20:40:25 +0900
committerGitHub <noreply@github.com>2021-03-12 11:40:25 +0000
commitc6440eb23cc95781e7a4e96036bc0d6769697591 (patch)
tree1632a6ffb7460df057275ecd8930331aec6f686b /var
parent40147d99553e2b7267a96c21499a0b803f58a4bd (diff)
downloadspack-c6440eb23cc95781e7a4e96036bc0d6769697591.tar.gz
spack-c6440eb23cc95781e7a4e96036bc0d6769697591.tar.bz2
spack-c6440eb23cc95781e7a4e96036bc0d6769697591.tar.xz
spack-c6440eb23cc95781e7a4e96036bc0d6769697591.zip
py-chainer: Add test method for ChainerMN (continued #21848, #21940) (#22189)
* py-chainer: Add test method for ChainerMN (continued #21848, #21940) * py-chainer: Fixed the word in the message * py-chainer: Delete unnecessary imports * py-chainer: Incorporation of the measures pointed out in #21940 was insufficient.
Diffstat (limited to 'var')
-rw-r--r--var/spack/repos/builtin/packages/py-chainer/package.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/var/spack/repos/builtin/packages/py-chainer/package.py b/var/spack/repos/builtin/packages/py-chainer/package.py
index 5f640cb3e5..7b2d7011b2 100644
--- a/var/spack/repos/builtin/packages/py-chainer/package.py
+++ b/var/spack/repos/builtin/packages/py-chainer/package.py
@@ -4,6 +4,7 @@
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from spack import *
+import json
class PyChainer(PythonPackage):
@@ -25,6 +26,8 @@ class PyChainer(PythonPackage):
version('7.2.0', sha256='6e2fba648cc5b8a5421e494385b76fe5ec154f1028a1c5908557f5d16c04f0b3')
version('6.7.0', sha256='87cb3378a35e7c5c695028ec91d58dc062356bc91412384ea939d71374610389')
+ variant("mn", default=False, description="run with ChainerMN")
+
depends_on('python@3.5.1:', when='@7:', type=('build', 'run'))
depends_on('py-setuptools', type=('build', 'run'))
depends_on('py-numpy@1.9:', type=('build', 'run'))
@@ -34,3 +37,49 @@ class PyChainer(PythonPackage):
depends_on('py-filelock', type=('build', 'run'))
depends_on('py-protobuf@3:', type=('build', 'run'))
depends_on('py-typing@:3.6.6', when='@:6', type=('build', 'run'))
+
+ # Dependencies only required for test of ChainerMN
+ depends_on('py-matplotlib', type=('build', 'run'), when='+mn')
+ depends_on('py-mpi4py', type=('build', 'run'), when='+mn')
+ depends_on("mpi", type=("build", "run"), when='+mn')
+
+ @run_after('install')
+ def cache_test_sources(self):
+ if '+mn' in self.spec:
+ self.cache_extra_test_sources("examples")
+
+ def test(self):
+ if "+mn" in self.spec:
+ # Run test of ChainerMN
+ test_dir = self.test_suite.current_test_data_dir
+
+ mnist_dir = join_path(
+ self.install_test_root, "examples", "chainermn", "mnist"
+ )
+ mnist_file = join_path(mnist_dir, "train_mnist.py")
+ mpi_name = self.spec["mpi"].prefix.bin.mpirun
+ python_exe = self.spec["python"].command.path
+ opts = [
+ "-n",
+ "4",
+ python_exe,
+ mnist_file,
+ "-o",
+ test_dir,
+ ]
+ env["OMP_NUM_THREADS"] = "4"
+
+ self.run_test(
+ mpi_name,
+ options=opts,
+ work_dir=test_dir,
+ )
+
+ # check results
+ json_open = open(join_path(test_dir, 'log'), 'r')
+ json_load = json.load(json_open)
+ v = dict([(d.get('epoch'), d.get('main/accuracy')) for d in json_load])
+ if 1 not in v or 20 not in v:
+ raise RuntimeError('Cannot find epoch 1 or epoch 20')
+ if abs(1.0 - v[1]) < abs(1.0 - v[20]):
+ raise RuntimeError('ChainerMN Test Failed !')