From 0c7fbc7178d426c26403f5722380a5199a0260a4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 6 Sep 2019 17:01:03 -0400 Subject: [PATCH] Weights path (#211) * added docs. removed options. added weights_save option * removed old restore * cleaned up save path * cleaned up save path * flake8 --- pytorch_lightning/trainer/trainer.py | 123 +++++++++++------------- pytorch_lightning/trainer/trainer_io.py | 14 +-- 2 files changed, 61 insertions(+), 76 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ffb793cae4..aa673abd1f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -59,7 +59,6 @@ class Trainer(TrainerIO): checkpoint_callback=None, gradient_clip=0, process_position=0, - current_gpu_name=0, nb_gpu_nodes=1, gpus=None, log_gpu_memory=False, @@ -81,42 +80,40 @@ class Trainer(TrainerIO): use_amp=False, print_nan_grads=False, print_weights_summary=True, + weights_save_path=None, amp_level='O2', nb_sanity_val_steps=5): """ :param experiment: Test-tube experiment - :param early_stop_callback: from pytorch_lightning import EarlyStopping - :param checkpoint_callback: from pytorch_lightning import Checkpoint - :param gradient_clip: - :param process_position: - :param current_gpu_name: - :param nb_gpu_nodes: - :param gpus: - :param log_gpu_memory: Log GPU memory utilization as metric - during training. This can lead to lower performance on some - servers, in particular when `nvidia-smi` is slow. - :param show_progress_bar: - :param overfit_pct: - :param track_grad_norm: - :param check_val_every_n_epoch: - :param fast_dev_run: - :param accumulate_grad_batches: - :param max_nb_epochs: - :param min_nb_epochs: - :param train_percent_check: - :param val_percent_check: - :param test_percent_check: - :param val_check_interval: - :param log_save_interval: - :param add_log_row_interval: - :param distributed_backend: - 'do' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none - :param use_amp: - :param print_nan_grads: - :param print_weights_summary: - :param amp_level: - :param nb_sanity_val_steps: + :param early_stop_callback: Callback for early stopping + :param checkpoint_callback: Callback for checkpointing + :param gradient_clip: int. 0 means don't clip. + :param process_position: shown in the tqdm bar + :param nb_gpu_nodes: number of GPU nodes + :param gpus: list or string of gpu ids [0, 1] or '0,1' + :param log_gpu_memory: Bool. If true, adds memory logs + :param show_progress_bar: Bool. If true shows tqdm bar + :param overfit_pct: float. uses this much of all datasets + :param track_grad_norm: int. -1 no tracking. Otherwise tracks that norm + :param check_val_every_n_epoch: int. check val every n train epochs + :param fast_dev_run: Bool. runs full iteration over everything to find bugs + :param accumulate_grad_batches: int. Accumulates grads every k batches + :param max_nb_epochs: int. + :param min_nb_epochs: int. + :param train_percent_check: int. How much of train set to check + :param val_percent_check: int. How much of val set to check + :param test_percent_check: int. How much of test set to check + :param val_check_interval: int. Check val this frequently within a train epoch + :param log_save_interval: int. Writes logs to disk this often + :param add_log_row_interval: int. How often to add logging rows + :param distributed_backend: str. dp, or ddp. + :param use_amp: Bool. If true uses apex for 16bit precision + :param print_nan_grads: Bool. Prints nan gradients + :param print_weights_summary: Bool. Prints summary of weights + :param weights_save_path: Bool. Where to save weights if on cluster + :param amp_level: str. Check nvidia docs for level + :param nb_sanity_val_steps: int. How many val steps before a full train loop. """ # Transfer params self.nb_gpu_nodes = nb_gpu_nodes @@ -128,7 +125,6 @@ class Trainer(TrainerIO): self.fast_dev_run = fast_dev_run self.on_gpu = gpus is not None and torch.cuda.is_available() self.process_position = process_position - self.current_gpu_name = current_gpu_name self.print_weights_summary = print_weights_summary self.max_nb_epochs = max_nb_epochs self.min_nb_epochs = min_nb_epochs @@ -157,11 +153,11 @@ class Trainer(TrainerIO): self.current_epoch = 0 self.total_batches = 0 - # configure callbacks + # configure early stop callback self.early_stop_callback = early_stop_callback - self.checkpoint_callback = checkpoint_callback - if self.checkpoint_callback is not None: - self.checkpoint_callback.save_function = self.save_checkpoint + + # configure weights save path + self.__configure_weights_path(checkpoint_callback, weights_save_path) # configure experiment self.experiment = experiment @@ -204,6 +200,26 @@ class Trainer(TrainerIO): self.amp_level = amp_level self.__init_amp(use_amp) + def __configure_weights_path(self, checkpoint_callback, weights_save_path): + """ + Weight path set in this priority: + Checkpoint_callback's path (if passed in). + User provided weights_saved_path + Otherwise use os.getcwd() + """ + self.weights_save_path = weights_save_path + + # configure checkpoint callback + self.checkpoint_callback = checkpoint_callback + if self.checkpoint_callback is not None: + self.checkpoint_callback.save_function = self.save_checkpoint + + # if checkpoint callback used, then override the weights path + self.weights_save_path = self.checkpoint_callback.filepath + + # if weights_save_path is still none here, set to current workingdir + self.weights_save_path = os.getcwd() + def __init_amp(self, use_amp): self.use_amp = use_amp and APEX_AVAILABLE if self.use_amp: @@ -289,37 +305,6 @@ class Trainer(TrainerIO): # likely not on slurm, so set the slurm managed flag to false self.is_slurm_managing_tasks = False - def restore_state_if_existing_checkpoint(self): - # restore trainer state and model if there is a weight for this experiment - last_epoch = -1 - last_ckpt_name = None - - # do nothing if there's not dir or callback - no_ckpt_callback = self.checkpoint_callback is None - if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath): - return - - # find last epoch - checkpoints = os.listdir(self.checkpoint_callback.filepath) - for name in checkpoints: - # ignore hpc ckpts - if 'hpc_' in name: - continue - - if '.ckpt' in name: - epoch = name.split('epoch_')[1] - epoch = int(re.sub('[^0-9]', '', epoch)) - - if epoch > last_epoch: - last_epoch = epoch - last_ckpt_name = name - - # restore last checkpoint - if last_ckpt_name is not None: - last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name) - self.restore(last_ckpt_path, self.on_gpu) - print(f'model and trainer restored from checkpoint: {last_ckpt_path}') - @property def data_parallel(self): return self.use_dp or self.use_ddp @@ -367,7 +352,7 @@ class Trainer(TrainerIO): tqdm_dic.update(self.tqdm_metrics) if self.on_gpu: - tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name) + tqdm_dic['gpu'] = '{}'.format(torch.cuda.current_device()) return tqdm_dic diff --git a/pytorch_lightning/trainer/trainer_io.py b/pytorch_lightning/trainer/trainer_io.py index 3fc0ba5d69..f5a869a883 100644 --- a/pytorch_lightning/trainer/trainer_io.py +++ b/pytorch_lightning/trainer/trainer_io.py @@ -28,11 +28,6 @@ class TrainerIO(object): :param model: :return: """ - # do nothing if there's not dir or callback - no_ckpt_callback = self.checkpoint_callback is None - if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath): - return - # restore weights if same exp version self.restore_state_if_checkpoint_exists(model) @@ -40,6 +35,11 @@ class TrainerIO(object): self.restore_hpc_weights_if_needed(model) def restore_state_if_checkpoint_exists(self, model): + # do nothing if there's not dir or callback + no_ckpt_callback = self.checkpoint_callback is None + if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath): + return + # restore trainer state and model if there is a weight for this experiment last_epoch = -1 last_ckpt_name = None @@ -87,7 +87,7 @@ class TrainerIO(object): if self.proc_rank == 0: # save weights print('handling SIGUSR1') - self.hpc_save(self.checkpoint_callback.filepath, self.experiment) + self.hpc_save(self.weights_save_path, self.experiment) # find job id job_id = os.environ['SLURM_JOB_ID'] @@ -179,7 +179,7 @@ class TrainerIO(object): :return: """ # look for hpc weights - folderpath = self.checkpoint_callback.filepath + folderpath = self.weights_save_path if os.path.exists(folderpath): files = os.listdir(folderpath) hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]