diff --git a/docs/source/logging.rst b/docs/source/logging.rst index 906240ce6e..79452b0ca8 100644 --- a/docs/source/logging.rst +++ b/docs/source/logging.rst @@ -6,7 +6,7 @@ .. role:: hidden :class: hidden-section - + .. _logging: @@ -57,9 +57,11 @@ Logging from a LightningModule Lightning offers automatic log functionalities for logging scalars, or manual logging for anything else. -Automatic logging +Automatic Logging ================= -Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method to log from anywhere in a :ref:`lightning_module`. +Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` +method to log from anywhere in a :ref:`lightning_module` and :ref:`callbacks` +except functions with `batch_start` in their names. .. code-block:: python @@ -95,6 +97,9 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a argument of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` or in the graphs plotted to the logger of your choice. +If your work requires to log in an unsupported function, please open an issue with a clear description of why it is blocking you. + + Manual logging ============== If you want to log anything that is not a scalar, like histograms, text, images, etc... you may need to use the logger object directly. @@ -144,8 +149,8 @@ Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:` def experiment(self): # Return the experiment object associated with this logger. pass - - @property + + @property def version(self): # Return the experiment version, int or str. return '0.1' @@ -238,7 +243,7 @@ if you are using a logger. These defaults can be customized by overriding the :func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module. .. code-block:: python - + def get_progress_bar_dict(self): # don't show the version number items = super().get_progress_bar_dict() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 9fc0be3c25..8dc993df7d 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -339,7 +339,9 @@ class EpochResultStore: self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info) # update logged_metrics, progress_bar_metrics, callback_metrics - self.update_logger_connector() + + if "epoch_end" in fx_name: + self.update_logger_connector() self.reset_model() @@ -355,18 +357,19 @@ class EpochResultStore: logger_connector = self.trainer.logger_connector callback_metrics = {} - is_train = self._stage == LoggerStages.TRAIN + batch_pbar_metrics = {} + batch_log_metrics = {} + is_train = self._stage in LoggerStages.TRAIN.value if not self._has_batch_loop_finished: # get pbar batch_pbar_metrics = self.get_latest_batch_pbar_metrics() logger_connector.add_progress_bar_metrics(batch_pbar_metrics) + batch_log_metrics = self.get_latest_batch_log_metrics() if is_train: # Only log and add to callback epoch step during evaluation, test. - batch_log_metrics = self.get_latest_batch_log_metrics() logger_connector.logged_metrics.update(batch_log_metrics) - callback_metrics.update(batch_pbar_metrics) callback_metrics.update(batch_log_metrics) else: @@ -393,6 +396,9 @@ class EpochResultStore: logger_connector.callback_metrics.update(callback_metrics) logger_connector.callback_metrics.pop("epoch", None) + batch_pbar_metrics.pop("debug_epoch", None) + return batch_pbar_metrics, batch_log_metrics + def run_batch_from_func_name(self, func_name) -> Dict: results = [getattr(hook_result, func_name) for hook_result in self._internals.values()] results = [func(include_forked_originals=False) for func in results] diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 56284afa7a..851a48e014 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -587,11 +587,13 @@ class LoggerConnector: return gathered_epoch_outputs def log_train_step_metrics(self, batch_output): + _, batch_log_metrics = self.cached_results.update_logger_connector() # when metrics should be logged if self.should_update_logs or self.trainer.fast_dev_run: # logs user requested information to logger - metrics = self.cached_results.get_latest_batch_log_metrics() grad_norm_dic = batch_output.grad_norm_dic - if len(metrics) > 0 or len(grad_norm_dic) > 0: - self.log_metrics(metrics, grad_norm_dic, log_train_step_metrics=True) - self.callback_metrics.update(metrics) + if grad_norm_dic is None: + grad_norm_dic = {} + if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0: + self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True) + self.callback_metrics.update(batch_log_metrics) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 097727a6be..4b70917c8c 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -106,9 +106,9 @@ class EvaluationLoop(object): def on_evaluation_end(self, *args, **kwargs): if self.testing: - self.trainer.call_hook('on_test_end', *args, capture=True, **kwargs) + self.trainer.call_hook('on_test_end', *args, **kwargs) else: - self.trainer.call_hook('on_validation_end', *args, capture=True, **kwargs) + self.trainer.call_hook('on_validation_end', *args, **kwargs) def reload_evaluation_dataloaders(self): model = self.trainer.get_model() @@ -329,9 +329,9 @@ class EvaluationLoop(object): def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook if self.testing: - self.trainer.call_hook('on_test_epoch_end', *args, capture=True, **kwargs) + self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: - self.trainer.call_hook('on_validation_epoch_end', *args, capture=True, **kwargs) + self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: @@ -346,10 +346,8 @@ class EvaluationLoop(object): self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx) def __log_result_step_metrics(self, step_log_metrics, step_pbar_metrics, batch_idx): - cached_batch_log_metrics = \ - self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics() - cached_batch_pbar_metrics = \ - self.trainer.logger_connector.cached_results.get_latest_batch_pbar_metrics() + cached_results = self.trainer.logger_connector.cached_results + cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector() step_log_metrics.update(cached_batch_log_metrics) step_pbar_metrics.update(cached_batch_pbar_metrics) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 81d8577ed1..92e3b6af2e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -855,6 +855,8 @@ class Trainer( model.setup(stage_name) def _reset_result_and_set_hook_fx_name(self, hook_name): + if "batch_start" in hook_name: + return True model_ref = self.get_model() if model_ref is not None: # used to track current hook name called @@ -868,10 +870,9 @@ class Trainer( # capture logging for this hook self.logger_connector.cache_logged_metrics() - def call_hook(self, hook_name, *args, capture=False, **kwargs): + def call_hook(self, hook_name, *args, **kwargs): # set hook_name to model + reset Result obj - if capture: - self._reset_result_and_set_hook_fx_name(hook_name) + skip = self._reset_result_and_set_hook_fx_name(hook_name) # always profile hooks with self.profiler.profile(hook_name): @@ -894,7 +895,7 @@ class Trainer( accelerator_hook = getattr(self.accelerator_backend, hook_name) output = accelerator_hook(*args, **kwargs) - if capture: + if not skip: self._cache_logged_metrics() return output diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9a4f324033..679f59c05e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -825,8 +825,8 @@ class TrainLoop: # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() - self.trainer.call_hook('on_epoch_end', capture=True) - self.trainer.call_hook('on_train_epoch_end', epoch_output, capture=True) + self.trainer.call_hook('on_epoch_end') + self.trainer.call_hook('on_train_epoch_end', epoch_output) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 2c89e19880..20d1c6fdd5 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" import sys from argparse import ArgumentParser diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index 022796c275..1a92891322 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -472,8 +472,6 @@ def test_log_works_in_val_callback(tmpdir): "forked": False, "func_name": func_name} - """ - def on_validation_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_validation_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -486,6 +484,7 @@ def test_log_works_in_val_callback(tmpdir): self.make_logging(pl_module, 'on_validation_epoch_start', 3, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) + """ def on_batch_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -493,6 +492,7 @@ def test_log_works_in_val_callback(tmpdir): def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): self.make_logging(pl_module, 'on_validation_batch_start', 5, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) + """ def on_batch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, @@ -510,8 +510,6 @@ def test_log_works_in_val_callback(tmpdir): self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) - """ - def on_validation_epoch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_validation_epoch_end', 9, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) @@ -541,16 +539,14 @@ def test_log_works_in_val_callback(tmpdir): trainer.fit(model) trainer.test() - """ assert test_callback.funcs_called_count["on_epoch_start"] == 1 - assert test_callback.funcs_called_count["on_batch_start"] == 1 + # assert test_callback.funcs_called_count["on_batch_start"] == 1 assert test_callback.funcs_called_count["on_batch_end"] == 1 assert test_callback.funcs_called_count["on_validation_start"] == 1 assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1 - assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 + # assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 assert test_callback.funcs_called_count["on_validation_batch_end"] == 4 assert test_callback.funcs_called_count["on_epoch_end"] == 1 - """ assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1 # Make sure the func_name exists within callback_metrics. If not, we missed some @@ -662,7 +658,6 @@ def test_log_works_in_test_callback(tmpdir): "forked": False, "func_name": func_name} - """ def on_test_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -675,11 +670,8 @@ def test_log_works_in_test_callback(tmpdir): self.make_logging(pl_module, 'on_test_epoch_start', 3, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) - def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - self.make_logging(pl_module, 'on_test_batch_start', 4, on_steps=self.choices, - on_epochs=self.choices, prob_bars=self.choices) - def on_test_step_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self.make_logging(pl_module, 'on_test_step_end', 5, on_steps=self.choices, + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_test_batch_end', 5, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) # used to make sure aggregation works fine. @@ -690,7 +682,6 @@ def test_log_works_in_test_callback(tmpdir): def on_epoch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_epoch_end', 6, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) - """ def on_test_epoch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_test_epoch_end', 7, on_steps=[False], @@ -728,13 +719,11 @@ def test_log_works_in_test_callback(tmpdir): ) trainer.fit(model) trainer.test() - """ + assert test_callback.funcs_called_count["on_test_start"] == 1 assert test_callback.funcs_called_count["on_epoch_start"] == 2 assert test_callback.funcs_called_count["on_test_epoch_start"] == 1 - assert test_callback.funcs_called_count["on_test_batch_start"] == 4 - assert test_callback.funcs_called_count["on_test_step_end"] == 4 - """ + assert test_callback.funcs_called_count["on_test_batch_end"] == 4 assert test_callback.funcs_called_count["on_test_epoch_end"] == 1 # Make sure the func_name exists within callback_metrics. If not, we missed some diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 9be44c68fa..c148748888 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -558,7 +558,7 @@ def test_log_works_in_train_callback(tmpdir): "prog_bar": prog_bar, "forked": False, "func_name": func_name} - """ + def on_train_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_train_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -571,15 +571,6 @@ def test_log_works_in_train_callback(tmpdir): self.make_logging(pl_module, 'on_train_epoch_start', 3, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) - def on_batch_start(self, trainer, pl_module): - self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, - on_epochs=self.choices, prob_bars=self.choices) - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - self.make_logging(pl_module, 'on_train_batch_start', 5, on_steps=self.choices, - on_epochs=self.choices, prob_bars=self.choices) - - def on_batch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -592,7 +583,6 @@ def test_log_works_in_train_callback(tmpdir): # with func = np.mean if on_epoch else func = np.max self.count += 1 - """ def on_epoch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) @@ -629,17 +619,12 @@ def test_log_works_in_train_callback(tmpdir): ) trainer.fit(model) - """ assert test_callback.funcs_called_count["on_train_start"] == 1 assert test_callback.funcs_called_count["on_epoch_start"] == 2 assert test_callback.funcs_called_count["on_train_epoch_start"] == 2 - assert test_callback.funcs_called_count["on_batch_start"] == 4 - assert test_callback.funcs_called_count["on_train_batch_start"] == 4 assert test_callback.funcs_called_count["on_batch_end"] == 4 assert test_callback.funcs_called_count["on_epoch_end"] == 2 assert test_callback.funcs_called_count["on_train_batch_end"] == 4 - - """ assert test_callback.funcs_called_count["on_epoch_end"] == 2 assert test_callback.funcs_called_count["on_train_epoch_end"] == 2