Merge pull request #55 from williamFalcon/continue

add training restore
This commit is contained in:
William Falcon 2019-08-07 09:02:16 -04:00 committed by GitHub
commit 35f23bbc82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 154 additions and 9 deletions

View File

@ -248,6 +248,7 @@ tensorboard --logdir /some/path
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
###### Computing cluster (SLURM)

View File

@ -18,5 +18,22 @@ checkpoint_callback = ModelCheckpoint(
trainer = Trainer(checkpoint_callback=checkpoint_callback)
```
---
### Restoring training session
You might want to not only load a model but also continue training it. Use this method to
restore the trainer state as well. This will continue from the epoch and global step you last left off.
However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter).
Lightning will restore the session if you pass an experiment with the same version and there's a saved checkpoint.
``` {.python}
from test_tube import Experiment
exp = Experiment(version=a_previous_version_with_a_saved_checkpoint)
Trainer(experiment=exp)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
# the trainer is now restored
```

View File

@ -21,6 +21,7 @@ But of course the fun is in all the advanced things it can do:
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
**Computing cluster (SLURM)**

View File

@ -28,6 +28,7 @@ one could be a seq-2-seq model, both (optionally) ran by the same trainer file.
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
- [Restoring training session](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#restoring-training-session)
###### Computing cluster (SLURM)

View File

@ -1,5 +1,5 @@
"""
The trainer handles all the logic for running a val loop, training loop, distributing, etc...
The trainer handles all the logic for running a val loop, training loop, distributing, etc.. .
"""
import os
@ -247,6 +247,32 @@ class Trainer(TrainerIO):
"""
raise ModuleNotFoundError(msg)
def restore_state_if_existing_checkpoint(self):
# restore trainer state and model if there is a weight for this experiment
last_epoch = -1
last_ckpt_name = None
# find last epoch
checkpoints = os.listdir(self.checkpoint_callback.filepath)
for name in checkpoints:
# ignore hpc ckpts
if 'hpc_' in name:
continue
if '.ckpt' in name:
epoch = name.split('epoch_')[1]
epoch = int(re.sub('[^0-9]', '' ,epoch))
if epoch > last_epoch:
last_epoch = epoch
last_ckpt_name = name
# restore last checkpoint
if last_ckpt_name is not None:
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
self.restore(last_ckpt_path, self.on_gpu)
print(f'model and trainer restored from checkpoint: {last_ckpt_path}')
@property
def data_parallel(self):
return self.use_dp or self.use_ddp
@ -609,9 +635,6 @@ We recommend you switch to ddp if you want to use amp
ref_model.trainer = self
ref_model.experiment = self.experiment
# run tiny validation to make sure program won't crash during val
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
# save exp to get started
if self.proc_rank == 0:
self.experiment.save()
@ -620,14 +643,23 @@ We recommend you switch to ddp if you want to use amp
# if cluster resets state, the model will update with the saved weights
self.model = model
# restore training and model before hpc call
self.restore_state_if_existing_checkpoint()
# enable cluster checkpointing
# also restores training state
# hpc checkpoint overrides any other checkpoints loaded before
if self.cluster is not None: # pragma: no cover
self.enable_auto_hpc_walltime_manager()
# run tiny validation to make sure program won't crash during val
ref_model.on_sanity_check_start()
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
# ---------------------------
# CORE TRAINING LOOP
# ---------------------------
self.__train()
def __train(self):

View File

@ -2,6 +2,14 @@ import torch
class ModelHooks(torch.nn.Module):
def on_sanity_check_start(self):
"""
Called before starting validate
:return:
"""
pass
def on_batch_start(self, data_batch):
pass

View File

@ -60,6 +60,22 @@ class TrainerIO(object):
# do the actual save
torch.save(checkpoint, filepath)
def restore(self, checkpoint_path, on_gpu):
if on_gpu:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# load model state
model = self.__get_model()
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
def dump_checkpoint(self):
checkpoint = {
@ -200,15 +216,15 @@ class TrainerIO(object):
# call model hook
model.on_hpc_load(checkpoint)
def max_ckpt_in_folder(self, path):
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = [x for x in files if 'ckpt_' in x]
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
ckpt_vs = []
for name in files:
name = name.split('ckpt_')[-1]
name = name.split(name_key)[-1]
name = re.sub('[^0-9]', '', name)
ckpt_vs.append(int(name))

View File

@ -26,6 +26,73 @@ np.random.seed(SEED)
# ------------------------------------------------------------------------
# TESTS
# ------------------------------------------------------------------------
def test_cpu_restore_training():
"""
Verify continue training session on CPU
:return:
"""
hparams = get_hparams()
model = LightningTestModel(hparams)
save_dir = init_save_dir()
# exp file to get meta
test_exp_version = 10
exp = get_exp(False, version=test_exp_version)
exp.argparse(hparams)
exp.save()
trainer_options = dict(
max_nb_epochs=2,
val_check_interval=0.50,
val_percent_check=0.2,
train_percent_check=0.2,
experiment=exp,
checkpoint_callback=ModelCheckpoint(save_dir)
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
real_global_epoch = trainer.current_epoch
# traning complete
assert result == 1, 'amp + ddp model failed to complete'
# wipe-out trainer and model
# retrain with not much data... this simulates picking training back up after slurm
# we want to see if the weights come back correctly
new_exp = get_exp(False, version=test_exp_version)
trainer_options = dict(
max_nb_epochs=2,
val_check_interval=0.50,
val_percent_check=0.2,
train_percent_check=0.2,
experiment=new_exp,
checkpoint_callback=ModelCheckpoint(save_dir),
)
trainer = Trainer(**trainer_options)
model = LightningTestModel(hparams)
# set the epoch start hook so we can predict before the model does the full training
def assert_good_acc():
assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0
# if model and state loaded correctly, predictions will be good even though we
# haven't trained with the new loaded model
trainer.model.eval()
run_prediction(trainer.val_dataloader, trainer.model)
model.on_sanity_check_start = assert_good_acc
# by calling fit again, we trigger training, loading weights from the cluster
# and our hook to predict using current model before any more weight updates
trainer.fit(model)
clear_save_dir()
def test_amp_gpu_ddp():
"""
Make sure DDP + AMP work
@ -56,6 +123,8 @@ def test_amp_gpu_ddp():
run_gpu_model_test(trainer_options, model, hparams)
def test_cpu_slurm_save_load():
"""
Verify model save/load/checkpoint on CPU
@ -622,10 +691,10 @@ def get_model():
return model, hparams
def get_exp(debug=True):
def get_exp(debug=True, version=None):
# set up exp object without actually saving logs
root_dir = os.path.dirname(os.path.realpath(__file__))
exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir')
exp = Experiment(debug=debug, save_dir=root_dir, name='tests_tt_dir', version=version)
return exp