262 lines
7.6 KiB
Python
262 lines
7.6 KiB
Python
import os
|
|
import re
|
|
|
|
import torch
|
|
|
|
from pytorch_lightning.pt_overrides.override_data_parallel import (
|
|
LightningDistributedDataParallel, LightningDataParallel)
|
|
|
|
|
|
class ModelIO(object):
|
|
|
|
def on_load_checkpoint(self, checkpoint):
|
|
"""
|
|
Do something with the checkpoint
|
|
Gives model a chance to load something before state_dict is restored
|
|
:param checkpoint:
|
|
:return:
|
|
"""
|
|
pass
|
|
|
|
def on_save_checkpoint(self, checkpoint):
|
|
"""
|
|
Give the model a chance to add something to the checkpoint.
|
|
state_dict is already there
|
|
"""
|
|
pass
|
|
|
|
# -------------------------
|
|
# OPTIONAL HOOKS
|
|
# -------------------------
|
|
def on_hpc_save(self, checkpoint):
|
|
"""
|
|
Hook to do whatever you need right before Slurm manager saves the model
|
|
:return:
|
|
"""
|
|
pass
|
|
|
|
def on_hpc_load(self, checkpoint):
|
|
"""
|
|
Hook to do whatever you need right before Slurm manager loads the model
|
|
:return:
|
|
"""
|
|
pass
|
|
|
|
|
|
class TrainerIO(object):
|
|
|
|
def __get_model(self):
|
|
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel,
|
|
LightningDataParallel))
|
|
model = self.model.module if is_dp_module else self.model
|
|
return model
|
|
|
|
# --------------------
|
|
# MODEL SAVE CHECKPOINT
|
|
# --------------------
|
|
def save_checkpoint(self, filepath):
|
|
checkpoint = self.dump_checkpoint()
|
|
|
|
# do the actual save
|
|
torch.save(checkpoint, filepath)
|
|
|
|
def restore(self, checkpoint_path, on_gpu):
|
|
|
|
if on_gpu:
|
|
checkpoint = torch.load(checkpoint_path)
|
|
else:
|
|
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
|
|
|
# load training state (affects trainer only)
|
|
self.restore_training_state(checkpoint)
|
|
|
|
# load model state
|
|
model = self.__get_model()
|
|
|
|
# load the state_dict on the model automatically
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
|
|
def dump_checkpoint(self):
|
|
|
|
checkpoint = {
|
|
'epoch': self.current_epoch,
|
|
'global_step': self.global_step
|
|
}
|
|
|
|
if self.checkpoint_callback is not None:
|
|
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
|
|
|
|
if self.early_stop_callback is not None:
|
|
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
|
|
checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience
|
|
|
|
# save optimizers
|
|
optimizer_states = []
|
|
for i, optimizer in enumerate(self.optimizers):
|
|
optimizer_states.append(optimizer.state_dict())
|
|
|
|
checkpoint['optimizer_states'] = optimizer_states
|
|
|
|
# save lr schedulers
|
|
lr_schedulers = []
|
|
for i, scheduler in enumerate(self.lr_schedulers):
|
|
lr_schedulers.append(scheduler.state_dict())
|
|
|
|
checkpoint['lr_schedulers'] = lr_schedulers
|
|
|
|
# add the state_dict from the model
|
|
model = self.__get_model()
|
|
checkpoint['state_dict'] = model.state_dict()
|
|
|
|
# give the model a chance to add a few things
|
|
model.on_save_checkpoint(checkpoint)
|
|
|
|
return checkpoint
|
|
|
|
# --------------------
|
|
# HPC IO
|
|
# --------------------
|
|
def enable_auto_hpc_walltime_manager(self):
|
|
if self.cluster is None:
|
|
return
|
|
|
|
# allow test tube to handle model check pointing automatically
|
|
# only if proc 0 so we don't trigger world_size resubmits
|
|
if self.proc_rank == 0:
|
|
self.cluster.set_checkpoint_save_function(
|
|
self.hpc_save,
|
|
kwargs={
|
|
'folderpath': self.checkpoint_callback.filepath,
|
|
'experiment': self.experiment
|
|
}
|
|
)
|
|
|
|
self.cluster.set_checkpoint_load_function(
|
|
self.hpc_load,
|
|
kwargs={
|
|
'folderpath': self.checkpoint_callback.filepath,
|
|
'on_gpu': self.on_gpu
|
|
}
|
|
)
|
|
|
|
def restore_training_state(self, checkpoint):
|
|
"""
|
|
Restore trainer state.
|
|
Model will get its change to update
|
|
:param checkpoint:
|
|
:return:
|
|
"""
|
|
if self.checkpoint_callback is not None:
|
|
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
|
|
|
|
if self.early_stop_callback is not None:
|
|
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
|
|
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']
|
|
|
|
self.global_step = checkpoint['global_step']
|
|
self.current_epoch = checkpoint['epoch']
|
|
|
|
# restore the optimizers
|
|
optimizer_states = checkpoint['optimizer_states']
|
|
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
|
|
optimizer.load_state_dict(opt_state)
|
|
|
|
# restore the lr schedulers
|
|
lr_schedulers = checkpoint['lr_schedulers']
|
|
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
|
|
scheduler.load_state_dict(lrs_state)
|
|
|
|
# ----------------------------------
|
|
# PRIVATE OPS
|
|
# ----------------------------------
|
|
def hpc_save(self, folderpath, experiment):
|
|
# make sure the checkpoint folder exists
|
|
os.makedirs(folderpath, exist_ok=True)
|
|
|
|
# save exp to make sure we get all the metrics
|
|
experiment.save()
|
|
|
|
# close experiment to avoid issues
|
|
experiment.close()
|
|
|
|
ckpt_number = self.max_ckpt_in_folder(folderpath) + 1
|
|
|
|
if not os.path.exists(folderpath):
|
|
os.makedirs(folderpath, exist_ok=True)
|
|
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
|
|
|
|
# give model a chance to do something on hpc_save
|
|
model = self.__get_model()
|
|
checkpoint = self.dump_checkpoint()
|
|
|
|
model.on_hpc_save(checkpoint)
|
|
|
|
# do the actual save
|
|
torch.save(checkpoint, filepath)
|
|
|
|
return filepath
|
|
|
|
def hpc_load(self, folderpath, on_gpu):
|
|
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, self.max_ckpt_in_folder(folderpath))
|
|
|
|
if on_gpu:
|
|
checkpoint = torch.load(filepath)
|
|
else:
|
|
checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
|
|
|
|
# load training state (affects trainer only)
|
|
self.restore_training_state(checkpoint)
|
|
|
|
# load model state
|
|
model = self.__get_model()
|
|
|
|
# load the state_dict on the model automatically
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
|
|
# call model hook
|
|
model.on_hpc_load(checkpoint)
|
|
|
|
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
|
|
files = os.listdir(path)
|
|
files = [x for x in files if name_key in x]
|
|
if len(files) == 0:
|
|
return 0
|
|
|
|
ckpt_vs = []
|
|
for name in files:
|
|
name = name.split(name_key)[-1]
|
|
name = re.sub('[^0-9]', '', name)
|
|
ckpt_vs.append(int(name))
|
|
|
|
return max(ckpt_vs)
|
|
|
|
|
|
def load_hparams_from_tags_csv(tags_csv):
|
|
from argparse import Namespace
|
|
import pandas as pd
|
|
|
|
tags_df = pd.read_csv(tags_csv)
|
|
dic = tags_df.to_dict(orient='records')
|
|
|
|
ns_dict = {row['key']: convert(row['value']) for row in dic}
|
|
|
|
ns = Namespace(**ns_dict)
|
|
return ns
|
|
|
|
|
|
def convert(val):
|
|
constructors = [int, float, str]
|
|
|
|
if type(val) is str:
|
|
if val.lower() == 'true':
|
|
return True
|
|
if val.lower() == 'false':
|
|
return False
|
|
|
|
for c in constructors:
|
|
try:
|
|
return c(val)
|
|
except ValueError:
|
|
pass
|
|
return val
|