added clean slurm save load test

This commit is contained in:
William Falcon 2019-07-26 22:57:49 -04:00
parent f183ac2a1c
commit 53b781709e
2 changed files with 19 additions and 8 deletions

View File

@ -610,14 +610,18 @@ class Trainer(TrainerIO):
if self.proc_rank == 0:
self.experiment.save()
# track model now.
# if cluster resets state, the model will update with the saved weights
self.model = model
# enable cluster checkpointing
# also restores training state
if self.cluster is not None: # pragma: no cover
self.enable_auto_hpc_walltime_manager()
# ---------------------------
# CORE TRAINING LOOP
# ---------------------------
self.model = model
self.__train()
def __train(self):

View File

@ -85,15 +85,22 @@ def test_cpu_slurm_save_load():
checkpoint_callback=ModelCheckpoint(save_dir)
)
trainer = Trainer(**trainer_options)
model = LightningTestModel(hparams)
# test HPC loading
trainer.hpc_load(save_dir, on_gpu=False)
assert trainer.global_step == real_global_step and trainer.global_step > 0
# set the epoch start hook so we can predict before the model does the full training
def assert_pred_same():
assert trainer.global_step == real_global_step and trainer.global_step > 0
# predict with loaded model to make sure answers are the same
trainer.model.eval()
new_pred = trainer.model(x)
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
# predict with loaded model to make sure answers are the same
trainer.model.eval()
new_pred = trainer.model(x)
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
model.on_epoch_start = assert_pred_same
# by calling fit again, we trigger training, loading weights from the cluster
# and our hook to predict using current model before any more weight updates
trainer.fit(model)
clear_save_dir()