From 019f6122044bc8337056df710f4768e66c3fcf61 Mon Sep 17 00:00:00 2001 From: Nic Eggert Date: Sun, 5 Jan 2020 13:34:25 -0600 Subject: [PATCH] Fix amp tests (#661) * Run AMP tests in their own process With opt_level="O1" (the default), AMP patches many torch functions, which breaks any tests that run afterwards. This patch introduces a pytest extension that lets tests be marked with @pytest.mark.spawn so that they are run in their own process using torch.multiprocessing.spawn so that the main python interpreter stays un-patched. Note that tests using DDP already run AMP in its own process, so they don't need this annotation. * Fix AMP tests Since AMP defaults to O1 now, DP tests no longer throw exceptions. Since AMP patches torch functions, CPU inference no longer works. Skip prediction step for AMP tests. * typo --- tests/conftest.py | 22 ++++++++++++++++++++++ tests/test_amp.py | 36 ++++++++++++------------------------ 2 files changed, 34 insertions(+), 24 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..bfb2b0d5fc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,22 @@ +import pytest + +import torch.multiprocessing as mp + + +def pytest_configure(config): + config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn") + + +def wrap(i, fn, args): + return fn(*args) + + +@pytest.mark.tryfirst +def pytest_pyfunc_call(pyfuncitem): + if pyfuncitem.get_closest_marker("spawn"): + testfunction = pyfuncitem.obj + funcargs = pyfuncitem.funcargs + testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames]) + + mp.spawn(wrap, (testfunction, testargs)) + return True diff --git a/tests/test_amp.py b/tests/test_amp.py index 85025a56ca..4d4ae22d57 100644 --- a/tests/test_amp.py +++ b/tests/test_amp.py @@ -32,6 +32,7 @@ def test_amp_single_gpu(tmpdir): tutils.run_model_test(trainer_options, model) +@pytest.mark.spawn def test_no_amp_single_gpu(tmpdir): """Make sure DDP + AMP work.""" tutils.reset_seed() @@ -51,8 +52,10 @@ def test_no_amp_single_gpu(tmpdir): use_amp=True ) - with pytest.raises((MisconfigurationException, ModuleNotFoundError)): - tutils.run_model_test(trainer_options, model) + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + assert result == 1 def test_amp_gpu_ddp(tmpdir): @@ -78,6 +81,7 @@ def test_amp_gpu_ddp(tmpdir): tutils.run_model_test(trainer_options, model) +@pytest.mark.spawn def test_amp_gpu_ddp_slurm_managed(tmpdir): """Make sure DDP + AMP work.""" if not tutils.can_run_gpu_test(): @@ -124,26 +128,6 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23' assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23' - # test model loading with a map_location - pretrained_model = tutils.load_model(logger.experiment, trainer.checkpoint_callback.filepath) - - # test model preds - for dataloader in trainer.get_test_dataloaders(): - tutils.run_prediction(dataloader, pretrained_model) - - if trainer.use_ddp: - # on hpc this would work fine... but need to hack it for the purpose of the test - trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() - - # test HPC loading / saving - trainer.hpc_save(tmpdir, logger) - trainer.hpc_load(tmpdir, on_gpu=True) - - # test freeze on gpu - model.freeze() - model.unfreeze() - def test_cpu_model_with_amp(tmpdir): """Make sure model trains on CPU.""" @@ -165,6 +149,7 @@ def test_cpu_model_with_amp(tmpdir): tutils.run_model_test(trainer_options, model, on_gpu=False) +@pytest.mark.spawn def test_amp_gpu_dp(tmpdir): """Make sure DP + AMP work.""" tutils.reset_seed() @@ -180,8 +165,11 @@ def test_amp_gpu_dp(tmpdir): distributed_backend='dp', use_amp=True ) - with pytest.raises(MisconfigurationException): - tutils.run_model_test(trainer_options, model, hparams) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + assert result == 1 if __name__ == '__main__':