Fix loggers and update docs (#964)

* Fix loggers and update docs

* Update trainer.py
This commit is contained in:
Ethan Harris 2020-02-27 20:54:06 +00:00 committed by GitHub
parent 27a3be0287
commit f5e0df390c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 22 deletions

View File

@ -22,13 +22,13 @@ To use CometLogger as your logger do the following.
)
trainer = Trainer(logger=comet_logger)
The CometLogger is available anywhere in your LightningModule
The CometLogger is available anywhere except ``__init__`` in your LightningModule
.. code-block:: python
class MyModule(pl.LightningModule):
def __init__(self, ...):
def any_lightning_module_function_or_hook(self, ...):
some_img = fake_image()
self.logger.experiment.add_image('generated_images', some_img, 0)
@ -52,13 +52,13 @@ To use Neptune.ai as your logger do the following.
)
trainer = Trainer(logger=neptune_logger)
The Neptune.ai is available anywhere in your LightningModule
The Neptune.ai is available anywhere except ``__init__`` in your LightningModule
.. code-block:: python
class MyModule(pl.LightningModule):
def __init__(self, ...):
def any_lightning_module_function_or_hook(self, ...):
some_img = fake_image()
self.logger.experiment.add_image('generated_images', some_img, 0)
@ -76,13 +76,13 @@ To use `Tensorboard <https://pytorch.org/docs/stable/tensorboard.html>`_ as your
logger = TensorBoardLogger("tb_logs", name="my_model")
trainer = Trainer(logger=logger)
The TensorBoardLogger is available anywhere in your LightningModule
The TensorBoardLogger is available anywhere except ``__init__`` in your LightningModule
.. code-block:: python
class MyModule(pl.LightningModule):
def __init__(self, ...):
def any_lightning_module_function_or_hook(self, ...):
some_img = fake_image()
self.logger.experiment.add_image('generated_images', some_img, 0)
@ -102,13 +102,13 @@ To use TestTube as your logger do the following.
logger = TestTubeLogger("tb_logs", name="my_model")
trainer = Trainer(logger=logger)
The TestTubeLogger is available anywhere in your LightningModule
The TestTubeLogger is available anywhere except ``__init__`` in your LightningModule
.. code-block:: python
class MyModule(pl.LightningModule):
def __init__(self, ...):
def any_lightning_module_function_or_hook(self, ...):
some_img = fake_image()
self.logger.experiment.add_image('generated_images', some_img, 0)
@ -127,13 +127,13 @@ To use Wandb as your logger do the following.
wandb_logger = WandbLogger()
trainer = Trainer(logger=wandb_logger)
The Wandb logger is available anywhere in your LightningModule
The Wandb logger is available anywhere except ``__init__`` in your LightningModule
.. code-block:: python
class MyModule(pl.LightningModule):
def __init__(self, ...):
def any_lightning_module_function_or_hook(self, ...):
some_img = fake_image()
self.logger.experiment.add_image('generated_images', some_img, 0)
@ -151,12 +151,17 @@ PyTorch-Lightning supports use of multiple loggers, just pass a list to the `Tra
logger2 = TestTubeLogger("tt_logs", name="my_model")
trainer = Trainer(logger=[logger1, logger2])
The loggers are available as a list anywhere in your LightningModule
The loggers are available as a list anywhere except ``__init__`` in your LightningModule
.. code-block:: python
class MyModule(pl.LightningModule):
def __init__(self, ...):
def any_lightning_module_function_or_hook(self, ...):
some_img = fake_image()
# Option 1
self.logger.experiment[0].add_image('generated_images', some_img, 0)
# Option 2
self.logger[0].experiment.add_image('generated_images', some_img, 0)

View File

@ -1,6 +1,7 @@
"""
Lightning supports most popular logging frameworks (Tensorboard, comet, weights and biases, etc...).
To use a logger, simply pass it into the trainer.
To use a logger, simply pass it into the trainer. To use multiple loggers, simply pass in a ``list``
or ``tuple`` of loggers.
.. code-block:: python
@ -14,14 +15,19 @@ To use a logger, simply pass it into the trainer.
comet_logger = loggers.CometLogger()
trainer = Trainer(logger=comet_logger)
.. note:: All loggers log by default to `os.getcwd()`. To change the path without creating a logger set
Trainer(default_save_path='/your/path/to/save/checkpoints')
# or pass a list
tb_logger = loggers.TensorBoardLogger()
comet_logger = loggers.CometLogger()
trainer = Trainer(logger=[tb_logger, comet_logger])
.. note:: All loggers log by default to ``os.getcwd()``. To change the path without creating a logger set
``Trainer(default_save_path='/your/path/to/save/checkpoints')``
Custom logger
-------------
You can implement your own logger by writing a class that inherits from
`LightningLoggerBase`. Use the `rank_zero_only` decorator to make sure that
``LightningLoggerBase``. Use the ``rank_zero_only`` decorator to make sure that
only the first process in DDP training logs data.
.. code-block:: python
@ -52,13 +58,13 @@ only the first process in DDP training logs data.
# finishes goes here
If you write a logger than may be useful to others, please send
If you write a logger that may be useful to others, please send
a pull request to add it to Lighting!
Using loggers
-------------
Call the logger anywhere from your LightningModule by doing:
Call the logger anywhere except ``__init__`` in your LightningModule by doing:
.. code-block:: python
@ -69,6 +75,8 @@ Call the logger anywhere from your LightningModule by doing:
def any_lightning_module_function_or_hook(...):
self.logger.experiment.add_histogram(...)
Read more in the `Experiment Logging use case <./experiment_logging.html>`_.
Supported Loggers
-----------------
"""
@ -77,7 +85,7 @@ from os import environ
from .base import LightningLoggerBase, LoggerCollection, rank_zero_only
from .tensorboard import TensorBoardLogger
__all__ = ['TensorBoardLogger', 'LoggerCollection']
__all__ = ['TensorBoardLogger']
try:
# needed to prevent ImportError and duplicated logs.

View File

@ -100,6 +100,9 @@ class LoggerCollection(LightningLoggerBase):
super().__init__()
self._logger_iterable = logger_iterable
def __getitem__(self, index: int) -> LightningLoggerBase:
return [logger for logger in self._logger_iterable][index]
@property
def experiment(self) -> List[Any]:
return [logger.experiment() for logger in self._logger_iterable]

View File

@ -937,6 +937,9 @@ class Trainer(TrainerIOMixin,
# feed to .fit()
"""
# bind logger
model.logger = self.logger
# Fit begin callbacks
self.on_fit_start()
@ -1065,10 +1068,8 @@ class Trainer(TrainerIOMixin,
# set local properties on the model
self.copy_trainer_model_properties(ref_model)
# link up experiment object
# log hyper-parameters
if self.logger is not None:
ref_model.logger = self.logger
# save exp to get started
if hasattr(ref_model, "hparams"):
self.logger.log_hyperparams(ref_model.hparams)