From c591013708cb7b049e8f517ae39e21331882b45a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 22 Sep 2020 18:00:23 -0400 Subject: [PATCH] enable any logged metric to be accessible in callbacks (#3598) * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * clarify forward * clarify forward * clarify forward * clarify forward --- .../callbacks/model_checkpoint.py | 25 ++++++++++++------- .../trainer/connectors/logger_connector.py | 11 +++++--- pytorch_lightning/trainer/logging.py | 2 ++ tests/trainer/test_eval_loop_dict_return.py | 8 +++--- .../test_trainer_steps_result_return.py | 9 +++++++ .../test_validation_steps_result_return.py | 4 +++ 6 files changed, 43 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 43ef28c7d8..29eb2b1fb5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -22,6 +22,7 @@ Automatically save model checkpoints during training. import os import re +from copy import deepcopy from typing import Any, Dict, Optional import numpy as np @@ -219,7 +220,8 @@ class ModelCheckpoint(Callback): monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] - return monitor_op(current, self.best_k_models[self.kth_best_model_path]) + val = monitor_op(current, self.best_k_models[self.kth_best_model_path]) + return val @classmethod def _format_checkpoint_name( @@ -351,7 +353,11 @@ class ModelCheckpoint(Callback): self.epoch_last_check = epoch - ckpt_name_metrics = trainer.logger_connector.logged_metrics + # anything logged or in callbacks can be in the name + ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics) + ckpt_name_metrics.update(trainer.logger_connector.callback_metrics) + ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics) + filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics) version_cnt = 0 while self._fs.exists(filepath): @@ -366,18 +372,19 @@ class ModelCheckpoint(Callback): if not isinstance(current, torch.Tensor): rank_zero_warn( - f"The metric you returned {current} must be a `torch.Tensor` instance, checkpoint not saved" - f" HINT: what is the value of {self.monitor} in validation_epoch_end()?", + f"The metric you returned {self.monitor}={current} must be a `torch.Tensor` " + f"instance, checkpoint not saved HINT: what is the value of {self.monitor}?", RuntimeWarning, ) if current is not None: - current = torch.tensor(current) + current = torch.tensor(current).to(pl_module.device) if current is None: - rank_zero_warn( - f"Can save best model only with {self.monitor} available, skipping.", - RuntimeWarning, - ) + m = f"Can save best model only with {self.monitor} available, skipping." + if self.monitor == 'checkpoint_on': + m = f'No checkpoint_on found. Hint: Did you set it in EvalResult(checkpoint_on=tensor) or ' \ + f'TrainResult(checkpoint_on=tensor)?' + rank_zero_warn(m, RuntimeWarning) elif self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch, trainer, pl_module) elif self.verbose: diff --git a/pytorch_lightning/trainer/connectors/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector.py index 386994bb90..a47a5e99e8 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector.py @@ -172,8 +172,10 @@ class LoggerConnector: # log metrics self.trainer.logger_connector.log_metrics(log_metrics, {}) - # track metrics for callbacks + # track metrics for callbacks (all prog bar, logged and callback metrics) self.trainer.logger_connector.callback_metrics.update(callback_metrics) + self.trainer.logger_connector.callback_metrics.update(log_metrics) + self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics) if len(dataloader_result_metrics) > 0: eval_loop_results.append(dataloader_result_metrics) @@ -263,16 +265,18 @@ class LoggerConnector: # -------------------------- # track results # -------------------------- - # add the metrics to the loggers + # add the metrics to the loggers and callbacks if epoch_log_metrics and len(epoch_log_metrics) > 0: self.log_metrics(epoch_log_metrics, {}) + self.callback_metrics.update(epoch_log_metrics) # add metrics to callbacks self.callback_metrics.update(epoch_callback_metrics) - # add metrics to progress_bar + # add metrics to progress_bar and callbacks if len(epoch_progress_bar_metrics) > 0: self.add_progress_bar_metrics(epoch_progress_bar_metrics) + self.callback_metrics.update(epoch_progress_bar_metrics) def __auto_reduce_results_on_epoch_end(self, epoch_output): epoch_log_metrics = {} @@ -326,3 +330,4 @@ class LoggerConnector: grad_norm_dic = batch_output.grad_norm_dic if len(metrics) > 0 or len(grad_norm_dic) > 0: self.log_metrics(metrics, grad_norm_dic) + self.callback_metrics.update(metrics) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 3aaa84468e..a0cca23af9 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -146,6 +146,8 @@ class TrainerLoggingMixin(ABC): # detach all metrics for callbacks to prevent memory leaks # no .item() because it will slow things down callback_metrics = recursive_detach(callback_metrics) + progress_bar_metrics = recursive_detach(progress_bar_metrics) + log_metrics = recursive_detach(log_metrics) # replace loss with checkpoint_on if 'loss' in callback_metrics: diff --git a/tests/trainer/test_eval_loop_dict_return.py b/tests/trainer/test_eval_loop_dict_return.py index 98aee4cb00..4b265ddb4a 100644 --- a/tests/trainer/test_eval_loop_dict_return.py +++ b/tests/trainer/test_eval_loop_dict_return.py @@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir): assert k in eval_results[1] # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) == 8 + assert len(trainer.logger_connector.callback_metrics) in [8, 9] # make sure correct steps were called assert model.validation_step_called @@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir): assert k in eval_results[1] # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) == 9 + assert len(trainer.logger_connector.callback_metrics) in [9, 10] # make sure correct steps were called assert model.validation_step_called @@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir): assert k in eval_results # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) == 9 + assert len(trainer.logger_connector.callback_metrics) in [9, 10] # make sure correct steps were called assert model.validation_step_called @@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir): assert k in eval_results # ensure all the keys ended up as candidates for callbacks - assert len(trainer.logger_connector.callback_metrics) == 10 + assert len(trainer.logger_connector.callback_metrics) in [10, 11] # make sure correct steps were called assert model.validation_step_called diff --git a/tests/trainer/test_trainer_steps_result_return.py b/tests/trainer/test_trainer_steps_result_return.py index dfd3d1474f..3835c21f75 100644 --- a/tests/trainer/test_trainer_steps_result_return.py +++ b/tests/trainer/test_trainer_steps_result_return.py @@ -58,6 +58,7 @@ def test_training_step_result_log_step_only(tmpdir): assert len(logged_metrics) == 4 # make sure we are using the correct metrics for callbacks + assert len(trainer.logger_connector.callback_metrics) == 8 assert trainer.logger_connector.callback_metrics['checkpoint_on'] == 171 # make sure pbar metrics are correct ang log metrics did not leak @@ -123,6 +124,8 @@ def test_training_step_result_log_epoch_only(tmpdir): assert not model.training_step_end_called assert not model.training_epoch_end_called + assert len(trainer.logger_connector.callback_metrics) == 12 + # make sure correct metrics are logged (one per batch step as requested) assert len(trainer.dev_debugger.logged_metrics) == epochs epoch_metrics = trainer.dev_debugger.logged_metrics @@ -198,6 +201,8 @@ def test_training_step_result_log_step_and_epoch(tmpdir): assert not model.training_step_end_called assert not model.training_epoch_end_called + assert len(trainer.logger_connector.callback_metrics) == 8 + # make sure correct metrics are logged (one per batch step as requested) assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs epoch_metrics = trainer.dev_debugger.logged_metrics @@ -323,6 +328,8 @@ def test_training_step_epoch_end_result(tmpdir): ) trainer.fit(model) + assert len(trainer.logger_connector.callback_metrics) == 11 + # make sure correct steps were called assert model.training_step_called assert not model.training_step_end_called @@ -403,6 +410,8 @@ def test_no_auto_callbacks_with_train_loop_only(tmpdir): ) trainer.fit(model) + assert len(trainer.logger_connector.callback_metrics) == 2 + all_losses = trainer.dev_debugger.saved_train_losses assert len(all_losses) == batches * epochs diff --git a/tests/trainer/test_validation_steps_result_return.py b/tests/trainer/test_validation_steps_result_return.py index 42b0c1aee3..5db54c9b70 100644 --- a/tests/trainer/test_validation_steps_result_return.py +++ b/tests/trainer/test_validation_steps_result_return.py @@ -278,6 +278,8 @@ def test_val_step_epoch_step_metrics(tmpdir): ) trainer.fit(model) + assert len(trainer.logger_connector.callback_metrics) == 7 + # make sure correct steps were called assert model.validation_step_called assert not model.validation_step_end_called @@ -352,6 +354,8 @@ def test_val_step_epoch_end_result(tmpdir): ) trainer.fit(model) + assert len(trainer.logger_connector.callback_metrics) == 6 + # make sure correct steps were called assert model.validation_step_called assert not model.validation_step_end_called