Fix amp tests (#661)
* Run AMP tests in their own process With opt_level="O1" (the default), AMP patches many torch functions, which breaks any tests that run afterwards. This patch introduces a pytest extension that lets tests be marked with @pytest.mark.spawn so that they are run in their own process using torch.multiprocessing.spawn so that the main python interpreter stays un-patched. Note that tests using DDP already run AMP in its own process, so they don't need this annotation. * Fix AMP tests Since AMP defaults to O1 now, DP tests no longer throw exceptions. Since AMP patches torch functions, CPU inference no longer works. Skip prediction step for AMP tests. * typo
This commit is contained in:
parent
c32f2b9116
commit
019f612204
|
@ -0,0 +1,22 @@
|
|||
import pytest
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")
|
||||
|
||||
|
||||
def wrap(i, fn, args):
|
||||
return fn(*args)
|
||||
|
||||
|
||||
@pytest.mark.tryfirst
|
||||
def pytest_pyfunc_call(pyfuncitem):
|
||||
if pyfuncitem.get_closest_marker("spawn"):
|
||||
testfunction = pyfuncitem.obj
|
||||
funcargs = pyfuncitem.funcargs
|
||||
testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames])
|
||||
|
||||
mp.spawn(wrap, (testfunction, testargs))
|
||||
return True
|
|
@ -32,6 +32,7 @@ def test_amp_single_gpu(tmpdir):
|
|||
tutils.run_model_test(trainer_options, model)
|
||||
|
||||
|
||||
@pytest.mark.spawn
|
||||
def test_no_amp_single_gpu(tmpdir):
|
||||
"""Make sure DDP + AMP work."""
|
||||
tutils.reset_seed()
|
||||
|
@ -51,8 +52,10 @@ def test_no_amp_single_gpu(tmpdir):
|
|||
use_amp=True
|
||||
)
|
||||
|
||||
with pytest.raises((MisconfigurationException, ModuleNotFoundError)):
|
||||
tutils.run_model_test(trainer_options, model)
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1
|
||||
|
||||
|
||||
def test_amp_gpu_ddp(tmpdir):
|
||||
|
@ -78,6 +81,7 @@ def test_amp_gpu_ddp(tmpdir):
|
|||
tutils.run_model_test(trainer_options, model)
|
||||
|
||||
|
||||
@pytest.mark.spawn
|
||||
def test_amp_gpu_ddp_slurm_managed(tmpdir):
|
||||
"""Make sure DDP + AMP work."""
|
||||
if not tutils.can_run_gpu_test():
|
||||
|
@ -124,26 +128,6 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
|
|||
assert trainer.resolve_root_node_address('abc[23-24]') == 'abc23'
|
||||
assert trainer.resolve_root_node_address('abc[23-24, 45-40, 40]') == 'abc23'
|
||||
|
||||
# test model loading with a map_location
|
||||
pretrained_model = tutils.load_model(logger.experiment, trainer.checkpoint_callback.filepath)
|
||||
|
||||
# test model preds
|
||||
for dataloader in trainer.get_test_dataloaders():
|
||||
tutils.run_prediction(dataloader, pretrained_model)
|
||||
|
||||
if trainer.use_ddp:
|
||||
# on hpc this would work fine... but need to hack it for the purpose of the test
|
||||
trainer.model = pretrained_model
|
||||
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
|
||||
|
||||
# test HPC loading / saving
|
||||
trainer.hpc_save(tmpdir, logger)
|
||||
trainer.hpc_load(tmpdir, on_gpu=True)
|
||||
|
||||
# test freeze on gpu
|
||||
model.freeze()
|
||||
model.unfreeze()
|
||||
|
||||
|
||||
def test_cpu_model_with_amp(tmpdir):
|
||||
"""Make sure model trains on CPU."""
|
||||
|
@ -165,6 +149,7 @@ def test_cpu_model_with_amp(tmpdir):
|
|||
tutils.run_model_test(trainer_options, model, on_gpu=False)
|
||||
|
||||
|
||||
@pytest.mark.spawn
|
||||
def test_amp_gpu_dp(tmpdir):
|
||||
"""Make sure DP + AMP work."""
|
||||
tutils.reset_seed()
|
||||
|
@ -180,8 +165,11 @@ def test_amp_gpu_dp(tmpdir):
|
|||
distributed_backend='dp',
|
||||
use_amp=True
|
||||
)
|
||||
with pytest.raises(MisconfigurationException):
|
||||
tutils.run_model_test(trainer_options, model, hparams)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue