Update Trainer property docstrings (#16989)

This commit is contained in:
Carlos Mocholí 2023-03-09 01:12:06 +01:00 committed by GitHub
parent 20374b93f4
commit 279e3add00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 65 additions and 8 deletions

View File

@ -1150,6 +1150,14 @@ class Trainer:
@property
def log_dir(self) -> Optional[str]:
"""The directory for the current experiment. Use this to save images to, etc...
.. code-block:: python
def training_step(self, batch, batch_idx):
img = ...
save_img(img, self.trainer.log_dir)
"""
if len(self.loggers) > 0:
if not isinstance(self.loggers[0], TensorBoardLogger):
dirpath = self.loggers[0].save_dir
@ -1163,6 +1171,14 @@ class Trainer:
@property
def is_global_zero(self) -> bool:
"""Whether this process is the global zero in multi-node training.
.. code-block:: python
def training_step(self, batch, batch_idx):
if self.trainer.is_global_zero:
print("in node 0, accelerator 0")
"""
return self.strategy.is_global_zero
@property
@ -1236,7 +1252,7 @@ class Trainer:
def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
"""Allows you to manage which checkpoint is loaded statefully.
Examples::
.. code-block:: python
trainer = Trainer()
trainer.ckpt_path = "my/checkpoint/file.ckpt"
@ -1384,11 +1400,13 @@ class Trainer:
@property
def train_dataloader(self) -> TRAIN_DATALOADERS:
"""The training dataloader(s) used during ``trainer.fit()``."""
if (combined_loader := self.fit_loop._combined_loader) is not None:
return combined_loader.iterables
@property
def val_dataloaders(self) -> EVAL_DATALOADERS:
"""The validation dataloader(s) used during ``trainer.fit()`` or ``trainer.validate()``."""
if (combined_loader := self.fit_loop.epoch_loop.val_loop._combined_loader) is not None:
return combined_loader.iterables
elif (combined_loader := self.validate_loop._combined_loader) is not None:
@ -1396,25 +1414,32 @@ class Trainer:
@property
def test_dataloaders(self) -> EVAL_DATALOADERS:
"""The test dataloader(s) used during ``trainer.test()``."""
if (combined_loader := self.test_loop._combined_loader) is not None:
return combined_loader.iterables
@property
def predict_dataloaders(self) -> EVAL_DATALOADERS:
"""The prediction dataloader(s) used during ``trainer.predict()``."""
if (combined_loader := self.predict_loop._combined_loader) is not None:
return combined_loader.iterables
@property
def num_training_batches(self) -> Union[int, float]:
"""The number of training batches that will be used during ``trainer.fit()``."""
return self.fit_loop.max_batches
@property
def num_sanity_val_batches(self) -> List[Union[int, float]]:
"""The number of validation batches that will be used during the sanity-checking part of
``trainer.fit()``."""
max_batches = self.fit_loop.epoch_loop.val_loop.max_batches
return [min(self.num_sanity_val_steps, batches) for batches in max_batches]
@property
def num_val_batches(self) -> List[Union[int, float]]:
"""The number of validation batches that will be used during ``trainer.fit()`` or
``trainer.validate()``."""
if self.state.fn == TrainerFn.VALIDATING:
return self.validate_loop.max_batches
# if no trainer.fn is set, assume fit's validation
@ -1423,10 +1448,12 @@ class Trainer:
@property
def num_test_batches(self) -> List[Union[int, float]]:
"""The number of test batches that will be used during ``trainer.test()``."""
return self.test_loop.max_batches
@property
def num_predict_batches(self) -> List[Union[int, float]]:
"""The number of prediction batches that will be used during ``trainer.predict()``."""
return self.predict_loop.max_batches
@property
@ -1454,6 +1481,7 @@ class Trainer:
@property
def logger(self) -> Optional[Logger]:
"""The first :class:`~lightning.pytorch.loggers.logger.Logger` being used."""
return self.loggers[0] if len(self.loggers) > 0 else None
@logger.setter
@ -1465,6 +1493,13 @@ class Trainer:
@property
def loggers(self) -> List[Logger]:
"""The list of class:`~lightning.pytorch.loggers.logger.Logger` used.
..code-block:: python
for logger in trainer.loggers:
logger.log_metrics({"foo": 1.0})
"""
return self._loggers
@loggers.setter
@ -1473,14 +1508,36 @@ class Trainer:
@property
def callback_metrics(self) -> _OUT_DICT:
"""The metrics available to callbacks.
This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log`.
..code-block:: python
def training_step(self, batch, batch_idx):
self.log("a_val", 2.0)
callback_metrics = trainer.callback_metrics
assert callback_metrics["a_val"] == 2.0
"""
return self._logger_connector.callback_metrics
@property
def logged_metrics(self) -> _OUT_DICT:
"""The metrics sent to the loggers.
This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the
:paramref:`~lightning.pytorch.core.module.LightningModule.log.logger` argument set.
"""
return self._logger_connector.logged_metrics
@property
def progress_bar_metrics(self) -> _PBAR_DICT:
"""The metrics sent to the progress bar.
This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the
:paramref:`~lightning.pytorch.core.module.LightningModule.log.prog_bar` argument set.
"""
return self._logger_connector.progress_bar_metrics
@property
@ -1496,18 +1553,18 @@ class Trainer:
@property
def estimated_stepping_batches(self) -> Union[int, float]:
r"""
Estimated stepping batches for the complete training inferred from DataLoaders, gradient
accumulation factor and distributed setup.
The estimated number of batches that will ``optimizer.step()`` during training.
Examples::
This accounts for gradient accumulation and the current trainer configuration. This might sets up your training
dataloader if hadn't been set up already.
..code-block:: python
def configure_optimizers(self):
optimizer = ...
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
)
stepping_batches = self.trainer.estimated_stepping_batches
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches)
return [optimizer], [scheduler]
"""
# infinite training
if self.max_epochs == -1: