diff --git a/pytorch_lightning/trainer/evaluate_loop.py b/pytorch_lightning/trainer/evaluate_loop.py index b57324a073..0bde682ef8 100644 --- a/pytorch_lightning/trainer/evaluate_loop.py +++ b/pytorch_lightning/trainer/evaluate_loop.py @@ -1,8 +1,6 @@ -import torch from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.core.step_result import Result, EvalResult from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities import flatten_dict from pytorch_lightning.utilities.model_utils import is_overridden @@ -145,28 +143,7 @@ class EvaluationLoop(object): def log_epoch_metrics(self, eval_results): using_eval_result = self.is_using_eval_results() - if using_eval_result: - if isinstance(eval_results, list): - for eval_result in eval_results: - self.trainer.logger_connector.callback_metrics = eval_result.callback_metrics - else: - self.trainer.logger_connector.callback_metrics = eval_results.callback_metrics - else: - if isinstance(eval_results, list): - for eval_result in eval_results: - # with a scalar return, auto set it to "val_loss" for callbacks - if isinstance(eval_result, torch.Tensor): - flat = {'val_loss': eval_result} - else: - flat = flatten_dict(eval_result) - self.trainer.logger_connector.callback_metrics.update(flat) - else: - # with a scalar return, auto set it to "val_loss" for callbacks - if isinstance(eval_results, torch.Tensor): - flat = {'val_loss': eval_results} - else: - flat = flatten_dict(eval_results) - self.trainer.logger_connector.callback_metrics.update(flat) + self.trainer.logger_connector.on_evaluation_epoch_end(eval_results, using_eval_result) def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() diff --git a/pytorch_lightning/trainer/logger_connector.py b/pytorch_lightning/trainer/logger_connector.py index 11042b460b..3148f4ac59 100644 --- a/pytorch_lightning/trainer/logger_connector.py +++ b/pytorch_lightning/trainer/logger_connector.py @@ -13,6 +13,9 @@ # limitations under the License. import torch from pytorch_lightning.core import memory +from pytorch_lightning.utilities import flatten_dict +from pytorch_lightning.utilities.model_utils import is_overridden +from pytorch_lightning.core.step_result import Result class LoggerConnector: @@ -69,3 +72,166 @@ class LoggerConnector: self.progress_bar_metrics[k] = v self.trainer.dev_debugger.track_pbar_metrics_history(metrics) + + def on_evaluation_epoch_end(self, eval_results, using_eval_result): + if using_eval_result: + if isinstance(eval_results, list): + for eval_result in eval_results: + self.trainer.logger_connector.callback_metrics = eval_result.callback_metrics + else: + self.trainer.logger_connector.callback_metrics = eval_results.callback_metrics + else: + if isinstance(eval_results, list): + for eval_result in eval_results: + # with a scalar return, auto set it to "val_loss" for callbacks + if isinstance(eval_result, torch.Tensor): + flat = {'val_loss': eval_result} + else: + flat = flatten_dict(eval_result) + self.trainer.logger_connector.callback_metrics.update(flat) + else: + # with a scalar return, auto set it to "val_loss" for callbacks + if isinstance(eval_results, torch.Tensor): + flat = {'val_loss': eval_results} + else: + flat = flatten_dict(eval_results) + self.trainer.logger_connector.callback_metrics.update(flat) + + def on_train_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers): + self.log_train_epoch_end_metrics(epoch_output, checkpoint_accumulator, + early_stopping_accumulator, num_optimizers) + + def log_train_epoch_end_metrics(self, + epoch_output, + checkpoint_accumulator, + early_stopping_accumulator, + num_optimizers): + # epoch output is a list. Each item in that list has all the outputs per optimizer + # epoch_output[optimizer_idx][training_step_idx][tbptt_index] + # remember that not using truncated backprop is equivalent with truncated back prop of len(1) + + model = self.trainer.get_model() + + epoch_log_metrics = {} + epoch_callback_metrics = {} + epoch_progress_bar_metrics = {} + + # ----------------------- + # Calculate epoch callback values if given + # ----------------------- + if checkpoint_accumulator.num_values > 0: + epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean() + + if early_stopping_accumulator.num_values > 0: + epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean() + + # ------------------------ + # determine if using a result obj + # ------------------------ + # [optimizer_idx][training_step_idx][tbptt_index] + opt_idx_outputs = epoch_output[0] + + try: + sample_obj = opt_idx_outputs[0][0] if isinstance(opt_idx_outputs[0], list) else opt_idx_outputs[0] + is_result_obj = len(epoch_output) > 0 and isinstance(sample_obj, Result) + except IndexError as e: + is_result_obj = False + + # -------------------------- + # EPOCH END STEP IF DEFINED + # -------------------------- + if is_overridden('training_epoch_end', model=model): + self.trainer.global_step += 1 + + if is_result_obj: + # with result object gather across time and training steps so each opt idx has a single result obj + epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output) + + if num_optimizers == 1: + epoch_output = epoch_output[0] + + # run training_epoch_end + # a list with a result per optimizer index + epoch_output = model.training_epoch_end(epoch_output) + + if isinstance(epoch_output, Result): + epoch_log_metrics = epoch_output.epoch_log_metrics + epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics + else: + _processed_outputs = self.trainer.process_output(epoch_output) + epoch_progress_bar_metrics = _processed_outputs[1] + epoch_log_metrics = _processed_outputs[2] + epoch_callback_metrics = _processed_outputs[3] + + # -------------------------- + # Structured Result (auto epoch end) + # -------------------------- + elif is_result_obj: + epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) + + # -------------------------- + # track results + # -------------------------- + # add the metrics to the loggers + if epoch_log_metrics and len(epoch_log_metrics) > 0: + self.log_metrics(epoch_log_metrics, {}) + + # add metrics to callbacks + self.callback_metrics.update(epoch_callback_metrics) + + # add metrics to progress_bar + if len(epoch_progress_bar_metrics) > 0: + self.add_progress_bar_metrics(epoch_progress_bar_metrics) + + def __auto_reduce_results_on_epoch_end(self, epoch_output): + epoch_log_metrics = {} + epoch_progress_bar_metrics = {} + for opt_outputs in epoch_output: + # reduce across time first + time_reduced_outputs = [] + for train_step_idx in range(len(opt_outputs)): + tbptt_outs = opt_outputs[train_step_idx] + tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) + time_reduced_outputs.append(tbptt_outs) + + # reduce across training steps + opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) + opt_outputs.minimize = opt_outputs.minimize.mean() + epoch_log_metrics.update(opt_outputs.epoch_log_metrics) + epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics) + + return epoch_log_metrics, epoch_progress_bar_metrics + + def __gather_result_across_time_and_optimizers(self, epoch_output): + """ + Gather results into a single padded tensor per metric where each tensor is gathered across + time and across time steps. + + Returns: + a list where each element is a Result with the tensors gathered + """ + gathered_epoch_outputs = [] + for opt_outputs in epoch_output: + # gather across time first + time_gathered_outputs = [] + for train_step_idx in range(len(opt_outputs)): + tbptt_outs = opt_outputs[train_step_idx] + tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs) + time_gathered_outputs.append(tbptt_outs) + + # gather across training steps + # each metric has dimensions (training_steps, seq_len) (seq_len=1 when no tbptt is used) + gathered_opt_output = time_gathered_outputs[0].__class__.padded_gather(time_gathered_outputs) + gathered_epoch_outputs.append(gathered_opt_output) + + return gathered_epoch_outputs + + def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output): + # when metrics should be logged + should_log_metrics = (batch_idx + 1) % self.trainer.row_log_interval == 0 or self.trainer.should_stop + if should_log_metrics or self.trainer.fast_dev_run: + # logs user requested information to logger + metrics = batch_output.batch_log_metrics + grad_norm_dic = batch_output.grad_norm_dic + if len(metrics) > 0 or len(grad_norm_dic) > 0: + self.log_metrics(metrics, grad_norm_dic) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 2f26bf54af..cfef6a0712 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -109,137 +109,6 @@ class TrainerTrainLoopMixin(ABC): return epoch_end_outputs - def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers): - # epoch output is a list. Each item in that list has all the outputs per optimizer - # epoch_output[optimizer_idx][training_step_idx][tbptt_index] - # remember that not using truncated backprop is equivalent with truncated back prop of len(1) - - model = self.get_model() - - epoch_log_metrics = {} - epoch_callback_metrics = {} - epoch_progress_bar_metrics = {} - - # ----------------------- - # Calculate epoch callback values if given - # ----------------------- - if checkpoint_accumulator.num_values > 0: - epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean() - - if early_stopping_accumulator.num_values > 0: - epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean() - - # ------------------------ - # determine if using a result obj - # ------------------------ - # [optimizer_idx][training_step_idx][tbptt_index] - opt_idx_outputs = epoch_output[0] - - try: - sample_obj = opt_idx_outputs[0][0] if isinstance(opt_idx_outputs[0], list) else opt_idx_outputs[0] - is_result_obj = len(epoch_output) > 0 and isinstance(sample_obj, Result) - except IndexError as e: - is_result_obj = False - - # -------------------------- - # EPOCH END STEP IF DEFINED - # -------------------------- - if is_overridden('training_epoch_end', model=model): - self.global_step += 1 - - if is_result_obj: - # with result object gather across time and training steps so each opt idx has a single result obj - epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output) - - if num_optimizers == 1: - epoch_output = epoch_output[0] - - # run training_epoch_end - # a list with a result per optimizer index - epoch_output = model.training_epoch_end(epoch_output) - - if isinstance(epoch_output, Result): - epoch_log_metrics = epoch_output.epoch_log_metrics - epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics - else: - _processed_outputs = self.process_output(epoch_output) - epoch_progress_bar_metrics = _processed_outputs[1] - epoch_log_metrics = _processed_outputs[2] - epoch_callback_metrics = _processed_outputs[3] - - # -------------------------- - # Structured Result (auto epoch end) - # -------------------------- - elif is_result_obj: - epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output) - - # -------------------------- - # track results - # -------------------------- - # add the metrics to the loggers - if epoch_log_metrics and len(epoch_log_metrics) > 0: - self.logger_connector.log_metrics(epoch_log_metrics, {}) - - # add metrics to callbacks - self.logger_connector.callback_metrics.update(epoch_callback_metrics) - - # add metrics to progress_bar - if len(epoch_progress_bar_metrics) > 0: - self.logger_connector.add_progress_bar_metrics(epoch_progress_bar_metrics) - - def __auto_reduce_results_on_epoch_end(self, epoch_output): - epoch_log_metrics = {} - epoch_progress_bar_metrics = {} - for opt_outputs in epoch_output: - # reduce across time first - time_reduced_outputs = [] - for train_step_idx in range(len(opt_outputs)): - tbptt_outs = opt_outputs[train_step_idx] - tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) - time_reduced_outputs.append(tbptt_outs) - - # reduce across training steps - opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) - opt_outputs.minimize = opt_outputs.minimize.mean() - epoch_log_metrics.update(opt_outputs.epoch_log_metrics) - epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics) - - return epoch_log_metrics, epoch_progress_bar_metrics - - def __gather_result_across_time_and_optimizers(self, epoch_output): - """ - Gather results into a single padded tensor per metric where each tensor is gathered across - time and across time steps. - - Returns: - a list where each element is a Result with the tensors gathered - """ - gathered_epoch_outputs = [] - for opt_outputs in epoch_output: - # gather across time first - time_gathered_outputs = [] - for train_step_idx in range(len(opt_outputs)): - tbptt_outs = opt_outputs[train_step_idx] - tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs) - time_gathered_outputs.append(tbptt_outs) - - # gather across training steps - # each metric has dimensions (training_steps, seq_len) (seq_len=1 when no tbptt is used) - gathered_opt_output = time_gathered_outputs[0].__class__.padded_gather(time_gathered_outputs) - gathered_epoch_outputs.append(gathered_opt_output) - - return gathered_epoch_outputs - - def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output): - # when metrics should be logged - should_log_metrics = (batch_idx + 1) % self.row_log_interval == 0 or self.should_stop - if should_log_metrics or self.fast_dev_run: - # logs user requested information to logger - metrics = batch_output.batch_log_metrics - grad_norm_dic = batch_output.grad_norm_dic - if len(metrics) > 0 or len(grad_norm_dic) > 0: - self.logger_connector.log_metrics(metrics, grad_norm_dic) - def save_loggers_in_training_loop(self, batch_idx): # when loggers should save to disk should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or self.should_stop diff --git a/pytorch_lightning/trainer/training_loop_temp.py b/pytorch_lightning/trainer/training_loop_temp.py index 73b86d4909..0124c4fc17 100644 --- a/pytorch_lightning/trainer/training_loop_temp.py +++ b/pytorch_lightning/trainer/training_loop_temp.py @@ -370,7 +370,7 @@ class TrainLoop: # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- - self.trainer.save_train_loop_metrics_to_loggers(batch_idx, batch_output) + self.trainer.logger_connector.save_train_loop_metrics_to_loggers(batch_idx, batch_output) # update LR schedulers monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) @@ -391,7 +391,7 @@ class TrainLoop: break # process epoch outputs - self.trainer.run_training_epoch_end( + self.trainer.logger_connector.on_train_epoch_end( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator,