Enable`self.log` in most functions. (#4969)
* refactor * solve pyright * remove logging in batch_start functions * update docs * update doc * resolve bug * update * correct script * resolve on comments
This commit is contained in:
parent
9b1afa8c87
commit
2e838e6dd8
|
@ -6,7 +6,7 @@
|
|||
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
|
||||
.. _logging:
|
||||
|
||||
|
||||
|
@ -57,9 +57,11 @@ Logging from a LightningModule
|
|||
|
||||
Lightning offers automatic log functionalities for logging scalars, or manual logging for anything else.
|
||||
|
||||
Automatic logging
|
||||
Automatic Logging
|
||||
=================
|
||||
Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method to log from anywhere in a :ref:`lightning_module`.
|
||||
Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log`
|
||||
method to log from anywhere in a :ref:`lightning_module` and :ref:`callbacks`
|
||||
except functions with `batch_start` in their names.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -95,6 +97,9 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a
|
|||
argument of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` or in the graphs plotted to the logger of your choice.
|
||||
|
||||
|
||||
If your work requires to log in an unsupported function, please open an issue with a clear description of why it is blocking you.
|
||||
|
||||
|
||||
Manual logging
|
||||
==============
|
||||
If you want to log anything that is not a scalar, like histograms, text, images, etc... you may need to use the logger object directly.
|
||||
|
@ -144,8 +149,8 @@ Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`
|
|||
def experiment(self):
|
||||
# Return the experiment object associated with this logger.
|
||||
pass
|
||||
|
||||
@property
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
# Return the experiment version, int or str.
|
||||
return '0.1'
|
||||
|
@ -238,7 +243,7 @@ if you are using a logger. These defaults can be customized by overriding the
|
|||
:func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
||||
def get_progress_bar_dict(self):
|
||||
# don't show the version number
|
||||
items = super().get_progress_bar_dict()
|
||||
|
|
|
@ -339,7 +339,9 @@ class EpochResultStore:
|
|||
self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info)
|
||||
|
||||
# update logged_metrics, progress_bar_metrics, callback_metrics
|
||||
self.update_logger_connector()
|
||||
|
||||
if "epoch_end" in fx_name:
|
||||
self.update_logger_connector()
|
||||
|
||||
self.reset_model()
|
||||
|
||||
|
@ -355,18 +357,19 @@ class EpochResultStore:
|
|||
logger_connector = self.trainer.logger_connector
|
||||
|
||||
callback_metrics = {}
|
||||
is_train = self._stage == LoggerStages.TRAIN
|
||||
batch_pbar_metrics = {}
|
||||
batch_log_metrics = {}
|
||||
is_train = self._stage in LoggerStages.TRAIN.value
|
||||
|
||||
if not self._has_batch_loop_finished:
|
||||
# get pbar
|
||||
batch_pbar_metrics = self.get_latest_batch_pbar_metrics()
|
||||
logger_connector.add_progress_bar_metrics(batch_pbar_metrics)
|
||||
batch_log_metrics = self.get_latest_batch_log_metrics()
|
||||
|
||||
if is_train:
|
||||
# Only log and add to callback epoch step during evaluation, test.
|
||||
batch_log_metrics = self.get_latest_batch_log_metrics()
|
||||
logger_connector.logged_metrics.update(batch_log_metrics)
|
||||
|
||||
callback_metrics.update(batch_pbar_metrics)
|
||||
callback_metrics.update(batch_log_metrics)
|
||||
else:
|
||||
|
@ -393,6 +396,9 @@ class EpochResultStore:
|
|||
logger_connector.callback_metrics.update(callback_metrics)
|
||||
logger_connector.callback_metrics.pop("epoch", None)
|
||||
|
||||
batch_pbar_metrics.pop("debug_epoch", None)
|
||||
return batch_pbar_metrics, batch_log_metrics
|
||||
|
||||
def run_batch_from_func_name(self, func_name) -> Dict:
|
||||
results = [getattr(hook_result, func_name) for hook_result in self._internals.values()]
|
||||
results = [func(include_forked_originals=False) for func in results]
|
||||
|
|
|
@ -587,11 +587,13 @@ class LoggerConnector:
|
|||
return gathered_epoch_outputs
|
||||
|
||||
def log_train_step_metrics(self, batch_output):
|
||||
_, batch_log_metrics = self.cached_results.update_logger_connector()
|
||||
# when metrics should be logged
|
||||
if self.should_update_logs or self.trainer.fast_dev_run:
|
||||
# logs user requested information to logger
|
||||
metrics = self.cached_results.get_latest_batch_log_metrics()
|
||||
grad_norm_dic = batch_output.grad_norm_dic
|
||||
if len(metrics) > 0 or len(grad_norm_dic) > 0:
|
||||
self.log_metrics(metrics, grad_norm_dic, log_train_step_metrics=True)
|
||||
self.callback_metrics.update(metrics)
|
||||
if grad_norm_dic is None:
|
||||
grad_norm_dic = {}
|
||||
if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0:
|
||||
self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True)
|
||||
self.callback_metrics.update(batch_log_metrics)
|
||||
|
|
|
@ -106,9 +106,9 @@ class EvaluationLoop(object):
|
|||
|
||||
def on_evaluation_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_end', *args, capture=True, **kwargs)
|
||||
self.trainer.call_hook('on_test_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_end', *args, capture=True, **kwargs)
|
||||
self.trainer.call_hook('on_validation_end', *args, **kwargs)
|
||||
|
||||
def reload_evaluation_dataloaders(self):
|
||||
model = self.trainer.get_model()
|
||||
|
@ -329,9 +329,9 @@ class EvaluationLoop(object):
|
|||
def on_evaluation_epoch_end(self, *args, **kwargs):
|
||||
# call the callback hook
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_epoch_end', *args, capture=True, **kwargs)
|
||||
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_end', *args, capture=True, **kwargs)
|
||||
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
|
||||
|
||||
def log_evaluation_step_metrics(self, output, batch_idx):
|
||||
if self.trainer.running_sanity_check:
|
||||
|
@ -346,10 +346,8 @@ class EvaluationLoop(object):
|
|||
self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)
|
||||
|
||||
def __log_result_step_metrics(self, step_log_metrics, step_pbar_metrics, batch_idx):
|
||||
cached_batch_log_metrics = \
|
||||
self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics()
|
||||
cached_batch_pbar_metrics = \
|
||||
self.trainer.logger_connector.cached_results.get_latest_batch_pbar_metrics()
|
||||
cached_results = self.trainer.logger_connector.cached_results
|
||||
cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
|
||||
|
||||
step_log_metrics.update(cached_batch_log_metrics)
|
||||
step_pbar_metrics.update(cached_batch_pbar_metrics)
|
||||
|
|
|
@ -855,6 +855,8 @@ class Trainer(
|
|||
model.setup(stage_name)
|
||||
|
||||
def _reset_result_and_set_hook_fx_name(self, hook_name):
|
||||
if "batch_start" in hook_name:
|
||||
return True
|
||||
model_ref = self.get_model()
|
||||
if model_ref is not None:
|
||||
# used to track current hook name called
|
||||
|
@ -868,10 +870,9 @@ class Trainer(
|
|||
# capture logging for this hook
|
||||
self.logger_connector.cache_logged_metrics()
|
||||
|
||||
def call_hook(self, hook_name, *args, capture=False, **kwargs):
|
||||
def call_hook(self, hook_name, *args, **kwargs):
|
||||
# set hook_name to model + reset Result obj
|
||||
if capture:
|
||||
self._reset_result_and_set_hook_fx_name(hook_name)
|
||||
skip = self._reset_result_and_set_hook_fx_name(hook_name)
|
||||
|
||||
# always profile hooks
|
||||
with self.profiler.profile(hook_name):
|
||||
|
@ -894,7 +895,7 @@ class Trainer(
|
|||
accelerator_hook = getattr(self.accelerator_backend, hook_name)
|
||||
output = accelerator_hook(*args, **kwargs)
|
||||
|
||||
if capture:
|
||||
if not skip:
|
||||
self._cache_logged_metrics()
|
||||
return output
|
||||
|
||||
|
|
|
@ -825,8 +825,8 @@ class TrainLoop:
|
|||
# inform logger the batch loop has finished
|
||||
self.trainer.logger_connector.on_train_epoch_end()
|
||||
|
||||
self.trainer.call_hook('on_epoch_end', capture=True)
|
||||
self.trainer.call_hook('on_train_epoch_end', epoch_output, capture=True)
|
||||
self.trainer.call_hook('on_epoch_end')
|
||||
self.trainer.call_hook('on_train_epoch_end', epoch_output)
|
||||
|
||||
def increment_accumulated_grad_global_step(self):
|
||||
num_accumulated_batches_reached = self._accumulated_batches_reached()
|
||||
|
|
|
@ -1,3 +1,16 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test deprecated functionality which will be removed in vX.Y.Z"""
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
|
|
|
@ -472,8 +472,6 @@ def test_log_works_in_val_callback(tmpdir):
|
|||
"forked": False,
|
||||
"func_name": func_name}
|
||||
|
||||
"""
|
||||
|
||||
def on_validation_start(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_validation_start', 1, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
@ -486,6 +484,7 @@ def test_log_works_in_val_callback(tmpdir):
|
|||
self.make_logging(pl_module, 'on_validation_epoch_start', 3, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
||||
"""
|
||||
def on_batch_start(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
@ -493,6 +492,7 @@ def test_log_works_in_val_callback(tmpdir):
|
|||
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
self.make_logging(pl_module, 'on_validation_batch_start', 5, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
"""
|
||||
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices,
|
||||
|
@ -510,8 +510,6 @@ def test_log_works_in_val_callback(tmpdir):
|
|||
self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False],
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
||||
"""
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_validation_epoch_end', 9, on_steps=[False],
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
@ -541,16 +539,14 @@ def test_log_works_in_val_callback(tmpdir):
|
|||
trainer.fit(model)
|
||||
trainer.test()
|
||||
|
||||
"""
|
||||
assert test_callback.funcs_called_count["on_epoch_start"] == 1
|
||||
assert test_callback.funcs_called_count["on_batch_start"] == 1
|
||||
# assert test_callback.funcs_called_count["on_batch_start"] == 1
|
||||
assert test_callback.funcs_called_count["on_batch_end"] == 1
|
||||
assert test_callback.funcs_called_count["on_validation_start"] == 1
|
||||
assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1
|
||||
assert test_callback.funcs_called_count["on_validation_batch_start"] == 4
|
||||
# assert test_callback.funcs_called_count["on_validation_batch_start"] == 4
|
||||
assert test_callback.funcs_called_count["on_validation_batch_end"] == 4
|
||||
assert test_callback.funcs_called_count["on_epoch_end"] == 1
|
||||
"""
|
||||
assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1
|
||||
|
||||
# Make sure the func_name exists within callback_metrics. If not, we missed some
|
||||
|
@ -662,7 +658,6 @@ def test_log_works_in_test_callback(tmpdir):
|
|||
"forked": False,
|
||||
"func_name": func_name}
|
||||
|
||||
"""
|
||||
def on_test_start(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_test_start', 1, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
@ -675,11 +670,8 @@ def test_log_works_in_test_callback(tmpdir):
|
|||
self.make_logging(pl_module, 'on_test_epoch_start', 3, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
||||
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
self.make_logging(pl_module, 'on_test_batch_start', 4, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
def on_test_step_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.make_logging(pl_module, 'on_test_step_end', 5, on_steps=self.choices,
|
||||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.make_logging(pl_module, 'on_test_batch_end', 5, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
||||
# used to make sure aggregation works fine.
|
||||
|
@ -690,7 +682,6 @@ def test_log_works_in_test_callback(tmpdir):
|
|||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_epoch_end', 6, on_steps=[False],
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
"""
|
||||
|
||||
def on_test_epoch_end(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_test_epoch_end', 7, on_steps=[False],
|
||||
|
@ -728,13 +719,11 @@ def test_log_works_in_test_callback(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
trainer.test()
|
||||
"""
|
||||
|
||||
assert test_callback.funcs_called_count["on_test_start"] == 1
|
||||
assert test_callback.funcs_called_count["on_epoch_start"] == 2
|
||||
assert test_callback.funcs_called_count["on_test_epoch_start"] == 1
|
||||
assert test_callback.funcs_called_count["on_test_batch_start"] == 4
|
||||
assert test_callback.funcs_called_count["on_test_step_end"] == 4
|
||||
"""
|
||||
assert test_callback.funcs_called_count["on_test_batch_end"] == 4
|
||||
assert test_callback.funcs_called_count["on_test_epoch_end"] == 1
|
||||
|
||||
# Make sure the func_name exists within callback_metrics. If not, we missed some
|
||||
|
|
|
@ -558,7 +558,7 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
"prog_bar": prog_bar,
|
||||
"forked": False,
|
||||
"func_name": func_name}
|
||||
"""
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_train_start', 1, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
@ -571,15 +571,6 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
self.make_logging(pl_module, 'on_train_epoch_start', 3, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
||||
def on_batch_start(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||
self.make_logging(pl_module, 'on_train_batch_start', 5, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
||||
|
||||
def on_batch_end(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices,
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
@ -592,7 +583,6 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
# with func = np.mean if on_epoch else func = np.max
|
||||
self.count += 1
|
||||
|
||||
"""
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False],
|
||||
on_epochs=self.choices, prob_bars=self.choices)
|
||||
|
@ -629,17 +619,12 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
"""
|
||||
assert test_callback.funcs_called_count["on_train_start"] == 1
|
||||
assert test_callback.funcs_called_count["on_epoch_start"] == 2
|
||||
assert test_callback.funcs_called_count["on_train_epoch_start"] == 2
|
||||
assert test_callback.funcs_called_count["on_batch_start"] == 4
|
||||
assert test_callback.funcs_called_count["on_train_batch_start"] == 4
|
||||
assert test_callback.funcs_called_count["on_batch_end"] == 4
|
||||
assert test_callback.funcs_called_count["on_epoch_end"] == 2
|
||||
assert test_callback.funcs_called_count["on_train_batch_end"] == 4
|
||||
|
||||
"""
|
||||
assert test_callback.funcs_called_count["on_epoch_end"] == 2
|
||||
assert test_callback.funcs_called_count["on_train_epoch_end"] == 2
|
||||
|
||||
|
|
Loading…
Reference in New Issue