added checkpoint test on cpu

This commit is contained in:
William Falcon 2019-07-26 11:50:02 -04:00
parent c4b37d1efe
commit 51a5cc36e3
1 changed files with 53 additions and 0 deletions

View File

@ -43,6 +43,59 @@ def test_dp_output_reduce():
assert reduced['b']['c'] == out['b']['c']
def test_cpu_slurm_managed():
"""
SLURM checkpointing works
:return:
"""
hparams = get_hparams()
model = LightningTestModel(hparams)
trainer_options = dict(
max_nb_epochs=1,
)
save_dir = init_save_dir()
# exp file to get meta
exp = get_exp(False)
exp.argparse(hparams)
exp.save()
# exp file to get weights
checkpoint = ModelCheckpoint(save_dir)
# add these to the trainer options
trainer_options['checkpoint_callback'] = checkpoint
trainer_options['experiment'] = exp
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# correct result and ok accuracy
assert result == 1, 'amp + ddp model failed to complete'
# test model loading with a map_location
pretrained_model = load_model(exp, save_dir, True)
# test model preds
run_prediction(model.test_dataloader, pretrained_model)
trainer.model = pretrained_model
trainer.optimizers = pretrained_model.configure_optimizers()
# test HPC loading / saving
trainer.hpc_save(save_dir, exp)
trainer.hpc_load(save_dir, on_gpu=False)
# test freeze on gpu
model.freeze()
model.unfreeze()
clear_save_dir()
def test_amp_gpu_ddp_slurm_managed():
"""
Make sure DDP + AMP work