Fix amp tests (#661)

* Run AMP tests in their own process

With opt_level="O1" (the default), AMP patches many
torch functions, which breaks any tests that run afterwards.
This patch introduces a pytest extension that lets
tests be marked with @pytest.mark.spawn so that they
are run in their own process using torch.multiprocessing.spawn
so that the main python interpreter stays un-patched.

Note that tests using DDP already run AMP in its own process,
so they don't need this annotation.

* Fix AMP tests

Since AMP defaults to O1 now, DP tests no longer throw exceptions.

Since AMP patches torch functions, CPU inference no longer works.
Skip prediction step for AMP tests.

* typo
This commit is contained in:
Nic Eggert 2020-01-05 13:34:25 -06:00 committed by William Falcon
parent c32f2b9116
commit 019f612204
2 changed files with 34 additions and 24 deletions

22
tests/conftest.py Normal file
View File

@ -0,0 +1,22 @@
import pytest
import torch.multiprocessing as mp
def pytest_configure(config):
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")
def wrap(i, fn, args):
return fn(*args)
@pytest.mark.tryfirst
def pytest_pyfunc_call(pyfuncitem):
if pyfuncitem.get_closest_marker("spawn"):
testfunction = pyfuncitem.obj
funcargs = pyfuncitem.funcargs
testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames])
mp.spawn(wrap, (testfunction, testargs))
return True

View File

@ -32,6 +32,7 @@ def test_amp_single_gpu(tmpdir):
tutils.run_model_test(trainer_options, model)
@pytest.mark.spawn
def test_no_amp_single_gpu(tmpdir):
"""Make sure DDP + AMP work."""
tutils.reset_seed()
@ -51,8 +52,10 @@ def test_no_amp_single_gpu(tmpdir):
use_amp=True
)
with pytest.raises((MisconfigurationException, ModuleNotFoundError)):
tutils.run_model_test(trainer_options, model)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1
def test_amp_gpu_ddp(tmpdir):
@ -78,6 +81,7 @@ def test_amp_gpu_ddp(tmpdir):
tutils.run_model_test(trainer_options, model)
@pytest.mark.spawn
def test_amp_gpu_ddp_slurm_managed(tmpdir):
"""Make sure DDP + AMP work."""
if not tutils.can_run_gpu_test():
@ -124,26 +128,6 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23'
assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23'
# test model loading with a map_location
pretrained_model = tutils.load_model(logger.experiment, trainer.checkpoint_callback.filepath)
# test model preds
for dataloader in trainer.get_test_dataloaders():
tutils.run_prediction(dataloader, pretrained_model)
if trainer.use_ddp:
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
# test HPC loading / saving
trainer.hpc_save(tmpdir, logger)
trainer.hpc_load(tmpdir, on_gpu=True)
# test freeze on gpu
model.freeze()
model.unfreeze()
def test_cpu_model_with_amp(tmpdir):
"""Make sure model trains on CPU."""
@ -165,6 +149,7 @@ def test_cpu_model_with_amp(tmpdir):
tutils.run_model_test(trainer_options, model, on_gpu=False)
@pytest.mark.spawn
def test_amp_gpu_dp(tmpdir):
"""Make sure DP + AMP work."""
tutils.reset_seed()
@ -180,8 +165,11 @@ def test_amp_gpu_dp(tmpdir):
distributed_backend='dp',
use_amp=True
)
with pytest.raises(MisconfigurationException):
tutils.run_model_test(trainer_options, model, hparams)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1
if __name__ == '__main__':