Enable`self.log` in most functions. (#4969)

* refactor

* solve pyright

* remove logging in batch_start functions

* update docs

* update doc

* resolve bug

* update

* correct script

* resolve on comments
This commit is contained in:
chaton 2020-12-06 13:01:43 +00:00 committed by GitHub
parent 9b1afa8c87
commit 2e838e6dd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 62 additions and 63 deletions

View File

@ -6,7 +6,7 @@
.. role:: hidden
:class: hidden-section
.. _logging:
@ -57,9 +57,11 @@ Logging from a LightningModule
Lightning offers automatic log functionalities for logging scalars, or manual logging for anything else.
Automatic logging
Automatic Logging
=================
Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method to log from anywhere in a :ref:`lightning_module`.
Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log`
method to log from anywhere in a :ref:`lightning_module` and :ref:`callbacks`
except functions with `batch_start` in their names.
.. code-block:: python
@ -95,6 +97,9 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a
argument of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` or in the graphs plotted to the logger of your choice.
If your work requires to log in an unsupported function, please open an issue with a clear description of why it is blocking you.
Manual logging
==============
If you want to log anything that is not a scalar, like histograms, text, images, etc... you may need to use the logger object directly.
@ -144,8 +149,8 @@ Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`
def experiment(self):
# Return the experiment object associated with this logger.
pass
@property
@property
def version(self):
# Return the experiment version, int or str.
return '0.1'
@ -238,7 +243,7 @@ if you are using a logger. These defaults can be customized by overriding the
:func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module.
.. code-block:: python
def get_progress_bar_dict(self):
# don't show the version number
items = super().get_progress_bar_dict()

View File

@ -339,7 +339,9 @@ class EpochResultStore:
self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info)
# update logged_metrics, progress_bar_metrics, callback_metrics
self.update_logger_connector()
if "epoch_end" in fx_name:
self.update_logger_connector()
self.reset_model()
@ -355,18 +357,19 @@ class EpochResultStore:
logger_connector = self.trainer.logger_connector
callback_metrics = {}
is_train = self._stage == LoggerStages.TRAIN
batch_pbar_metrics = {}
batch_log_metrics = {}
is_train = self._stage in LoggerStages.TRAIN.value
if not self._has_batch_loop_finished:
# get pbar
batch_pbar_metrics = self.get_latest_batch_pbar_metrics()
logger_connector.add_progress_bar_metrics(batch_pbar_metrics)
batch_log_metrics = self.get_latest_batch_log_metrics()
if is_train:
# Only log and add to callback epoch step during evaluation, test.
batch_log_metrics = self.get_latest_batch_log_metrics()
logger_connector.logged_metrics.update(batch_log_metrics)
callback_metrics.update(batch_pbar_metrics)
callback_metrics.update(batch_log_metrics)
else:
@ -393,6 +396,9 @@ class EpochResultStore:
logger_connector.callback_metrics.update(callback_metrics)
logger_connector.callback_metrics.pop("epoch", None)
batch_pbar_metrics.pop("debug_epoch", None)
return batch_pbar_metrics, batch_log_metrics
def run_batch_from_func_name(self, func_name) -> Dict:
results = [getattr(hook_result, func_name) for hook_result in self._internals.values()]
results = [func(include_forked_originals=False) for func in results]

View File

@ -587,11 +587,13 @@ class LoggerConnector:
return gathered_epoch_outputs
def log_train_step_metrics(self, batch_output):
_, batch_log_metrics = self.cached_results.update_logger_connector()
# when metrics should be logged
if self.should_update_logs or self.trainer.fast_dev_run:
# logs user requested information to logger
metrics = self.cached_results.get_latest_batch_log_metrics()
grad_norm_dic = batch_output.grad_norm_dic
if len(metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(metrics, grad_norm_dic, log_train_step_metrics=True)
self.callback_metrics.update(metrics)
if grad_norm_dic is None:
grad_norm_dic = {}
if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True)
self.callback_metrics.update(batch_log_metrics)

View File

@ -106,9 +106,9 @@ class EvaluationLoop(object):
def on_evaluation_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_end', *args, capture=True, **kwargs)
self.trainer.call_hook('on_test_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_end', *args, capture=True, **kwargs)
self.trainer.call_hook('on_validation_end', *args, **kwargs)
def reload_evaluation_dataloaders(self):
model = self.trainer.get_model()
@ -329,9 +329,9 @@ class EvaluationLoop(object):
def on_evaluation_epoch_end(self, *args, **kwargs):
# call the callback hook
if self.testing:
self.trainer.call_hook('on_test_epoch_end', *args, capture=True, **kwargs)
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, capture=True, **kwargs)
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
def log_evaluation_step_metrics(self, output, batch_idx):
if self.trainer.running_sanity_check:
@ -346,10 +346,8 @@ class EvaluationLoop(object):
self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)
def __log_result_step_metrics(self, step_log_metrics, step_pbar_metrics, batch_idx):
cached_batch_log_metrics = \
self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics()
cached_batch_pbar_metrics = \
self.trainer.logger_connector.cached_results.get_latest_batch_pbar_metrics()
cached_results = self.trainer.logger_connector.cached_results
cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
step_log_metrics.update(cached_batch_log_metrics)
step_pbar_metrics.update(cached_batch_pbar_metrics)

View File

@ -855,6 +855,8 @@ class Trainer(
model.setup(stage_name)
def _reset_result_and_set_hook_fx_name(self, hook_name):
if "batch_start" in hook_name:
return True
model_ref = self.get_model()
if model_ref is not None:
# used to track current hook name called
@ -868,10 +870,9 @@ class Trainer(
# capture logging for this hook
self.logger_connector.cache_logged_metrics()
def call_hook(self, hook_name, *args, capture=False, **kwargs):
def call_hook(self, hook_name, *args, **kwargs):
# set hook_name to model + reset Result obj
if capture:
self._reset_result_and_set_hook_fx_name(hook_name)
skip = self._reset_result_and_set_hook_fx_name(hook_name)
# always profile hooks
with self.profiler.profile(hook_name):
@ -894,7 +895,7 @@ class Trainer(
accelerator_hook = getattr(self.accelerator_backend, hook_name)
output = accelerator_hook(*args, **kwargs)
if capture:
if not skip:
self._cache_logged_metrics()
return output

View File

@ -825,8 +825,8 @@ class TrainLoop:
# inform logger the batch loop has finished
self.trainer.logger_connector.on_train_epoch_end()
self.trainer.call_hook('on_epoch_end', capture=True)
self.trainer.call_hook('on_train_epoch_end', epoch_output, capture=True)
self.trainer.call_hook('on_epoch_end')
self.trainer.call_hook('on_train_epoch_end', epoch_output)
def increment_accumulated_grad_global_step(self):
num_accumulated_batches_reached = self._accumulated_batches_reached()

View File

@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z"""
import sys
from argparse import ArgumentParser

View File

@ -472,8 +472,6 @@ def test_log_works_in_val_callback(tmpdir):
"forked": False,
"func_name": func_name}
"""
def on_validation_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_validation_start', 1, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
@ -486,6 +484,7 @@ def test_log_works_in_val_callback(tmpdir):
self.make_logging(pl_module, 'on_validation_epoch_start', 3, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
"""
def on_batch_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
@ -493,6 +492,7 @@ def test_log_works_in_val_callback(tmpdir):
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_validation_batch_start', 5, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
"""
def on_batch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices,
@ -510,8 +510,6 @@ def test_log_works_in_val_callback(tmpdir):
self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)
"""
def on_validation_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_validation_epoch_end', 9, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)
@ -541,16 +539,14 @@ def test_log_works_in_val_callback(tmpdir):
trainer.fit(model)
trainer.test()
"""
assert test_callback.funcs_called_count["on_epoch_start"] == 1
assert test_callback.funcs_called_count["on_batch_start"] == 1
# assert test_callback.funcs_called_count["on_batch_start"] == 1
assert test_callback.funcs_called_count["on_batch_end"] == 1
assert test_callback.funcs_called_count["on_validation_start"] == 1
assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1
assert test_callback.funcs_called_count["on_validation_batch_start"] == 4
# assert test_callback.funcs_called_count["on_validation_batch_start"] == 4
assert test_callback.funcs_called_count["on_validation_batch_end"] == 4
assert test_callback.funcs_called_count["on_epoch_end"] == 1
"""
assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1
# Make sure the func_name exists within callback_metrics. If not, we missed some
@ -662,7 +658,6 @@ def test_log_works_in_test_callback(tmpdir):
"forked": False,
"func_name": func_name}
"""
def on_test_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_test_start', 1, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
@ -675,11 +670,8 @@ def test_log_works_in_test_callback(tmpdir):
self.make_logging(pl_module, 'on_test_epoch_start', 3, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_test_batch_start', 4, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
def on_test_step_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_test_step_end', 5, on_steps=self.choices,
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_test_batch_end', 5, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
# used to make sure aggregation works fine.
@ -690,7 +682,6 @@ def test_log_works_in_test_callback(tmpdir):
def on_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_epoch_end', 6, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)
"""
def on_test_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_test_epoch_end', 7, on_steps=[False],
@ -728,13 +719,11 @@ def test_log_works_in_test_callback(tmpdir):
)
trainer.fit(model)
trainer.test()
"""
assert test_callback.funcs_called_count["on_test_start"] == 1
assert test_callback.funcs_called_count["on_epoch_start"] == 2
assert test_callback.funcs_called_count["on_test_epoch_start"] == 1
assert test_callback.funcs_called_count["on_test_batch_start"] == 4
assert test_callback.funcs_called_count["on_test_step_end"] == 4
"""
assert test_callback.funcs_called_count["on_test_batch_end"] == 4
assert test_callback.funcs_called_count["on_test_epoch_end"] == 1
# Make sure the func_name exists within callback_metrics. If not, we missed some

View File

@ -558,7 +558,7 @@ def test_log_works_in_train_callback(tmpdir):
"prog_bar": prog_bar,
"forked": False,
"func_name": func_name}
"""
def on_train_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_train_start', 1, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
@ -571,15 +571,6 @@ def test_log_works_in_train_callback(tmpdir):
self.make_logging(pl_module, 'on_train_epoch_start', 3, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
def on_batch_start(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
self.make_logging(pl_module, 'on_train_batch_start', 5, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
def on_batch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices,
on_epochs=self.choices, prob_bars=self.choices)
@ -592,7 +583,6 @@ def test_log_works_in_train_callback(tmpdir):
# with func = np.mean if on_epoch else func = np.max
self.count += 1
"""
def on_epoch_end(self, trainer, pl_module):
self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False],
on_epochs=self.choices, prob_bars=self.choices)
@ -629,17 +619,12 @@ def test_log_works_in_train_callback(tmpdir):
)
trainer.fit(model)
"""
assert test_callback.funcs_called_count["on_train_start"] == 1
assert test_callback.funcs_called_count["on_epoch_start"] == 2
assert test_callback.funcs_called_count["on_train_epoch_start"] == 2
assert test_callback.funcs_called_count["on_batch_start"] == 4
assert test_callback.funcs_called_count["on_train_batch_start"] == 4
assert test_callback.funcs_called_count["on_batch_end"] == 4
assert test_callback.funcs_called_count["on_epoch_end"] == 2
assert test_callback.funcs_called_count["on_train_batch_end"] == 4
"""
assert test_callback.funcs_called_count["on_epoch_end"] == 2
assert test_callback.funcs_called_count["on_train_epoch_end"] == 2