diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 8e0a037fe6..248046d270 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -1150,6 +1150,14 @@ class Trainer: @property def log_dir(self) -> Optional[str]: + """The directory for the current experiment. Use this to save images to, etc... + + .. code-block:: python + + def training_step(self, batch, batch_idx): + img = ... + save_img(img, self.trainer.log_dir) + """ if len(self.loggers) > 0: if not isinstance(self.loggers[0], TensorBoardLogger): dirpath = self.loggers[0].save_dir @@ -1163,6 +1171,14 @@ class Trainer: @property def is_global_zero(self) -> bool: + """Whether this process is the global zero in multi-node training. + + .. code-block:: python + + def training_step(self, batch, batch_idx): + if self.trainer.is_global_zero: + print("in node 0, accelerator 0") + """ return self.strategy.is_global_zero @property @@ -1236,7 +1252,7 @@ class Trainer: def ckpt_path(self, ckpt_path: Optional[_PATH]) -> None: """Allows you to manage which checkpoint is loaded statefully. - Examples:: + .. code-block:: python trainer = Trainer() trainer.ckpt_path = "my/checkpoint/file.ckpt" @@ -1384,11 +1400,13 @@ class Trainer: @property def train_dataloader(self) -> TRAIN_DATALOADERS: + """The training dataloader(s) used during ``trainer.fit()``.""" if (combined_loader := self.fit_loop._combined_loader) is not None: return combined_loader.iterables @property def val_dataloaders(self) -> EVAL_DATALOADERS: + """The validation dataloader(s) used during ``trainer.fit()`` or ``trainer.validate()``.""" if (combined_loader := self.fit_loop.epoch_loop.val_loop._combined_loader) is not None: return combined_loader.iterables elif (combined_loader := self.validate_loop._combined_loader) is not None: @@ -1396,25 +1414,32 @@ class Trainer: @property def test_dataloaders(self) -> EVAL_DATALOADERS: + """The test dataloader(s) used during ``trainer.test()``.""" if (combined_loader := self.test_loop._combined_loader) is not None: return combined_loader.iterables @property def predict_dataloaders(self) -> EVAL_DATALOADERS: + """The prediction dataloader(s) used during ``trainer.predict()``.""" if (combined_loader := self.predict_loop._combined_loader) is not None: return combined_loader.iterables @property def num_training_batches(self) -> Union[int, float]: + """The number of training batches that will be used during ``trainer.fit()``.""" return self.fit_loop.max_batches @property def num_sanity_val_batches(self) -> List[Union[int, float]]: + """The number of validation batches that will be used during the sanity-checking part of + ``trainer.fit()``.""" max_batches = self.fit_loop.epoch_loop.val_loop.max_batches return [min(self.num_sanity_val_steps, batches) for batches in max_batches] @property def num_val_batches(self) -> List[Union[int, float]]: + """The number of validation batches that will be used during ``trainer.fit()`` or + ``trainer.validate()``.""" if self.state.fn == TrainerFn.VALIDATING: return self.validate_loop.max_batches # if no trainer.fn is set, assume fit's validation @@ -1423,10 +1448,12 @@ class Trainer: @property def num_test_batches(self) -> List[Union[int, float]]: + """The number of test batches that will be used during ``trainer.test()``.""" return self.test_loop.max_batches @property def num_predict_batches(self) -> List[Union[int, float]]: + """The number of prediction batches that will be used during ``trainer.predict()``.""" return self.predict_loop.max_batches @property @@ -1454,6 +1481,7 @@ class Trainer: @property def logger(self) -> Optional[Logger]: + """The first :class:`~lightning.pytorch.loggers.logger.Logger` being used.""" return self.loggers[0] if len(self.loggers) > 0 else None @logger.setter @@ -1465,6 +1493,13 @@ class Trainer: @property def loggers(self) -> List[Logger]: + """The list of class:`~lightning.pytorch.loggers.logger.Logger` used. + + ..code-block:: python + + for logger in trainer.loggers: + logger.log_metrics({"foo": 1.0}) + """ return self._loggers @loggers.setter @@ -1473,14 +1508,36 @@ class Trainer: @property def callback_metrics(self) -> _OUT_DICT: + """The metrics available to callbacks. + + This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log`. + + ..code-block:: python + + def training_step(self, batch, batch_idx): + self.log("a_val", 2.0) + + callback_metrics = trainer.callback_metrics + assert callback_metrics["a_val"] == 2.0 + """ return self._logger_connector.callback_metrics @property def logged_metrics(self) -> _OUT_DICT: + """The metrics sent to the loggers. + + This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the + :paramref:`~lightning.pytorch.core.module.LightningModule.log.logger` argument set. + """ return self._logger_connector.logged_metrics @property def progress_bar_metrics(self) -> _PBAR_DICT: + """The metrics sent to the progress bar. + + This includes metrics logged via :meth:`~lightning.pytorch.core.module.LightningModule.log` with the + :paramref:`~lightning.pytorch.core.module.LightningModule.log.prog_bar` argument set. + """ return self._logger_connector.progress_bar_metrics @property @@ -1496,18 +1553,18 @@ class Trainer: @property def estimated_stepping_batches(self) -> Union[int, float]: r""" - Estimated stepping batches for the complete training inferred from DataLoaders, gradient - accumulation factor and distributed setup. + The estimated number of batches that will ``optimizer.step()`` during training. - Examples:: + This accounts for gradient accumulation and the current trainer configuration. This might sets up your training + dataloader if hadn't been set up already. + + ..code-block:: python def configure_optimizers(self): optimizer = ... - scheduler = torch.optim.lr_scheduler.OneCycleLR( - optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches - ) + stepping_batches = self.trainer.estimated_stepping_batches + scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches) return [optimizer], [scheduler] - """ # infinite training if self.max_epochs == -1: