Merge pull request #55 from williamFalcon/continue
add training restore
This commit is contained in:
commit
35f23bbc82
|
@ -248,6 +248,7 @@ tensorboard --logdir /some/path
|
|||
|
||||
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
|
||||
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
|
||||
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
|
||||
|
||||
###### Computing cluster (SLURM)
|
||||
|
||||
|
|
|
@ -18,5 +18,22 @@ checkpoint_callback = ModelCheckpoint(
|
|||
trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
||||
```
|
||||
|
||||
---
|
||||
### Restoring training session
|
||||
You might want to not only load a model but also continue training it. Use this method to
|
||||
restore the trainer state as well. This will continue from the epoch and global step you last left off.
|
||||
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).
|
||||
|
||||
Lightning will restore the session if you pass an experiment with the same version and there's a saved checkpoint.
|
||||
``` {.python}
|
||||
from test_tube import Experiment
|
||||
|
||||
exp = Experiment(version=a_previous_version_with_a_saved_checkpoint)
|
||||
Trainer(experiment=exp)
|
||||
|
||||
trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
||||
# the trainer is now restored
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ But of course the fun is in all the advanced things it can do:
|
|||
|
||||
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
|
||||
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
|
||||
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
|
||||
|
||||
**Computing cluster (SLURM)**
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ one could be a seq-2-seq model, both (optionally) ran by the same trainer file.
|
|||
|
||||
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
|
||||
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
|
||||
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
|
||||
|
||||
###### Computing cluster (SLURM)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
The trainer handles all the logic for running a val loop, training loop, distributing, etc...
|
||||
The trainer handles all the logic for running a val loop, training loop, distributing, etc.. .
|
||||
"""
|
||||
|
||||
import os
|
||||
|
@ -247,6 +247,32 @@ class Trainer(TrainerIO):
|
|||
"""
|
||||
raise ModuleNotFoundError(msg)
|
||||
|
||||
def restore_state_if_existing_checkpoint(self):
|
||||
# restore trainer state and model if there is a weight for this experiment
|
||||
last_epoch = -1
|
||||
last_ckpt_name = None
|
||||
|
||||
# find last epoch
|
||||
checkpoints = os.listdir(self.checkpoint_callback.filepath)
|
||||
for name in checkpoints:
|
||||
# ignore hpc ckpts
|
||||
if 'hpc_' in name:
|
||||
continue
|
||||
|
||||
if '.ckpt' in name:
|
||||
epoch = name.split('epoch_')[1]
|
||||
epoch = int(re.sub('[^0-9]', '' ,epoch))
|
||||
|
||||
if epoch > last_epoch:
|
||||
last_epoch = epoch
|
||||
last_ckpt_name = name
|
||||
|
||||
# restore last checkpoint
|
||||
if last_ckpt_name is not None:
|
||||
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
|
||||
self.restore(last_ckpt_path, self.on_gpu)
|
||||
print(f'model and trainer restored from checkpoint: {last_ckpt_path}')
|
||||
|
||||
@property
|
||||
def data_parallel(self):
|
||||
return self.use_dp or self.use_ddp
|
||||
|
@ -609,9 +635,6 @@ We recommend you switch to ddp if you want to use amp
|
|||
ref_model.trainer = self
|
||||
ref_model.experiment = self.experiment
|
||||
|
||||
# run tiny validation to make sure program won't crash during val
|
||||
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
||||
|
||||
# save exp to get started
|
||||
if self.proc_rank == 0:
|
||||
self.experiment.save()
|
||||
|
@ -620,14 +643,23 @@ We recommend you switch to ddp if you want to use amp
|
|||
# if cluster resets state, the model will update with the saved weights
|
||||
self.model = model
|
||||
|
||||
# restore training and model before hpc call
|
||||
self.restore_state_if_existing_checkpoint()
|
||||
|
||||
# enable cluster checkpointing
|
||||
# also restores training state
|
||||
# hpc checkpoint overrides any other checkpoints loaded before
|
||||
if self.cluster is not None: # pragma: no cover
|
||||
self.enable_auto_hpc_walltime_manager()
|
||||
|
||||
# run tiny validation to make sure program won't crash during val
|
||||
ref_model.on_sanity_check_start()
|
||||
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
||||
|
||||
# ---------------------------
|
||||
# CORE TRAINING LOOP
|
||||
# ---------------------------
|
||||
|
||||
self.__train()
|
||||
|
||||
def __train(self):
|
||||
|
|
|
@ -2,6 +2,14 @@ import torch
|
|||
|
||||
|
||||
class ModelHooks(torch.nn.Module):
|
||||
|
||||
def on_sanity_check_start(self):
|
||||
"""
|
||||
Called before starting validate
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_batch_start(self, data_batch):
|
||||
pass
|
||||
|
||||
|
|
|
@ -60,6 +60,22 @@ class TrainerIO(object):
|
|||
# do the actual save
|
||||
torch.save(checkpoint, filepath)
|
||||
|
||||
def restore(self, checkpoint_path, on_gpu):
|
||||
|
||||
if on_gpu:
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
# load training state (affects trainer only)
|
||||
self.restore_training_state(checkpoint)
|
||||
|
||||
# load model state
|
||||
model = self.__get_model()
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
def dump_checkpoint(self):
|
||||
|
||||
checkpoint = {
|
||||
|
@ -200,15 +216,15 @@ class TrainerIO(object):
|
|||
# call model hook
|
||||
model.on_hpc_load(checkpoint)
|
||||
|
||||
def max_ckpt_in_folder(self, path):
|
||||
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
|
||||
files = os.listdir(path)
|
||||
files = [x for x in files if 'ckpt_' in x]
|
||||
files = [x for x in files if name_key in x]
|
||||
if len(files) == 0:
|
||||
return 0
|
||||
|
||||
ckpt_vs = []
|
||||
for name in files:
|
||||
name = name.split('ckpt_')[-1]
|
||||
name = name.split(name_key)[-1]
|
||||
name = re.sub('[^0-9]', '', name)
|
||||
ckpt_vs.append(int(name))
|
||||
|
||||
|
|
|
@ -26,6 +26,73 @@ np.random.seed(SEED)
|
|||
# ------------------------------------------------------------------------
|
||||
# TESTS
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def test_cpu_restore_training():
|
||||
"""
|
||||
Verify continue training session on CPU
|
||||
:return:
|
||||
"""
|
||||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = init_save_dir()
|
||||
|
||||
# exp file to get meta
|
||||
test_exp_version = 10
|
||||
exp = get_exp(False, version=test_exp_version)
|
||||
exp.argparse(hparams)
|
||||
exp.save()
|
||||
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=2,
|
||||
val_check_interval=0.50,
|
||||
val_percent_check=0.2,
|
||||
train_percent_check=0.2,
|
||||
experiment=exp,
|
||||
checkpoint_callback=ModelCheckpoint(save_dir)
|
||||
)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
real_global_epoch = trainer.current_epoch
|
||||
|
||||
# traning complete
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
||||
# wipe-out trainer and model
|
||||
# retrain with not much data... this simulates picking training back up after slurm
|
||||
# we want to see if the weights come back correctly
|
||||
new_exp = get_exp(False, version=test_exp_version)
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=2,
|
||||
val_check_interval=0.50,
|
||||
val_percent_check=0.2,
|
||||
train_percent_check=0.2,
|
||||
experiment=new_exp,
|
||||
checkpoint_callback=ModelCheckpoint(save_dir),
|
||||
)
|
||||
trainer = Trainer(**trainer_options)
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
# set the epoch start hook so we can predict before the model does the full training
|
||||
def assert_good_acc():
|
||||
assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0
|
||||
|
||||
# if model and state loaded correctly, predictions will be good even though we
|
||||
# haven't trained with the new loaded model
|
||||
trainer.model.eval()
|
||||
run_prediction(trainer.val_dataloader, trainer.model)
|
||||
|
||||
model.on_sanity_check_start = assert_good_acc
|
||||
|
||||
# by calling fit again, we trigger training, loading weights from the cluster
|
||||
# and our hook to predict using current model before any more weight updates
|
||||
trainer.fit(model)
|
||||
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
def test_amp_gpu_ddp():
|
||||
"""
|
||||
Make sure DDP + AMP work
|
||||
|
@ -56,6 +123,8 @@ def test_amp_gpu_ddp():
|
|||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_cpu_slurm_save_load():
|
||||
"""
|
||||
Verify model save/load/checkpoint on CPU
|
||||
|
@ -622,10 +691,10 @@ def get_model():
|
|||
return model, hparams
|
||||
|
||||
|
||||
def get_exp(debug=True):
|
||||
def get_exp(debug=True, version=None):
|
||||
# set up exp object without actually saving logs
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir')
|
||||
exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir', version=version)
|
||||
return exp
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue