added clean slurm save load test
This commit is contained in:
parent
c61e13f0ff
commit
b5419fcd8b
|
@ -90,7 +90,6 @@ def test_model_saving_loading():
|
|||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
real_global_step = trainer.global_step
|
||||
|
||||
# traning complete
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
@ -124,7 +123,7 @@ def test_model_saving_loading():
|
|||
clear_save_dir()
|
||||
|
||||
|
||||
def test_cpu_slurm_saving_loading():
|
||||
def test_cpu_slurm_save_load():
|
||||
"""
|
||||
Verify model save/load/checkpoint on CPU
|
||||
:return:
|
||||
|
@ -154,38 +153,49 @@ def test_cpu_slurm_saving_loading():
|
|||
# traning complete
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
# test saving checkpoint
|
||||
ckpt_test = os.path.join(save_dir, 'test.ckpt')
|
||||
trainer.save_checkpoint(ckpt_test)
|
||||
# predict with trained model before saving
|
||||
# make a prediction
|
||||
for batch in model.test_dataloader:
|
||||
break
|
||||
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
model.eval()
|
||||
pred_before_saving = model(x)
|
||||
|
||||
# test registering a save function
|
||||
trainer.enable_auto_hpc_walltime_manager()
|
||||
|
||||
# test model loading with a map_location
|
||||
pretrained_model = load_model(exp, save_dir, True)
|
||||
|
||||
# test model preds
|
||||
run_prediction(model.test_dataloader, pretrained_model)
|
||||
|
||||
trainer.model = pretrained_model
|
||||
trainer.optimizers = pretrained_model.configure_optimizers()
|
||||
|
||||
# test HPC saving
|
||||
# simulate snapshot on slurm
|
||||
saved_filepath = trainer.hpc_save(save_dir, exp)
|
||||
assert os.path.exists(saved_filepath)
|
||||
|
||||
# wipe-out trainer model
|
||||
# we want to see if the weights come back correctly
|
||||
trainer.model = LightningTestModel(hparams)
|
||||
|
||||
# test HPC loading
|
||||
trainer.global_step = 20000000
|
||||
trainer.hpc_load(save_dir, on_gpu=False)
|
||||
assert trainer.global_step == real_global_step and trainer.global_step != 20000000
|
||||
|
||||
# test freeze on gpu
|
||||
model.freeze()
|
||||
model.unfreeze()
|
||||
# predict with loaded model to make sure answers are the same
|
||||
new_pred = trainer.model(x)
|
||||
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
def test_model_freeze_unfreeze():
|
||||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
model.freeze()
|
||||
model.unfreeze()
|
||||
|
||||
|
||||
def test_amp_gpu_ddp_slurm_managed():
|
||||
"""
|
||||
Make sure DDP + AMP work
|
||||
|
|
Loading…
Reference in New Issue