added checkpoint test on cpu
This commit is contained in:
parent
c4b37d1efe
commit
51a5cc36e3
|
@ -43,6 +43,59 @@ def test_dp_output_reduce():
|
|||
assert reduced['b']['c'] == out['b']['c']
|
||||
|
||||
|
||||
def test_cpu_slurm_managed():
|
||||
"""
|
||||
SLURM checkpointing works
|
||||
:return:
|
||||
"""
|
||||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
)
|
||||
|
||||
save_dir = init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
exp = get_exp(False)
|
||||
exp.argparse(hparams)
|
||||
exp.save()
|
||||
|
||||
# exp file to get weights
|
||||
checkpoint = ModelCheckpoint(save_dir)
|
||||
|
||||
# add these to the trainer options
|
||||
trainer_options['checkpoint_callback'] = checkpoint
|
||||
trainer_options['experiment'] = exp
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
# test model loading with a map_location
|
||||
pretrained_model = load_model(exp, save_dir, True)
|
||||
|
||||
# test model preds
|
||||
run_prediction(model.test_dataloader, pretrained_model)
|
||||
|
||||
trainer.model = pretrained_model
|
||||
trainer.optimizers = pretrained_model.configure_optimizers()
|
||||
|
||||
# test HPC loading / saving
|
||||
trainer.hpc_save(save_dir, exp)
|
||||
trainer.hpc_load(save_dir, on_gpu=False)
|
||||
|
||||
# test freeze on gpu
|
||||
model.freeze()
|
||||
model.unfreeze()
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
def test_amp_gpu_ddp_slurm_managed():
|
||||
"""
|
||||
Make sure DDP + AMP work
|
||||
|
|
Loading…
Reference in New Issue