lightning/pytorch_lightning/root_module/model_saving.py

262 lines
7.6 KiB
Python
Raw Normal View History

2019-03-31 01:45:16 +00:00
import os
import re
import torch
2019-08-07 14:14:59 +00:00
from pytorch_lightning.pt_overrides.override_data_parallel import (
2019-08-06 10:08:31 +00:00
LightningDistributedDataParallel, LightningDataParallel)
2019-03-31 01:45:16 +00:00
2019-03-31 01:45:16 +00:00
class ModelIO(object):
2019-07-27 02:04:27 +00:00
def on_load_checkpoint(self, checkpoint):
2019-03-31 01:45:16 +00:00
"""
Do something with the checkpoint
Gives model a chance to load something before state_dict is restored
2019-03-31 01:45:16 +00:00
:param checkpoint:
:return:
"""
pass
2019-03-31 01:45:16 +00:00
2019-07-27 02:04:27 +00:00
def on_save_checkpoint(self, checkpoint):
2019-03-31 01:45:16 +00:00
"""
2019-07-27 02:04:27 +00:00
Give the model a chance to add something to the checkpoint.
state_dict is already there
2019-03-31 01:45:16 +00:00
"""
pass
2019-03-31 01:45:16 +00:00
# -------------------------
# OPTIONAL HOOKS
# -------------------------
2019-07-27 02:04:27 +00:00
def on_hpc_save(self, checkpoint):
"""
Hook to do whatever you need right before Slurm manager saves the model
:return:
"""
pass
2019-07-27 02:04:27 +00:00
def on_hpc_load(self, checkpoint):
"""
Hook to do whatever you need right before Slurm manager loads the model
:return:
"""
pass
2019-03-31 01:45:16 +00:00
class TrainerIO(object):
2019-07-24 22:03:19 +00:00
def __get_model(self):
2019-08-06 10:08:31 +00:00
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel,
LightningDataParallel))
2019-07-24 22:03:19 +00:00
model = self.model.module if is_dp_module else self.model
return model
2019-03-31 01:45:16 +00:00
# --------------------
# MODEL SAVE CHECKPOINT
# --------------------
def save_checkpoint(self, filepath):
checkpoint = self.dump_checkpoint()
# do the actual save
torch.save(checkpoint, filepath)
2019-08-07 10:55:05 +00:00
def restore(self, checkpoint_path, on_gpu):
if on_gpu:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# load model state
model = self.__get_model()
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
2019-03-31 01:45:16 +00:00
def dump_checkpoint(self):
2019-03-31 01:45:16 +00:00
checkpoint = {
'epoch': self.current_epoch,
'global_step': self.global_step
}
if self.checkpoint_callback is not None:
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
if self.early_stop_callback is not None:
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience
2019-07-28 12:57:37 +00:00
# save optimizers
2019-03-31 01:45:16 +00:00
optimizer_states = []
for i, optimizer in enumerate(self.optimizers):
optimizer_states.append(optimizer.state_dict())
checkpoint['optimizer_states'] = optimizer_states
2019-08-04 18:08:14 +00:00
2019-07-28 12:57:37 +00:00
# save lr schedulers
lr_schedulers = []
for i, scheduler in enumerate(self.lr_schedulers):
lr_schedulers.append(scheduler.state_dict())
checkpoint['lr_schedulers'] = lr_schedulers
2019-03-31 01:45:16 +00:00
2019-07-27 02:04:27 +00:00
# add the state_dict from the model
2019-07-24 22:03:19 +00:00
model = self.__get_model()
2019-07-27 02:09:35 +00:00
checkpoint['state_dict'] = model.state_dict()
2019-07-27 02:04:27 +00:00
# give the model a chance to add a few things
model.on_save_checkpoint(checkpoint)
2019-03-31 01:45:16 +00:00
return checkpoint
# --------------------
# HPC IO
# --------------------
2019-07-26 16:14:58 +00:00
def enable_auto_hpc_walltime_manager(self):
2019-03-31 01:45:16 +00:00
if self.cluster is None:
return
# allow test tube to handle model check pointing automatically
# only if proc 0 so we don't trigger world_size resubmits
if self.proc_rank == 0:
self.cluster.set_checkpoint_save_function(
self.hpc_save,
kwargs={
'folderpath': self.checkpoint_callback.filepath,
'experiment': self.experiment
}
)
2019-03-31 01:45:16 +00:00
self.cluster.set_checkpoint_load_function(
self.hpc_load,
kwargs={
'folderpath': self.checkpoint_callback.filepath,
'on_gpu': self.on_gpu
}
)
def restore_training_state(self, checkpoint):
"""
Restore trainer state.
Model will get its change to update
:param checkpoint:
:return:
"""
if self.checkpoint_callback is not None:
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
if self.early_stop_callback is not None:
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']
2019-03-31 01:45:16 +00:00
self.global_step = checkpoint['global_step']
self.current_epoch = checkpoint['epoch']
2019-03-31 01:45:16 +00:00
# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
2019-08-04 18:08:14 +00:00
2019-07-28 12:57:37 +00:00
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
scheduler.load_state_dict(lrs_state)
2019-03-31 01:45:16 +00:00
# ----------------------------------
# PRIVATE OPS
# ----------------------------------
def hpc_save(self, folderpath, experiment):
2019-06-14 13:46:41 +00:00
# make sure the checkpoint folder exists
os.makedirs(folderpath, exist_ok=True)
2019-03-31 01:45:16 +00:00
# save exp to make sure we get all the metrics
experiment.save()
2019-06-29 19:58:47 +00:00
# close experiment to avoid issues
experiment.close()
2019-03-31 01:45:16 +00:00
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
if not os.path.exists(folderpath):
os.makedirs(folderpath, exist_ok=True)
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
# give model a chance to do something on hpc_save
2019-07-24 22:03:19 +00:00
model = self.__get_model()
2019-07-27 02:04:27 +00:00
checkpoint = self.dump_checkpoint()
2019-07-27 02:04:27 +00:00
model.on_hpc_save(checkpoint)
2019-03-31 01:45:16 +00:00
# do the actual save
2019-07-27 02:04:27 +00:00
torch.save(checkpoint, filepath)
2019-03-31 01:45:16 +00:00
2019-07-26 16:14:58 +00:00
return filepath
2019-03-31 01:45:16 +00:00
def hpc_load(self, folderpath, on_gpu):
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
if on_gpu:
checkpoint = torch.load(filepath)
else:
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
# load training state (affects trainer only)
2019-03-31 01:45:16 +00:00
self.restore_training_state(checkpoint)
# load model state
2019-07-24 22:03:19 +00:00
model = self.__get_model()
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
# call model hook
2019-07-27 02:04:27 +00:00
model.on_hpc_load(checkpoint)
2019-08-07 10:55:05 +00:00
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
2019-03-31 01:45:16 +00:00
files = os.listdir(path)
2019-08-07 10:55:05 +00:00
files = [x for x in files if name_key in x]
2019-06-14 13:24:51 +00:00
if len(files) == 0:
return 0
2019-03-31 01:45:16 +00:00
ckpt_vs = []
for name in files:
2019-08-07 10:55:05 +00:00
name = name.split(name_key)[-1]
2019-03-31 01:45:16 +00:00
name = re.sub('[^0-9]', '', name)
ckpt_vs.append(int(name))
return max(ckpt_vs)
2019-07-24 19:44:04 +00:00
def load_hparams_from_tags_csv(tags_csv):
from argparse import Namespace
import pandas as pd
tags_df = pd.read_csv(tags_csv)
dic = tags_df.to_dict(orient='records')
ns_dict = {row['key']: convert(row['value']) for row in dic}
ns = Namespace(**ns_dict)
return ns
def convert(val):
constructors = [int, float, str]
if type(val) is str:
if val.lower() == 'true':
return True
if val.lower() == 'false':
return False
for c in constructors:
try:
return c(val)
except ValueError:
pass
return val