From 51a5cc36e3e7e5f0ed99371c996797de7506a549 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 26 Jul 2019 11:50:02 -0400 Subject: [PATCH] added checkpoint test on cpu --- tests/test_models.py | 53 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index c20c927d6a..2dc05489de 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -43,6 +43,59 @@ def test_dp_output_reduce(): assert reduced['b']['c'] == out['b']['c'] +def test_cpu_slurm_managed(): + """ + SLURM checkpointing works + :return: + """ + hparams = get_hparams() + model = LightningTestModel(hparams) + + trainer_options = dict( + max_nb_epochs=1, + ) + + save_dir = init_save_dir() + + # exp file to get meta + exp = get_exp(False) + exp.argparse(hparams) + exp.save() + + # exp file to get weights + checkpoint = ModelCheckpoint(save_dir) + + # add these to the trainer options + trainer_options['checkpoint_callback'] = checkpoint + trainer_options['experiment'] = exp + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # correct result and ok accuracy + assert result == 1, 'amp + ddp model failed to complete' + + # test model loading with a map_location + pretrained_model = load_model(exp, save_dir, True) + + # test model preds + run_prediction(model.test_dataloader, pretrained_model) + + trainer.model = pretrained_model + trainer.optimizers = pretrained_model.configure_optimizers() + + # test HPC loading / saving + trainer.hpc_save(save_dir, exp) + trainer.hpc_load(save_dir, on_gpu=False) + + # test freeze on gpu + model.freeze() + model.unfreeze() + + clear_save_dir() + + def test_amp_gpu_ddp_slurm_managed(): """ Make sure DDP + AMP work