From e05586c4b2a54eca5d939c2328a4d1cdedb8a522 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Date: Sat, 22 Feb 2020 17:06:48 -0500 Subject: [PATCH] extract training teardown into method, catch KeyboardInterrupt (#856) Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/trainer.py | 3 - pytorch_lightning/trainer/training_loop.py | 164 +++++++++++---------- 2 files changed, 88 insertions(+), 79 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e2d76943a9..63e60ad330 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1060,9 +1060,6 @@ class Trainer(TrainerIOMixin, # CORE TRAINING LOOP self.train() - # summarize profile results - self.profiler.describe() - def test(self, model=None): r""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 87626d8cdb..6d12d6fe6f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -155,6 +155,7 @@ When this flag is enabled each batch is split into sequences of size truncated_b import copy import warnings from abc import ABC, abstractmethod +import logging as log import numpy as np @@ -307,98 +308,95 @@ class TrainerTrainLoopMixin(ABC): def train(self): warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) + # get model model = self.get_model() - # run all epochs - for epoch in range(self.current_epoch, self.max_epochs): - # set seed for distributed sampler (enables shuffling for each epoch) - if (self.use_ddp or self.use_tpu) \ - and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): - self.get_train_dataloader().sampler.set_epoch(epoch) + try: + # run all epochs + for epoch in range(self.current_epoch, self.max_epochs): + # set seed for distributed sampler (enables shuffling for each epoch) + if (self.use_ddp or self.use_tpu) \ + and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): + self.get_train_dataloader().sampler.set_epoch(epoch) - # get model - model = self.get_model() + # get model + model = self.get_model() - # update training progress in trainer and model - model.current_epoch = epoch - self.current_epoch = epoch + # update training progress in trainer and model + model.current_epoch = epoch + self.current_epoch = epoch - total_val_batches = 0 - is_val_epoch = False - if not self.disable_validation: - # val can be checked multiple times in epoch - is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 - val_checks_per_epoch = self.num_training_batches // self.val_check_batch - val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 - total_val_batches = self.num_val_batches * val_checks_per_epoch + total_val_batches = 0 + is_val_epoch = False + if not self.disable_validation: + # val can be checked multiple times in epoch + is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 + val_checks_per_epoch = self.num_training_batches // self.val_check_batch + val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 + total_val_batches = self.num_val_batches * val_checks_per_epoch - # total batches includes multiple val checks - self.total_batches = self.num_training_batches + total_val_batches - self.batch_loss_value = 0 # accumulated grads + # total batches includes multiple val checks + self.total_batches = self.num_training_batches + total_val_batches + self.batch_loss_value = 0 # accumulated grads - if self.fast_dev_run: - # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run - num_iterations = 2 - elif self.is_iterable_train_dataloader: - # for iterable train loader, the progress bar never ends - num_iterations = None - else: - num_iterations = self.total_batches + if self.fast_dev_run: + # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run + num_iterations = 2 + elif self.is_iterable_train_dataloader: + # for iterable train loader, the progress bar never ends + num_iterations = None + else: + num_iterations = self.total_batches - # reset progress bar - # .reset() doesn't work on disabled progress bar so we should check - if not self.main_progress_bar.disable: - self.main_progress_bar.reset(num_iterations) - desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else '' - self.main_progress_bar.set_description(desc) + # reset progress bar + # .reset() doesn't work on disabled progress bar so we should check + if not self.main_progress_bar.disable: + self.main_progress_bar.reset(num_iterations) + desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else '' + self.main_progress_bar.set_description(desc) - # changing gradient according accumulation_scheduler - self.accumulation_scheduler.on_epoch_begin() + # changing gradient according accumulation_scheduler + self.accumulation_scheduler.on_epoch_begin() - # ----------------- - # RUN TNG EPOCH - # ----------------- - self.run_training_epoch() + # ----------------- + # RUN TNG EPOCH + # ----------------- + self.run_training_epoch() - # update LR schedulers - if self.lr_schedulers is not None: - for lr_scheduler in self.lr_schedulers: - lr_scheduler.step() - if self.reduce_lr_on_plateau_scheduler is not None: - val_loss = self.callback_metrics.get('val_loss') - if val_loss is None: - avail_metrics = ','.join(list(self.callback_metrics.keys())) - m = f'ReduceLROnPlateau conditioned on metric val_loss ' \ - f'which is not available. Available metrics are: {avail_metrics}' - raise MisconfigurationException(m) - self.reduce_lr_on_plateau_scheduler.step(val_loss) + # update LR schedulers + if self.lr_schedulers is not None: + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step() + if self.reduce_lr_on_plateau_scheduler is not None: + val_loss = self.callback_metrics.get('val_loss') + if val_loss is None: + avail_metrics = ','.join(list(self.callback_metrics.keys())) + m = f'ReduceLROnPlateau conditioned on metric val_loss ' \ + f'which is not available. Available metrics are: {avail_metrics}' + raise MisconfigurationException(m) + self.reduce_lr_on_plateau_scheduler.step(val_loss) - if self.max_steps and self.max_steps == self.global_step: - self.main_progress_bar.close() - model.on_train_end() - return - - # early stopping - met_min_epochs = epoch >= self.min_epochs - 1 - met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - - if (self.enable_early_stop and not self.disable_validation and is_val_epoch and - ((met_min_epochs and met_min_steps) or self.fast_dev_run)): - should_stop = self.early_stop_callback.on_epoch_end() - # stop training - stop = should_stop and met_min_epochs - if stop: + if self.max_steps and self.max_steps == self.global_step: self.main_progress_bar.close() - with self.profiler.profile('on_train_end'): - model.on_train_end() + model.on_train_end() return - self.main_progress_bar.close() + # early stopping + met_min_epochs = epoch >= self.min_epochs - 1 + met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - with self.profiler.profile('on_train_end'): - model.on_train_end() + if (self.enable_early_stop and not self.disable_validation and is_val_epoch and + ((met_min_epochs and met_min_steps) or self.fast_dev_run)): + should_stop = self.early_stop_callback.on_epoch_end() + # stop training + stop = should_stop and met_min_epochs + if stop: + self.run_training_teardown() + return - if self.logger is not None: - self.logger.finalize("success") + self.run_training_teardown() + except KeyboardInterrupt: + log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') + self.run_training_teardown() def run_training_epoch(self): # before epoch hook @@ -622,6 +620,20 @@ class TrainerTrainLoopMixin(ABC): return 0, grad_norm_dic, all_log_metrics + def run_training_teardown(self): + model = self.get_model() + + self.main_progress_bar.close() + + with self.profiler.profile('on_train_end'): + model.on_train_end() + + if self.logger is not None: + self.logger.finalize("success") + + # summarize profile results + self.profiler.describe() + def training_forward(self, batch, batch_idx, opt_idx, hiddens): """ Handle forward for each training case (distributed, single gpu, etc...)