added clean slurm save load test

This commit is contained in:
William Falcon 2019-07-26 22:24:01 -04:00
parent c61e13f0ff
commit b5419fcd8b
1 changed files with 27 additions and 17 deletions

View File

@ -90,7 +90,6 @@ def test_model_saving_loading():
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
real_global_step = trainer.global_step
# traning complete
assert result == 1, 'amp + ddp model failed to complete'
@ -124,7 +123,7 @@ def test_model_saving_loading():
clear_save_dir()
def test_cpu_slurm_saving_loading():
def test_cpu_slurm_save_load():
"""
Verify model save/load/checkpoint on CPU
:return:
@ -154,38 +153,49 @@ def test_cpu_slurm_saving_loading():
# traning complete
assert result == 1, 'amp + ddp model failed to complete'
# test saving checkpoint
ckpt_test = os.path.join(save_dir, 'test.ckpt')
trainer.save_checkpoint(ckpt_test)
# predict with trained model before saving
# make a prediction
for batch in model.test_dataloader:
break
x, y = batch
x = x.view(x.size(0), -1)
model.eval()
pred_before_saving = model(x)
# test registering a save function
trainer.enable_auto_hpc_walltime_manager()
# test model loading with a map_location
pretrained_model = load_model(exp, save_dir, True)
# test model preds
run_prediction(model.test_dataloader, pretrained_model)
trainer.model = pretrained_model
trainer.optimizers = pretrained_model.configure_optimizers()
# test HPC saving
# simulate snapshot on slurm
saved_filepath = trainer.hpc_save(save_dir, exp)
assert os.path.exists(saved_filepath)
# wipe-out trainer model
# we want to see if the weights come back correctly
trainer.model = LightningTestModel(hparams)
# test HPC loading
trainer.global_step = 20000000
trainer.hpc_load(save_dir, on_gpu=False)
assert trainer.global_step == real_global_step and trainer.global_step != 20000000
# test freeze on gpu
model.freeze()
model.unfreeze()
# predict with loaded model to make sure answers are the same
new_pred = trainer.model(x)
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
clear_save_dir()
def test_model_freeze_unfreeze():
hparams = get_hparams()
model = LightningTestModel(hparams)
model.freeze()
model.unfreeze()
def test_amp_gpu_ddp_slurm_managed():
"""
Make sure DDP + AMP work