Weights path (#211)
* added docs. removed options. added weights_save option * removed old restore * cleaned up save path * cleaned up save path * flake8
This commit is contained in:
parent
3e74ea15d8
commit
0c7fbc7178
|
@ -59,7 +59,6 @@ class Trainer(TrainerIO):
|
|||
checkpoint_callback=None,
|
||||
gradient_clip=0,
|
||||
process_position=0,
|
||||
current_gpu_name=0,
|
||||
nb_gpu_nodes=1,
|
||||
gpus=None,
|
||||
log_gpu_memory=False,
|
||||
|
@ -81,42 +80,40 @@ class Trainer(TrainerIO):
|
|||
use_amp=False,
|
||||
print_nan_grads=False,
|
||||
print_weights_summary=True,
|
||||
weights_save_path=None,
|
||||
amp_level='O2',
|
||||
nb_sanity_val_steps=5):
|
||||
"""
|
||||
|
||||
:param experiment: Test-tube experiment
|
||||
:param early_stop_callback: from pytorch_lightning import EarlyStopping
|
||||
:param checkpoint_callback: from pytorch_lightning import Checkpoint
|
||||
:param gradient_clip:
|
||||
:param process_position:
|
||||
:param current_gpu_name:
|
||||
:param nb_gpu_nodes:
|
||||
:param gpus:
|
||||
:param log_gpu_memory: Log GPU memory utilization as metric
|
||||
during training. This can lead to lower performance on some
|
||||
servers, in particular when `nvidia-smi` is slow.
|
||||
:param show_progress_bar:
|
||||
:param overfit_pct:
|
||||
:param track_grad_norm:
|
||||
:param check_val_every_n_epoch:
|
||||
:param fast_dev_run:
|
||||
:param accumulate_grad_batches:
|
||||
:param max_nb_epochs:
|
||||
:param min_nb_epochs:
|
||||
:param train_percent_check:
|
||||
:param val_percent_check:
|
||||
:param test_percent_check:
|
||||
:param val_check_interval:
|
||||
:param log_save_interval:
|
||||
:param add_log_row_interval:
|
||||
:param distributed_backend:
|
||||
'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none
|
||||
:param use_amp:
|
||||
:param print_nan_grads:
|
||||
:param print_weights_summary:
|
||||
:param amp_level:
|
||||
:param nb_sanity_val_steps:
|
||||
:param early_stop_callback: Callback for early stopping
|
||||
:param checkpoint_callback: Callback for checkpointing
|
||||
:param gradient_clip: int. 0 means don't clip.
|
||||
:param process_position: shown in the tqdm bar
|
||||
:param nb_gpu_nodes: number of GPU nodes
|
||||
:param gpus: list or string of gpu ids [0, 1] or '0,1'
|
||||
:param log_gpu_memory: Bool. If true, adds memory logs
|
||||
:param show_progress_bar: Bool. If true shows tqdm bar
|
||||
:param overfit_pct: float. uses this much of all datasets
|
||||
:param track_grad_norm: int. -1 no tracking. Otherwise tracks that norm
|
||||
:param check_val_every_n_epoch: int. check val every n train epochs
|
||||
:param fast_dev_run: Bool. runs full iteration over everything to find bugs
|
||||
:param accumulate_grad_batches: int. Accumulates grads every k batches
|
||||
:param max_nb_epochs: int.
|
||||
:param min_nb_epochs: int.
|
||||
:param train_percent_check: int. How much of train set to check
|
||||
:param val_percent_check: int. How much of val set to check
|
||||
:param test_percent_check: int. How much of test set to check
|
||||
:param val_check_interval: int. Check val this frequently within a train epoch
|
||||
:param log_save_interval: int. Writes logs to disk this often
|
||||
:param add_log_row_interval: int. How often to add logging rows
|
||||
:param distributed_backend: str. dp, or ddp.
|
||||
:param use_amp: Bool. If true uses apex for 16bit precision
|
||||
:param print_nan_grads: Bool. Prints nan gradients
|
||||
:param print_weights_summary: Bool. Prints summary of weights
|
||||
:param weights_save_path: Bool. Where to save weights if on cluster
|
||||
:param amp_level: str. Check nvidia docs for level
|
||||
:param nb_sanity_val_steps: int. How many val steps before a full train loop.
|
||||
"""
|
||||
# Transfer params
|
||||
self.nb_gpu_nodes = nb_gpu_nodes
|
||||
|
@ -128,7 +125,6 @@ class Trainer(TrainerIO):
|
|||
self.fast_dev_run = fast_dev_run
|
||||
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
||||
self.process_position = process_position
|
||||
self.current_gpu_name = current_gpu_name
|
||||
self.print_weights_summary = print_weights_summary
|
||||
self.max_nb_epochs = max_nb_epochs
|
||||
self.min_nb_epochs = min_nb_epochs
|
||||
|
@ -157,11 +153,11 @@ class Trainer(TrainerIO):
|
|||
self.current_epoch = 0
|
||||
self.total_batches = 0
|
||||
|
||||
# configure callbacks
|
||||
# configure early stop callback
|
||||
self.early_stop_callback = early_stop_callback
|
||||
self.checkpoint_callback = checkpoint_callback
|
||||
if self.checkpoint_callback is not None:
|
||||
self.checkpoint_callback.save_function = self.save_checkpoint
|
||||
|
||||
# configure weights save path
|
||||
self.__configure_weights_path(checkpoint_callback, weights_save_path)
|
||||
|
||||
# configure experiment
|
||||
self.experiment = experiment
|
||||
|
@ -204,6 +200,26 @@ class Trainer(TrainerIO):
|
|||
self.amp_level = amp_level
|
||||
self.__init_amp(use_amp)
|
||||
|
||||
def __configure_weights_path(self, checkpoint_callback, weights_save_path):
|
||||
"""
|
||||
Weight path set in this priority:
|
||||
Checkpoint_callback's path (if passed in).
|
||||
User provided weights_saved_path
|
||||
Otherwise use os.getcwd()
|
||||
"""
|
||||
self.weights_save_path = weights_save_path
|
||||
|
||||
# configure checkpoint callback
|
||||
self.checkpoint_callback = checkpoint_callback
|
||||
if self.checkpoint_callback is not None:
|
||||
self.checkpoint_callback.save_function = self.save_checkpoint
|
||||
|
||||
# if checkpoint callback used, then override the weights path
|
||||
self.weights_save_path = self.checkpoint_callback.filepath
|
||||
|
||||
# if weights_save_path is still none here, set to current workingdir
|
||||
self.weights_save_path = os.getcwd()
|
||||
|
||||
def __init_amp(self, use_amp):
|
||||
self.use_amp = use_amp and APEX_AVAILABLE
|
||||
if self.use_amp:
|
||||
|
@ -289,37 +305,6 @@ class Trainer(TrainerIO):
|
|||
# likely not on slurm, so set the slurm managed flag to false
|
||||
self.is_slurm_managing_tasks = False
|
||||
|
||||
def restore_state_if_existing_checkpoint(self):
|
||||
# restore trainer state and model if there is a weight for this experiment
|
||||
last_epoch = -1
|
||||
last_ckpt_name = None
|
||||
|
||||
# do nothing if there's not dir or callback
|
||||
no_ckpt_callback = self.checkpoint_callback is None
|
||||
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
|
||||
return
|
||||
|
||||
# find last epoch
|
||||
checkpoints = os.listdir(self.checkpoint_callback.filepath)
|
||||
for name in checkpoints:
|
||||
# ignore hpc ckpts
|
||||
if 'hpc_' in name:
|
||||
continue
|
||||
|
||||
if '.ckpt' in name:
|
||||
epoch = name.split('epoch_')[1]
|
||||
epoch = int(re.sub('[^0-9]', '', epoch))
|
||||
|
||||
if epoch > last_epoch:
|
||||
last_epoch = epoch
|
||||
last_ckpt_name = name
|
||||
|
||||
# restore last checkpoint
|
||||
if last_ckpt_name is not None:
|
||||
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
|
||||
self.restore(last_ckpt_path, self.on_gpu)
|
||||
print(f'model and trainer restored from checkpoint: {last_ckpt_path}')
|
||||
|
||||
@property
|
||||
def data_parallel(self):
|
||||
return self.use_dp or self.use_ddp
|
||||
|
@ -367,7 +352,7 @@ class Trainer(TrainerIO):
|
|||
tqdm_dic.update(self.tqdm_metrics)
|
||||
|
||||
if self.on_gpu:
|
||||
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)
|
||||
tqdm_dic['gpu'] = '{}'.format(torch.cuda.current_device())
|
||||
|
||||
return tqdm_dic
|
||||
|
||||
|
|
|
@ -28,11 +28,6 @@ class TrainerIO(object):
|
|||
:param model:
|
||||
:return:
|
||||
"""
|
||||
# do nothing if there's not dir or callback
|
||||
no_ckpt_callback = self.checkpoint_callback is None
|
||||
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
|
||||
return
|
||||
|
||||
# restore weights if same exp version
|
||||
self.restore_state_if_checkpoint_exists(model)
|
||||
|
||||
|
@ -40,6 +35,11 @@ class TrainerIO(object):
|
|||
self.restore_hpc_weights_if_needed(model)
|
||||
|
||||
def restore_state_if_checkpoint_exists(self, model):
|
||||
# do nothing if there's not dir or callback
|
||||
no_ckpt_callback = self.checkpoint_callback is None
|
||||
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
|
||||
return
|
||||
|
||||
# restore trainer state and model if there is a weight for this experiment
|
||||
last_epoch = -1
|
||||
last_ckpt_name = None
|
||||
|
@ -87,7 +87,7 @@ class TrainerIO(object):
|
|||
if self.proc_rank == 0:
|
||||
# save weights
|
||||
print('handling SIGUSR1')
|
||||
self.hpc_save(self.checkpoint_callback.filepath, self.experiment)
|
||||
self.hpc_save(self.weights_save_path, self.experiment)
|
||||
|
||||
# find job id
|
||||
job_id = os.environ['SLURM_JOB_ID']
|
||||
|
@ -179,7 +179,7 @@ class TrainerIO(object):
|
|||
:return:
|
||||
"""
|
||||
# look for hpc weights
|
||||
folderpath = self.checkpoint_callback.filepath
|
||||
folderpath = self.weights_save_path
|
||||
if os.path.exists(folderpath):
|
||||
files = os.listdir(folderpath)
|
||||
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]
|
||||
|
|
Loading…
Reference in New Issue