Docs fixes (#19529)

This commit is contained in:
awaelchli 2024-02-26 12:06:08 +01:00 committed by GitHub
parent 2e512d4b2e
commit 2a827f3f6f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 23 additions and 24 deletions

View File

@ -167,9 +167,11 @@ In distributed training cases where a model is running across many machines, Lig
trainer = Trainer(strategy="ddp")
model = MyLightningModule(hparams)
trainer.fit(model)
# Saves only on the main process
# Handles strategy-specific saving logic like XLA, FSDP, DeepSpeed etc.
trainer.save_checkpoint("example.ckpt")
Not using :meth:`~lightning.pytorch.trainer.trainer.Trainer.save_checkpoint` can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer's save functionality.
If using custom saving functions cannot be avoided, we recommend using the :func:`~lightning.pytorch.utilities.rank_zero.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
model parallel distributed strategies such as deepspeed or sharded training.
By using :meth:`~lightning.pytorch.trainer.trainer.Trainer.save_checkpoint` instead of ``torch.save``, you make your code agnostic to the distributed training strategy being used.
It will ensure that checkpoints are saved correctly in a multi-process setting, avoiding race conditions, deadlocks and other common issues that normally require boilerplate code to handle properly.

View File

@ -16,7 +16,9 @@ Install lightning inside a virtual env or conda environment with pip
python -m pip install lightning
--------------
----
******************
Install with Conda
@ -66,17 +68,17 @@ Install future patch releases from the source. Note that the patch release conta
^^^^^^^^^^^^^^^^^^^^^^
Custom PyTorch Version
^^^^^^^^^^^^^^^^^^^^^^
To use any PyTorch version visit the `PyTorch Installation Page <https://pytorch.org/get-started/locally/#start-locally>`_.
To use any PyTorch version visit the `PyTorch Installation Page <https://pytorch.org/get-started/locally/#start-locally>`_.
You can find the list of supported PyTorch versions in our :ref:`compatibility matrix <versioning:Compatibility matrix>`.
----
*******************************************
Optimized for ML workflows (lightning Apps)
Optimized for ML workflows (Lightning Apps)
*******************************************
If you are deploying workflows built with Lightning in production and require fewer dependencies, try using the optimized `lightning[apps]` package:
If you are deploying workflows built with Lightning in production and require fewer dependencies, try using the optimized ``lightning[apps]`` package:
.. code-block:: bash

View File

@ -89,13 +89,12 @@ class ModelCheckpoint(Checkpoint):
in a deterministic manner. Default: ``None``.
save_top_k: if ``save_top_k == k``,
the best k models according to the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
if ``save_top_k == -1``, all models are saved.
If ``save_top_k == 0``, no models are saved.
If ``save_top_k == -1``, all models are saved.
Please note that the monitors are checked every ``every_n_epochs`` epochs.
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with ``v1``
unless ``enable_version_counter`` is set to False.
If ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, and the filename remains
unchanged, the name of the saved file will be appended with a version count starting with ``v1`` to avoid
collisions unless ``enable_version_counter`` is set to False.
mode: one of {min, max}.
If ``save_top_k != 0``, the decision to overwrite the current save file is made
based on either the maximization or the minimization of the monitored quantity.

View File

@ -85,6 +85,10 @@ class ModelHooks:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
Note:
The value ``outputs["loss"]`` here will be the normalized value w.r.t ``accumulate_grad_batches`` of the
loss returned from ``training_step``.
"""
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:

View File

@ -1285,20 +1285,12 @@ class LightningModule(
Examples::
# DEFAULT
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
# Add your custom logic to run directly before `optimizer.step()`
optimizer.step(closure=optimizer_closure)
# Learning rate warm-up
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
# update params
optimizer.step(closure=optimizer_closure)
# manually warm up lr without a scheduler
if self.trainer.global_step < 500:
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * self.learning_rate
# Add your custom logic to run directly after `optimizer.step()`
"""
optimizer.step(closure=optimizer_closure)