Weights path (#211)

* added docs. removed options. added weights_save option

* removed old restore

* cleaned up save path

* cleaned up save path

* flake8
This commit is contained in:
William Falcon 2019-09-06 17:01:03 -04:00 committed by GitHub
parent 3e74ea15d8
commit 0c7fbc7178
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 76 deletions

View File

@ -59,7 +59,6 @@ class Trainer(TrainerIO):
checkpoint_callback=None,
gradient_clip=0,
process_position=0,
current_gpu_name=0,
nb_gpu_nodes=1,
gpus=None,
log_gpu_memory=False,
@ -81,42 +80,40 @@ class Trainer(TrainerIO):
use_amp=False,
print_nan_grads=False,
print_weights_summary=True,
weights_save_path=None,
amp_level='O2',
nb_sanity_val_steps=5):
"""
:param experiment: Test-tube experiment
:param early_stop_callback: from pytorch_lightning import EarlyStopping
:param checkpoint_callback: from pytorch_lightning import Checkpoint
:param gradient_clip:
:param process_position:
:param current_gpu_name:
:param nb_gpu_nodes:
:param gpus:
:param log_gpu_memory: Log GPU memory utilization as metric
during training. This can lead to lower performance on some
servers, in particular when `nvidia-smi` is slow.
:param show_progress_bar:
:param overfit_pct:
:param track_grad_norm:
:param check_val_every_n_epoch:
:param fast_dev_run:
:param accumulate_grad_batches:
:param max_nb_epochs:
:param min_nb_epochs:
:param train_percent_check:
:param val_percent_check:
:param test_percent_check:
:param val_check_interval:
:param log_save_interval:
:param add_log_row_interval:
:param distributed_backend:
'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none
:param use_amp:
:param print_nan_grads:
:param print_weights_summary:
:param amp_level:
:param nb_sanity_val_steps:
:param early_stop_callback: Callback for early stopping
:param checkpoint_callback: Callback for checkpointing
:param gradient_clip: int. 0 means don't clip.
:param process_position: shown in the tqdm bar
:param nb_gpu_nodes: number of GPU nodes
:param gpus: list or string of gpu ids [0, 1] or '0,1'
:param log_gpu_memory: Bool. If true, adds memory logs
:param show_progress_bar: Bool. If true shows tqdm bar
:param overfit_pct: float. uses this much of all datasets
:param track_grad_norm: int. -1 no tracking. Otherwise tracks that norm
:param check_val_every_n_epoch: int. check val every n train epochs
:param fast_dev_run: Bool. runs full iteration over everything to find bugs
:param accumulate_grad_batches: int. Accumulates grads every k batches
:param max_nb_epochs: int.
:param min_nb_epochs: int.
:param train_percent_check: int. How much of train set to check
:param val_percent_check: int. How much of val set to check
:param test_percent_check: int. How much of test set to check
:param val_check_interval: int. Check val this frequently within a train epoch
:param log_save_interval: int. Writes logs to disk this often
:param add_log_row_interval: int. How often to add logging rows
:param distributed_backend: str. dp, or ddp.
:param use_amp: Bool. If true uses apex for 16bit precision
:param print_nan_grads: Bool. Prints nan gradients
:param print_weights_summary: Bool. Prints summary of weights
:param weights_save_path: Bool. Where to save weights if on cluster
:param amp_level: str. Check nvidia docs for level
:param nb_sanity_val_steps: int. How many val steps before a full train loop.
"""
# Transfer params
self.nb_gpu_nodes = nb_gpu_nodes
@ -128,7 +125,6 @@ class Trainer(TrainerIO):
self.fast_dev_run = fast_dev_run
self.on_gpu = gpus is not None and torch.cuda.is_available()
self.process_position = process_position
self.current_gpu_name = current_gpu_name
self.print_weights_summary = print_weights_summary
self.max_nb_epochs = max_nb_epochs
self.min_nb_epochs = min_nb_epochs
@ -157,11 +153,11 @@ class Trainer(TrainerIO):
self.current_epoch = 0
self.total_batches = 0
# configure callbacks
# configure early stop callback
self.early_stop_callback = early_stop_callback
self.checkpoint_callback = checkpoint_callback
if self.checkpoint_callback is not None:
self.checkpoint_callback.save_function = self.save_checkpoint
# configure weights save path
self.__configure_weights_path(checkpoint_callback, weights_save_path)
# configure experiment
self.experiment = experiment
@ -204,6 +200,26 @@ class Trainer(TrainerIO):
self.amp_level = amp_level
self.__init_amp(use_amp)
def __configure_weights_path(self, checkpoint_callback, weights_save_path):
"""
Weight path set in this priority:
Checkpoint_callback's path (if passed in).
User provided weights_saved_path
Otherwise use os.getcwd()
"""
self.weights_save_path = weights_save_path
# configure checkpoint callback
self.checkpoint_callback = checkpoint_callback
if self.checkpoint_callback is not None:
self.checkpoint_callback.save_function = self.save_checkpoint
# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.filepath
# if weights_save_path is still none here, set to current workingdir
self.weights_save_path = os.getcwd()
def __init_amp(self, use_amp):
self.use_amp = use_amp and APEX_AVAILABLE
if self.use_amp:
@ -289,37 +305,6 @@ class Trainer(TrainerIO):
# likely not on slurm, so set the slurm managed flag to false
self.is_slurm_managing_tasks = False
def restore_state_if_existing_checkpoint(self):
# restore trainer state and model if there is a weight for this experiment
last_epoch = -1
last_ckpt_name = None
# do nothing if there's not dir or callback
no_ckpt_callback = self.checkpoint_callback is None
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
return
# find last epoch
checkpoints = os.listdir(self.checkpoint_callback.filepath)
for name in checkpoints:
# ignore hpc ckpts
if 'hpc_' in name:
continue
if '.ckpt' in name:
epoch = name.split('epoch_')[1]
epoch = int(re.sub('[^0-9]', '', epoch))
if epoch > last_epoch:
last_epoch = epoch
last_ckpt_name = name
# restore last checkpoint
if last_ckpt_name is not None:
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
self.restore(last_ckpt_path, self.on_gpu)
print(f'model and trainer restored from checkpoint: {last_ckpt_path}')
@property
def data_parallel(self):
return self.use_dp or self.use_ddp
@ -367,7 +352,7 @@ class Trainer(TrainerIO):
tqdm_dic.update(self.tqdm_metrics)
if self.on_gpu:
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)
tqdm_dic['gpu'] = '{}'.format(torch.cuda.current_device())
return tqdm_dic

View File

@ -28,11 +28,6 @@ class TrainerIO(object):
:param model:
:return:
"""
# do nothing if there's not dir or callback
no_ckpt_callback = self.checkpoint_callback is None
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
return
# restore weights if same exp version
self.restore_state_if_checkpoint_exists(model)
@ -40,6 +35,11 @@ class TrainerIO(object):
self.restore_hpc_weights_if_needed(model)
def restore_state_if_checkpoint_exists(self, model):
# do nothing if there's not dir or callback
no_ckpt_callback = self.checkpoint_callback is None
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
return
# restore trainer state and model if there is a weight for this experiment
last_epoch = -1
last_ckpt_name = None
@ -87,7 +87,7 @@ class TrainerIO(object):
if self.proc_rank == 0:
# save weights
print('handling SIGUSR1')
self.hpc_save(self.checkpoint_callback.filepath, self.experiment)
self.hpc_save(self.weights_save_path, self.experiment)
# find job id
job_id = os.environ['SLURM_JOB_ID']
@ -179,7 +179,7 @@ class TrainerIO(object):
:return:
"""
# look for hpc weights
folderpath = self.checkpoint_callback.filepath
folderpath = self.weights_save_path
if os.path.exists(folderpath):
files = os.listdir(folderpath)
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]