Update Trainer property docstrings (#16989)
This commit is contained in:
parent
20374b93f4
commit
279e3add00
|
@ -1150,6 +1150,14 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def log_dir(self) -> Optional[str]:
|
||||
"""The directory for the current experiment. Use this to save images to, etc...
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
img = ...
|
||||
save_img(img, self.trainer.log_dir)
|
||||
"""
|
||||
if len(self.loggers) > 0:
|
||||
if not isinstance(self.loggers[0], TensorBoardLogger):
|
||||
dirpath = self.loggers[0].save_dir
|
||||
|
@ -1163,6 +1171,14 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def is_global_zero(self) -> bool:
|
||||
"""Whether this process is the global zero in multi-node training.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
if self.trainer.is_global_zero:
|
||||
print("in node 0, accelerator 0")
|
||||
"""
|
||||
return self.strategy.is_global_zero
|
||||
|
||||
@property
|
||||
|
@ -1236,7 +1252,7 @@ class Trainer:
|
|||
def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None:
|
||||
"""Allows you to manage which checkpoint is loaded statefully.
|
||||
|
||||
Examples::
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer()
|
||||
trainer.ckpt_path = "my/checkpoint/file.ckpt"
|
||||
|
@ -1384,11 +1400,13 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
||||
"""The training dataloader(s) used during ``trainer.fit()``."""
|
||||
if (combined_loader := self.fit_loop._combined_loader) is not None:
|
||||
return combined_loader.iterables
|
||||
|
||||
@property
|
||||
def val_dataloaders(self) -> EVAL_DATALOADERS:
|
||||
"""The validation dataloader(s) used during ``trainer.fit()`` or ``trainer.validate()``."""
|
||||
if (combined_loader := self.fit_loop.epoch_loop.val_loop._combined_loader) is not None:
|
||||
return combined_loader.iterables
|
||||
elif (combined_loader := self.validate_loop._combined_loader) is not None:
|
||||
|
@ -1396,25 +1414,32 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def test_dataloaders(self) -> EVAL_DATALOADERS:
|
||||
"""The test dataloader(s) used during ``trainer.test()``."""
|
||||
if (combined_loader := self.test_loop._combined_loader) is not None:
|
||||
return combined_loader.iterables
|
||||
|
||||
@property
|
||||
def predict_dataloaders(self) -> EVAL_DATALOADERS:
|
||||
"""The prediction dataloader(s) used during ``trainer.predict()``."""
|
||||
if (combined_loader := self.predict_loop._combined_loader) is not None:
|
||||
return combined_loader.iterables
|
||||
|
||||
@property
|
||||
def num_training_batches(self) -> Union[int, float]:
|
||||
"""The number of training batches that will be used during ``trainer.fit()``."""
|
||||
return self.fit_loop.max_batches
|
||||
|
||||
@property
|
||||
def num_sanity_val_batches(self) -> List[Union[int, float]]:
|
||||
"""The number of validation batches that will be used during the sanity-checking part of
|
||||
``trainer.fit()``."""
|
||||
max_batches = self.fit_loop.epoch_loop.val_loop.max_batches
|
||||
return [min(self.num_sanity_val_steps, batches) for batches in max_batches]
|
||||
|
||||
@property
|
||||
def num_val_batches(self) -> List[Union[int, float]]:
|
||||
"""The number of validation batches that will be used during ``trainer.fit()`` or
|
||||
``trainer.validate()``."""
|
||||
if self.state.fn == TrainerFn.VALIDATING:
|
||||
return self.validate_loop.max_batches
|
||||
# if no trainer.fn is set, assume fit's validation
|
||||
|
@ -1423,10 +1448,12 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def num_test_batches(self) -> List[Union[int, float]]:
|
||||
"""The number of test batches that will be used during ``trainer.test()``."""
|
||||
return self.test_loop.max_batches
|
||||
|
||||
@property
|
||||
def num_predict_batches(self) -> List[Union[int, float]]:
|
||||
"""The number of prediction batches that will be used during ``trainer.predict()``."""
|
||||
return self.predict_loop.max_batches
|
||||
|
||||
@property
|
||||
|
@ -1454,6 +1481,7 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def logger(self) -> Optional[Logger]:
|
||||
"""The first :class:`~lightning.pytorch.loggers.logger.Logger` being used."""
|
||||
return self.loggers[0] if len(self.loggers) > 0 else None
|
||||
|
||||
@logger.setter
|
||||
|
@ -1465,6 +1493,13 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def loggers(self) -> List[Logger]:
|
||||
"""The list of class:`~lightning.pytorch.loggers.logger.Logger` used.
|
||||
|
||||
..code-block:: python
|
||||
|
||||
for logger in trainer.loggers:
|
||||
logger.log_metrics({"foo": 1.0})
|
||||
"""
|
||||
return self._loggers
|
||||
|
||||
@loggers.setter
|
||||
|
@ -1473,14 +1508,36 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def callback_metrics(self) -> _OUT_DICT:
|
||||
"""The metrics available to callbacks.
|
||||
|
||||
This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log`.
|
||||
|
||||
..code-block:: python
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
self.log("a_val", 2.0)
|
||||
|
||||
callback_metrics = trainer.callback_metrics
|
||||
assert callback_metrics["a_val"] == 2.0
|
||||
"""
|
||||
return self._logger_connector.callback_metrics
|
||||
|
||||
@property
|
||||
def logged_metrics(self) -> _OUT_DICT:
|
||||
"""The metrics sent to the loggers.
|
||||
|
||||
This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the
|
||||
:paramref:`~lightning.pytorch.core.module.LightningModule.log.logger` argument set.
|
||||
"""
|
||||
return self._logger_connector.logged_metrics
|
||||
|
||||
@property
|
||||
def progress_bar_metrics(self) -> _PBAR_DICT:
|
||||
"""The metrics sent to the progress bar.
|
||||
|
||||
This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the
|
||||
:paramref:`~lightning.pytorch.core.module.LightningModule.log.prog_bar` argument set.
|
||||
"""
|
||||
return self._logger_connector.progress_bar_metrics
|
||||
|
||||
@property
|
||||
|
@ -1496,18 +1553,18 @@ class Trainer:
|
|||
@property
|
||||
def estimated_stepping_batches(self) -> Union[int, float]:
|
||||
r"""
|
||||
Estimated stepping batches for the complete training inferred from DataLoaders, gradient
|
||||
accumulation factor and distributed setup.
|
||||
The estimated number of batches that will ``optimizer.step()`` during training.
|
||||
|
||||
Examples::
|
||||
This accounts for gradient accumulation and the current trainer configuration. This might sets up your training
|
||||
dataloader if hadn't been set up already.
|
||||
|
||||
..code-block:: python
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = ...
|
||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
|
||||
)
|
||||
stepping_batches = self.trainer.estimated_stepping_batches
|
||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
"""
|
||||
# infinite training
|
||||
if self.max_epochs == -1:
|
||||
|
|
Loading…
Reference in New Issue