From c6440eb23cc95781e7a4e96036bc0d6769697591 Mon Sep 17 00:00:00 2001 From: a-saitoh-fj <63334055+a-saitoh-fj@users.noreply.github.com> Date: Fri, 12 Mar 2021 20:40:25 +0900 Subject: py-chainer: Add test method for ChainerMN (continued #21848, #21940) (#22189) * py-chainer: Add test method for ChainerMN (continued #21848, #21940) * py-chainer: Fixed the word in the message * py-chainer: Delete unnecessary imports * py-chainer: Incorporation of the measures pointed out in #21940 was insufficient. --- .../repos/builtin/packages/py-chainer/package.py | 49 ++++++++++++++++++++++ 1 file changed, 49 insertions(+) (limited to 'var') diff --git a/var/spack/repos/builtin/packages/py-chainer/package.py b/var/spack/repos/builtin/packages/py-chainer/package.py index 5f640cb3e5..7b2d7011b2 100644 --- a/var/spack/repos/builtin/packages/py-chainer/package.py +++ b/var/spack/repos/builtin/packages/py-chainer/package.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: (Apache-2.0 OR MIT) from spack import * +import json class PyChainer(PythonPackage): @@ -25,6 +26,8 @@ class PyChainer(PythonPackage): version('7.2.0', sha256='6e2fba648cc5b8a5421e494385b76fe5ec154f1028a1c5908557f5d16c04f0b3') version('6.7.0', sha256='87cb3378a35e7c5c695028ec91d58dc062356bc91412384ea939d71374610389') + variant("mn", default=False, description="run with ChainerMN") + depends_on('python@3.5.1:', when='@7:', type=('build', 'run')) depends_on('py-setuptools', type=('build', 'run')) depends_on('py-numpy@1.9:', type=('build', 'run')) @@ -34,3 +37,49 @@ class PyChainer(PythonPackage): depends_on('py-filelock', type=('build', 'run')) depends_on('py-protobuf@3:', type=('build', 'run')) depends_on('py-typing@:3.6.6', when='@:6', type=('build', 'run')) + + # Dependencies only required for test of ChainerMN + depends_on('py-matplotlib', type=('build', 'run'), when='+mn') + depends_on('py-mpi4py', type=('build', 'run'), when='+mn') + depends_on("mpi", type=("build", "run"), when='+mn') + + @run_after('install') + def cache_test_sources(self): + if '+mn' in self.spec: + self.cache_extra_test_sources("examples") + + def test(self): + if "+mn" in self.spec: + # Run test of ChainerMN + test_dir = self.test_suite.current_test_data_dir + + mnist_dir = join_path( + self.install_test_root, "examples", "chainermn", "mnist" + ) + mnist_file = join_path(mnist_dir, "train_mnist.py") + mpi_name = self.spec["mpi"].prefix.bin.mpirun + python_exe = self.spec["python"].command.path + opts = [ + "-n", + "4", + python_exe, + mnist_file, + "-o", + test_dir, + ] + env["OMP_NUM_THREADS"] = "4" + + self.run_test( + mpi_name, + options=opts, + work_dir=test_dir, + ) + + # check results + json_open = open(join_path(test_dir, 'log'), 'r') + json_load = json.load(json_open) + v = dict([(d.get('epoch'), d.get('main/accuracy')) for d in json_load]) + if 1 not in v or 20 not in v: + raise RuntimeError('Cannot find epoch 1 or epoch 20') + if abs(1.0 - v[1]) < abs(1.0 - v[20]): + raise RuntimeError('ChainerMN Test Failed !') -- cgit v1.2.3-60-g2f50