2020-01-21 20:18:32 +00:00
.. role :: hidden
:class: hidden-section
2020-11-07 19:18:45 +00:00
.. testsetup :: *
import os
2023-02-27 20:14:23 +00:00
from lightning.pytorch import Trainer, LightningModule, seed_everything
2020-11-07 19:18:45 +00:00
2020-08-13 22:56:51 +00:00
.. _trainer:
2020-01-21 20:18:32 +00:00
Trainer
2020-02-11 12:41:15 +00:00
=======
2020-11-07 19:18:45 +00:00
2023-09-20 17:09:34 +00:00
Once you've organized your PyTorch code into a :class: `~lightning.pytorch.core.LightningModule` , the `` Trainer `` automates everything else.
2020-11-07 19:18:45 +00:00
2023-03-14 14:51:07 +00:00
The `` Trainer `` achieves the following:
2020-11-07 19:18:45 +00:00
2023-09-20 17:09:34 +00:00
1. You maintain control over all aspects via PyTorch code in your :class: `~lightning.pytorch.core.LightningModule` .
2020-11-07 19:18:45 +00:00
2. The trainer uses best practices embedded by contributors and users
from top AI labs such as Facebook AI Research, NYU, MIT, Stanford, etc...
2023-03-14 14:51:07 +00:00
3. The trainer allows disabling any key part that you don't want automated.
2020-11-07 19:18:45 +00:00
|
-----------
Basic use
---------
This is the basic use of the trainer:
.. code-block :: python
model = MyLightningModule()
trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)
2020-11-07 19:53:04 +00:00
--------
Under the hood
--------------
2023-03-14 14:51:07 +00:00
The Lightning `` Trainer `` does much more than just "training". Under the hood, it handles all loop details for you, some examples include:
2020-11-07 19:53:04 +00:00
2020-11-17 17:37:15 +00:00
- Automatically enabling/disabling grads
2020-11-07 19:53:04 +00:00
- Running the training, validation and test dataloaders
- Calling the Callbacks at the appropriate times
- Putting batches and computations on the correct devices
Here's the pseudocode for what the trainer does under the hood (showing the train loop only)
.. code-block :: python
2023-11-16 21:32:27 +00:00
# enable grads
2020-11-07 19:53:04 +00:00
torch.set_grad_enabled(True)
losses = []
for batch in train_dataloader:
# calls hooks like this one
on_train_batch_start()
# train step
loss = training_step(batch)
2021-03-12 09:00:23 +00:00
# clear gradients
optimizer.zero_grad()
2020-11-07 19:53:04 +00:00
# backward
loss.backward()
2021-03-12 09:00:23 +00:00
# update parameters
2020-11-07 19:53:04 +00:00
optimizer.step()
losses.append(loss)
2020-11-07 19:18:45 +00:00
--------
Trainer in Python scripts
-------------------------
In Python scripts, it's recommended you use a main function to call the Trainer.
.. code-block :: python
from argparse import ArgumentParser
2021-07-30 12:10:15 +00:00
2020-11-07 19:18:45 +00:00
def main(hparams):
model = LightningModule()
2022-03-23 19:52:12 +00:00
trainer = Trainer(accelerator=hparams.accelerator, devices=hparams.devices)
2020-11-07 19:18:45 +00:00
trainer.fit(model)
2021-07-30 12:10:15 +00:00
if __name__ == "__main__":
2020-11-07 19:18:45 +00:00
parser = ArgumentParser()
2022-03-23 19:52:12 +00:00
parser.add_argument("--accelerator", default=None)
parser.add_argument("--devices", default=None)
2020-11-07 19:18:45 +00:00
args = parser.parse_args()
main(args)
So you can run it like so:
.. code-block :: bash
2022-03-23 19:52:12 +00:00
python main.py --accelerator 'gpu' --devices 2
2020-11-07 19:18:45 +00:00
.. note ::
2023-02-13 20:44:30 +00:00
Pro-tip: You don't need to define all flags manually.
You can let the :doc: `LightningCLI <../cli/lightning_cli>` create the Trainer and model with arguments supplied from the CLI.
2020-11-07 19:18:45 +00:00
2023-02-13 20:44:30 +00:00
If you want to stop a training run early, you can press "Ctrl + C" on your keyboard.
The trainer will catch the `` KeyboardInterrupt `` and attempt a graceful shutdown. The trainer object will also set
an attribute `` interrupted `` to `` True `` in such cases. If you have a callback which shuts down compute
resources, for example, you can conditionally run the shutdown logic for only uninterrupted runs by overriding :meth: `lightning.pytorch.Callback.on_exception` .
2020-11-07 19:18:45 +00:00
------------
2021-03-11 02:46:37 +00:00
Validation
----------
You can perform an evaluation epoch over the validation set, outside of the training loop,
2023-02-27 20:14:23 +00:00
using :meth: `~lightning.pytorch.trainer.trainer.Trainer.validate` . This might be
2021-03-11 02:46:37 +00:00
useful if you want to collect new metrics from a model right at its initialization
or after it has already been trained.
.. code-block :: python
2022-10-12 08:46:03 +00:00
trainer.validate(model=model, dataloaders=val_dataloaders)
2021-03-11 02:46:37 +00:00
------------
2020-11-07 19:18:45 +00:00
Testing
-------
Once you're done training, feel free to run the test set!
(Only right before publishing your paper or pushing to production)
.. code-block :: python
2021-11-04 10:03:39 +00:00
trainer.test(dataloaders=test_dataloaders)
2020-11-07 19:18:45 +00:00
------------
Reproducibility
---------------
To ensure full reproducibility from run to run you need to set seeds for pseudo-random generators,
and set `` deterministic `` flag in `` Trainer `` .
Example::
2023-02-27 20:14:23 +00:00
from lightning.pytorch import Trainer, seed_everything
2020-11-07 19:18:45 +00:00
2021-04-19 14:28:37 +00:00
seed_everything(42, workers=True)
2022-02-11 14:24:25 +00:00
# sets seeds for numpy, torch and python.random.
2020-11-07 19:18:45 +00:00
model = Model()
trainer = Trainer(deterministic=True)
2023-02-27 20:14:23 +00:00
By setting `` workers=True `` in :func: `~lightning.pytorch.seed_everything` , Lightning derives
2021-04-19 14:28:37 +00:00
unique seeds across all dataloader workers and processes for :mod: `torch` , :mod: `numpy` and stdlib
:mod: `random` number generators. When turned on, it ensures that e.g. data augmentations are not repeated across workers.
2020-11-07 19:18:45 +00:00
-------
2021-06-16 21:28:51 +00:00
.. _trainer_flags:
2020-11-07 19:18:45 +00:00
Trainer flags
-------------
accelerator
^^^^^^^^^^^
2024-02-12 21:16:02 +00:00
Supports passing different accelerator types (`` "cpu", "gpu", "tpu", "hpu", "auto" `` )
2021-10-20 15:32:21 +00:00
as well as custom accelerator instances.
2020-11-07 19:18:45 +00:00
2021-10-20 15:32:21 +00:00
.. code-block :: python
2020-11-07 19:18:45 +00:00
2021-10-20 15:32:21 +00:00
# CPU accelerator
trainer = Trainer(accelerator="cpu")
2020-11-07 19:18:45 +00:00
2022-03-28 14:44:59 +00:00
# Training with GPU Accelerator using 2 GPUs
2021-10-20 15:32:21 +00:00
trainer = Trainer(devices=2, accelerator="gpu")
2020-11-07 19:18:45 +00:00
2021-10-20 15:32:21 +00:00
# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices=8, accelerator="tpu")
2020-11-07 19:18:45 +00:00
2021-10-20 15:32:21 +00:00
# Training with GPU Accelerator using the DistributedDataParallel strategy
trainer = Trainer(devices=4, accelerator="gpu", strategy="ddp")
2020-11-07 19:18:45 +00:00
2023-02-23 17:42:17 +00:00
.. note :: The `` "auto" `` option recognizes the machine you are on, and selects the appropriate `` Accelerator `` .
2020-11-07 19:18:45 +00:00
2021-10-20 15:32:21 +00:00
.. code-block :: python
2020-11-07 19:18:45 +00:00
2021-10-20 15:32:21 +00:00
# If your machine has GPUs, it will use the GPU Accelerator for training
trainer = Trainer(devices=2, accelerator="auto")
2020-11-07 19:18:45 +00:00
You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs.
Example::
2021-10-20 15:32:21 +00:00
class MyOwnAcc(CPUAccelerator):
2020-11-07 19:18:45 +00:00
...
2021-03-01 13:14:53 +00:00
Trainer(accelerator=MyOwnAcc())
2020-11-07 19:18:45 +00:00
2022-03-28 21:10:20 +00:00
.. note ::
If the `` devices `` flag is not defined, it will assume `` devices `` to be `` "auto" `` and fetch the `` auto_device_count ``
from the accelerator.
.. code-block :: python
2022-07-19 10:16:35 +00:00
# This is part of the built-in `CUDAAccelerator`
class CUDAAccelerator(Accelerator):
2022-03-28 21:10:20 +00:00
"""Accelerator for GPU devices."""
@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return torch.cuda.device_count()
# Training with GPU Accelerator using total number of gpus available on the system
Trainer(accelerator="gpu")
2020-11-07 19:18:45 +00:00
accumulate_grad_batches
^^^^^^^^^^^^^^^^^^^^^^^
2023-02-13 20:15:38 +00:00
Accumulates gradients over k batches before stepping the optimizer.
2020-11-07 19:18:45 +00:00
.. testcode ::
# default used by the Trainer (no accumulation)
trainer = Trainer(accumulate_grad_batches=1)
Example::
# accumulate every 4 batches (effective batch size is batch*4)
trainer = Trainer(accumulate_grad_batches=4)
2023-02-13 20:15:38 +00:00
See also: :ref: `gradient_accumulation` to enable more fine-grained accumulation schedules.
2020-11-07 19:18:45 +00:00
benchmark
^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/benchmark.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/benchmark.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
2022-05-31 19:22:19 +00:00
The value (`` True `` or `` False `` ) to set `` torch.backends.cudnn.benchmark `` to. The value for
`` torch.backends.cudnn.benchmark `` set in the current session will be used (`` False `` if not manually set).
2023-09-20 17:09:34 +00:00
If :paramref: `~lightning.pytorch.trainer.trainer.Trainer.deterministic` is set to `` True `` , this will default to `` False `` .
2022-05-31 19:22:19 +00:00
You can read more about the interaction of `` torch.backends.cudnn.benchmark `` and `` torch.backends.cudnn.deterministic ``
2022-02-24 19:06:03 +00:00
`here <https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking> `__
2020-11-07 19:18:45 +00:00
2022-05-31 19:22:19 +00:00
Setting this flag to `` True `` can increase the speed of your system if your input sizes don't
change. However, if they do, then it might make your system slower. The CUDNN auto-tuner will try to find the best
algorithm for the hardware when a new input size is encountered. This might also increase the memory usage.
Read more about it `here <https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936> `__ .
2020-11-07 19:18:45 +00:00
Example::
2022-05-31 19:22:19 +00:00
# Will use whatever the current value for torch.backends.cudnn.benchmark, normally False
trainer = Trainer(benchmark=None) # default
2022-02-24 19:06:03 +00:00
# you can overwrite the value
2022-05-31 19:22:19 +00:00
trainer = Trainer(benchmark=True)
2020-11-07 19:18:45 +00:00
deterministic
^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/deterministic.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/deterministic.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
2022-04-20 14:19:45 +00:00
This flag sets the `` torch.backends.cudnn.deterministic `` flag.
2020-11-07 19:18:45 +00:00
Might make your system slower, but ensures reproducibility.
2022-04-20 14:19:45 +00:00
For more info check `PyTorch docs <https://pytorch.org/docs/stable/notes/randomness.html> `_ .
2020-11-07 19:18:45 +00:00
Example::
# default used by the Trainer
trainer = Trainer(deterministic=False)
callbacks
^^^^^^^^^
2023-03-14 11:44:07 +00:00
This argument can be used to add a :class: `~lightning.pytorch.callbacks.callback.Callback` or a list of them.
Callbacks run sequentially in the order defined here
2023-02-27 20:14:23 +00:00
with the exception of :class: `~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks which run
2021-02-03 21:40:57 +00:00
after all others to ensure all states are saved to the checkpoints.
2020-11-07 19:18:45 +00:00
.. code-block :: python
2023-03-14 11:44:07 +00:00
# single callback
trainer = Trainer(callbacks=PrintCallback())
2020-11-07 19:18:45 +00:00
# a list of callbacks
2023-03-14 11:44:07 +00:00
trainer = Trainer(callbacks=[PrintCallback()])
2020-11-07 19:18:45 +00:00
Example::
2023-02-27 20:14:23 +00:00
from lightning.pytorch.callbacks import Callback
2020-11-07 19:18:45 +00:00
class PrintCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is started!")
def on_train_end(self, trainer, pl_module):
print("Training is done.")
2021-02-13 00:27:44 +00:00
Model-specific callbacks can also be added inside the `` LightningModule `` through
2023-09-20 17:09:34 +00:00
:meth: `~lightning.pytorch.core.LightningModule.configure_callbacks` .
2021-02-13 00:27:44 +00:00
Callbacks returned in this hook will extend the list initially given to the `` Trainer `` argument, and replace
the trainer callbacks should there be two or more of the same type.
2023-02-27 20:14:23 +00:00
:class: `~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks always run last.
2021-02-13 00:27:44 +00:00
2020-11-07 19:18:45 +00:00
check_val_every_n_epoch
^^^^^^^^^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/check_val_every_n_epoch.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/check_val_every_n_epoch.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
Check val every n train epochs.
Example::
# default used by the Trainer
trainer = Trainer(check_val_every_n_epoch=1)
# run val loop every 10 training epochs
trainer = Trainer(check_val_every_n_epoch=10)
2021-10-12 07:55:07 +00:00
default_root_dir
^^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/default_root_dir.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/default%E2%80%A8_root_dir.jpg
:width: 400
:muted:
2021-10-12 07:55:07 +00:00
Default path for logs and weights when no logger or
2023-02-27 20:14:23 +00:00
:class: `lightning.pytorch.callbacks.ModelCheckpoint` callback passed. On
2021-10-12 07:55:07 +00:00
certain clusters you might want to separate where logs and checkpoints are
stored. If you don't then use this argument for convenience. Paths can be local
2023-03-14 11:44:07 +00:00
paths or remote paths such as `` s3://bucket/path `` or `` hdfs://path/ `` . Credentials
2021-10-12 07:55:07 +00:00
will need to be set up to use remote filepaths.
.. testcode ::
# default used by the Trainer
trainer = Trainer(default_root_dir=os.getcwd())
2021-11-01 18:39:00 +00:00
devices
^^^^^^^
Number of devices to train on (`` int `` ), which devices to train on (`` list `` or `` str `` ), or `` "auto" `` .
.. code-block :: python
# Training with CPU Accelerator using 2 processes
trainer = Trainer(devices=2, accelerator="cpu")
# Training with GPU Accelerator using GPUs 1 and 3
trainer = Trainer(devices=[1, 3], accelerator="gpu")
# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices=8, accelerator="tpu")
.. tip :: The `` "auto" `` option recognizes the devices to train on, depending on the `` Accelerator `` being used.
.. code-block :: python
2023-02-23 17:42:17 +00:00
# Use whatever hardware your machine has available
2021-11-01 18:39:00 +00:00
trainer = Trainer(devices="auto", accelerator="auto")
# Training with CPU Accelerator using 1 process
trainer = Trainer(devices="auto", accelerator="cpu")
# Training with TPU Accelerator using 8 tpu cores
trainer = Trainer(devices="auto", accelerator="tpu")
2022-03-28 21:10:20 +00:00
.. note ::
If the `` devices `` flag is not defined, it will assume `` devices `` to be `` "auto" `` and fetch the `` auto_device_count ``
from the accelerator.
.. code-block :: python
2022-07-19 10:16:35 +00:00
# This is part of the built-in `CUDAAccelerator`
class CUDAAccelerator(Accelerator):
2022-03-28 21:10:20 +00:00
"""Accelerator for GPU devices."""
@staticmethod
def auto_device_count() -> int:
"""Get the devices when set to auto."""
return torch.cuda.device_count()
# Training with GPU Accelerator using total number of gpus available on the system
Trainer(accelerator="gpu")
2021-10-12 07:55:07 +00:00
enable_checkpointing
^^^^^^^^^^^^^^^^^^^^
2020-11-07 19:18:45 +00:00
By default Lightning saves a checkpoint for you in your current working directory, with the state of your last training epoch,
Checkpoints capture the exact value of all parameters used by a model.
To disable automatic checkpointing, set this to `False` .
.. code-block :: python
2021-10-12 07:55:07 +00:00
# default used by Trainer, saves the most recent model to a single checkpoint after each epoch
trainer = Trainer(enable_checkpointing=True)
2020-11-07 19:18:45 +00:00
# turn off automatic checkpointing
2021-10-12 07:55:07 +00:00
trainer = Trainer(enable_checkpointing=False)
2020-11-07 19:18:45 +00:00
2023-02-27 20:14:23 +00:00
You can override the default behavior by initializing the :class: `~lightning.pytorch.callbacks.ModelCheckpoint`
callback, and adding it to the :paramref: `~lightning.pytorch.trainer.trainer.Trainer.callbacks` list.
2021-12-16 16:21:59 +00:00
See :doc: `Saving and Loading Checkpoints <../common/checkpointing>` for how to customize checkpointing.
2020-11-07 19:18:45 +00:00
2020-11-22 06:35:54 +00:00
.. testcode ::
2023-02-27 20:14:23 +00:00
from lightning.pytorch.callbacks import ModelCheckpoint
2021-07-30 12:10:15 +00:00
2020-11-22 06:35:54 +00:00
# Init ModelCheckpoint callback, monitoring 'val_loss'
2021-07-30 12:10:15 +00:00
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
2020-11-22 06:35:54 +00:00
# Add your callback to the callbacks list
trainer = Trainer(callbacks=[checkpoint_callback])
2020-11-07 19:18:45 +00:00
fast_dev_run
^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/fast_dev_run.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/fast_dev_run.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
2022-04-25 16:41:23 +00:00
Runs n if set to `` n `` (int) else 1 if set to `` True `` batch(es) to ensure your code will execute without errors. This
applies to fitting, validating, testing, and predicting. This flag is **only** recommended for debugging purposes and
should not be used to limit the number of batches to run.
2020-11-07 19:18:45 +00:00
.. code-block :: python
# default used by the Trainer
trainer = Trainer(fast_dev_run=False)
2022-04-25 16:41:23 +00:00
# runs only 1 training and 1 validation batch and the program ends
2020-11-07 19:18:45 +00:00
trainer = Trainer(fast_dev_run=True)
2022-04-25 16:41:23 +00:00
trainer.fit(...)
2020-11-07 19:18:45 +00:00
2022-04-25 16:41:23 +00:00
# runs 7 predict batches and program ends
2020-12-08 20:07:53 +00:00
trainer = Trainer(fast_dev_run=7)
2022-04-25 16:41:23 +00:00
trainer.predict(...)
This argument is different from `` limit_{train,val,test,predict}_batches `` because side effects are avoided to reduce the
impact to subsequent runs. These are the changes enabled:
- Sets `` Trainer(max_epochs=1) `` .
- Sets `` Trainer(max_steps=...) `` to 1 or the number passed.
- Sets `` Trainer(num_sanity_val_steps=0) `` .
- Sets `` Trainer(val_check_interval=1.0) `` .
- Sets `` Trainer(check_every_n_epoch=1) `` .
- Disables all loggers.
- Disables passing logged metrics to loggers.
2023-02-27 20:14:23 +00:00
- The :class: `~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callbacks will not trigger.
- The :class: `~lightning.pytorch.callbacks.early_stopping.EarlyStopping` callbacks will not trigger.
2022-04-25 16:41:23 +00:00
- Sets `` limit_{train,val,test,predict}_batches `` to 1 or the number passed.
2023-02-27 20:14:23 +00:00
- Disables the tuning callbacks (:class: `~lightning.pytorch.callbacks.batch_size_finder.BatchSizeFinder` , :class: `~lightning.pytorch.callbacks.lr_finder.LearningRateFinder` ).
2022-04-25 16:41:23 +00:00
- If using the CLI, the configuration file is not saved.
2020-12-08 20:07:53 +00:00
2020-11-07 19:18:45 +00:00
gradient_clip_val
^^^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/gradient_clip_val.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/gradient+_clip_val.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
Gradient clipping value
.. testcode ::
# default used by the Trainer
2023-03-14 11:44:07 +00:00
trainer = Trainer(gradient_clip_val=None)
2020-11-07 19:18:45 +00:00
2021-01-04 14:05:24 +00:00
limit_train_batches
^^^^^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/limit_batches.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/limit_train_batches.jpg
:width: 400
:muted:
2021-01-04 14:05:24 +00:00
How much of training dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
.. testcode ::
# default used by the Trainer
trainer = Trainer(limit_train_batches=1.0)
Example::
# default used by the Trainer
trainer = Trainer(limit_train_batches=1.0)
# run through only 25% of the training set each epoch
trainer = Trainer(limit_train_batches=0.25)
# run through only 10 batches of the training set each epoch
trainer = Trainer(limit_train_batches=10)
2020-11-07 19:18:45 +00:00
limit_test_batches
^^^^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/limit_batches.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/limit_test_batches.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
How much of test dataset to check.
.. testcode ::
# default used by the Trainer
trainer = Trainer(limit_test_batches=1.0)
# run through only 25% of the test set each epoch
trainer = Trainer(limit_test_batches=0.25)
# run for only 10 batches
trainer = Trainer(limit_test_batches=10)
In the case of multiple test dataloaders, the limit applies to each dataloader individually.
limit_val_batches
^^^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/limit_batches.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/limit_val_batches.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
How much of validation dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
.. testcode ::
# default used by the Trainer
trainer = Trainer(limit_val_batches=1.0)
# run through only 25% of the validation set each epoch
trainer = Trainer(limit_val_batches=0.25)
# run for only 10 batches
trainer = Trainer(limit_val_batches=10)
2022-06-02 16:49:40 +00:00
# disable validation
trainer = Trainer(limit_val_batches=0)
2020-11-07 19:18:45 +00:00
In the case of multiple validation dataloaders, the limit applies to each dataloader individually.
2021-01-04 14:05:24 +00:00
log_every_n_steps
^^^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/log_every_n_steps.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/log_every_n_steps.jpg
:width: 400
:muted:
2021-01-04 14:05:24 +00:00
How often to add logging rows (does not write to disk)
.. testcode ::
# default used by the Trainer
trainer = Trainer(log_every_n_steps=50)
See Also:
2021-01-26 20:07:07 +00:00
- :doc: `logging <../extensions/logging>`
2021-01-04 14:05:24 +00:00
2020-11-07 19:18:45 +00:00
logger
^^^^^^
2022-04-19 18:15:47 +00:00
:doc: `Logger <../visualize/loggers>` (or iterable collection of loggers) for experiment tracking. A `` True `` value uses the default `` TensorBoardLogger `` shown below. `` False `` will disable logging.
2020-11-07 19:18:45 +00:00
.. testcode ::
2023-01-16 08:06:14 +00:00
:skipif: not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE
2020-11-07 19:18:45 +00:00
2023-02-27 20:14:23 +00:00
from lightning.pytorch.loggers import TensorBoardLogger
2020-11-07 19:18:45 +00:00
2023-01-16 08:06:14 +00:00
# default logger used by trainer (if tensorboard is installed)
2021-07-30 12:10:15 +00:00
logger = TensorBoardLogger(save_dir=os.getcwd(), version=1, name="lightning_logs")
2020-11-07 19:18:45 +00:00
Trainer(logger=logger)
max_epochs
^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/min_max_epochs.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/max_epochs.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
Stop training once this number of epochs is reached
.. testcode ::
# default used by the Trainer
trainer = Trainer(max_epochs=1000)
2021-09-04 23:33:43 +00:00
If both `` max_epochs `` and `` max_steps `` aren't specified, `` max_epochs `` will default to `` 1000 `` .
To enable infinite training, set `` max_epochs = -1 `` .
2020-11-07 19:18:45 +00:00
min_epochs
^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/min_max_epochs.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/min_epochs.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
Force training for at least these many epochs
.. testcode ::
# default used by the Trainer
trainer = Trainer(min_epochs=1)
max_steps
^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/min_max_steps.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/max_steps.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
2022-03-07 19:21:37 +00:00
Stop training after this number of :ref: `global steps <common/trainer:global_step>` .
2020-11-07 19:18:45 +00:00
Training will stop if max_steps or max_epochs have reached (earliest).
.. testcode ::
# Default (disabled)
2022-07-14 12:28:38 +00:00
trainer = Trainer(max_steps=-1)
2020-11-07 19:18:45 +00:00
# Stop after 100 steps
trainer = Trainer(max_steps=100)
2021-09-04 23:33:43 +00:00
If `` max_steps `` is not specified, `` max_epochs `` will be used instead (and `` max_epochs `` defaults to
`` 1000 `` if `` max_epochs `` is not specified). To disable this default, set `` max_steps = -1 `` .
2020-11-07 19:18:45 +00:00
min_steps
^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/min_max_steps.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/min_steps.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
2022-03-07 19:21:37 +00:00
Force training for at least this number of :ref: `global steps <common/trainer:global_step>` .
2020-11-07 19:18:45 +00:00
Trainer will train model for at least min_steps or min_epochs (latest).
.. testcode ::
# Default (disabled)
trainer = Trainer(min_steps=None)
# Run at least for 100 steps (disable min_epochs)
trainer = Trainer(min_steps=100, min_epochs=0)
2021-04-16 11:38:57 +00:00
max_time
^^^^^^^^
Set the maximum amount of time for training. Training will get interrupted mid-epoch.
2023-02-27 20:14:23 +00:00
For customizable options use the :class: `~lightning.pytorch.callbacks.timer.Timer` callback.
2021-04-16 11:38:57 +00:00
.. testcode ::
# Default (disabled)
trainer = Trainer(max_time=None)
# Stop after 12 hours of training or when reaching 10 epochs (string)
trainer = Trainer(max_time="00:12:00:00", max_epochs=10)
# Stop after 1 day and 5 hours (dict)
trainer = Trainer(max_time={"days": 1, "hours": 5})
In case `` max_time `` is used together with `` min_steps `` or `` min_epochs `` , the `` min_* `` requirement
always has precedence.
2020-11-07 19:18:45 +00:00
num_nodes
^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/num_nodes.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/num_nodes.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
Number of GPU nodes for distributed training.
.. testcode ::
# default used by the Trainer
trainer = Trainer(num_nodes=1)
# to train on 8 nodes
trainer = Trainer(num_nodes=8)
num_sanity_val_steps
^^^^^^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/num_sanity_val_steps.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/num_sanity%E2%80%A8_val_steps.jp
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
Sanity check runs n batches of val before starting the training routine.
This catches any bugs in your validation without having to wait for the first validation check.
The Trainer uses 2 steps by default. Turn it off or modify it here.
.. testcode ::
# default used by the Trainer
trainer = Trainer(num_sanity_val_steps=2)
# turn it off
trainer = Trainer(num_sanity_val_steps=0)
# check all validation data
trainer = Trainer(num_sanity_val_steps=-1)
This option will reset the validation dataloader unless `` num_sanity_val_steps=0 `` .
2021-01-04 14:05:24 +00:00
overfit_batches
^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/overfit_batches.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/overfit_batches.jpg
:width: 400
:muted:
2021-01-04 14:05:24 +00:00
2022-03-31 18:04:58 +00:00
Uses this much data of the training & validation set.
If the training & validation dataloaders have `` shuffle=True `` , Lightning will automatically disable it.
2021-01-04 14:05:24 +00:00
Useful for quickly debugging or trying to overfit on purpose.
.. testcode ::
# default used by the Trainer
trainer = Trainer(overfit_batches=0.0)
2022-03-31 18:04:58 +00:00
# use only 1% of the train & val set
2021-01-04 14:05:24 +00:00
trainer = Trainer(overfit_batches=0.01)
# overfit on 10 of the same batches
trainer = Trainer(overfit_batches=10)
2020-11-07 19:18:45 +00:00
plugins
^^^^^^^
2021-04-14 20:53:21 +00:00
:ref: `Plugins` allow you to connect arbitrary backends, precision libraries, clusters etc. For example:
2020-11-07 19:18:45 +00:00
2022-04-19 18:15:47 +00:00
- :ref: `Checkpoint IO <checkpointing_expert>`
2021-04-14 20:53:21 +00:00
- `TorchElastic <https://pytorch.org/elastic/0.2.2/index.html> `_
2022-04-19 18:15:47 +00:00
- :ref: `Precision Plugins <precision_expert>`
2020-11-07 19:18:45 +00:00
2021-04-14 20:53:21 +00:00
To define your own behavior, subclass the relevant class and pass it in. Here's an example linking up your own
2023-02-27 20:14:23 +00:00
:class: `~lightning.pytorch.plugins.environments.ClusterEnvironment` .
2020-11-07 19:18:45 +00:00
.. code-block :: python
2023-02-27 20:14:23 +00:00
from lightning.pytorch.plugins.environments import ClusterEnvironment
2020-11-07 19:18:45 +00:00
2021-07-30 12:10:15 +00:00
class MyCluster(ClusterEnvironment):
2021-11-08 12:32:58 +00:00
def main_address(self):
return your_main_address
2020-11-07 19:18:45 +00:00
2021-11-08 12:32:58 +00:00
def main_port(self):
return your_main_port
2020-11-07 19:18:45 +00:00
def world_size(self):
return the_world_size
2021-07-30 12:10:15 +00:00
2021-04-14 20:53:21 +00:00
trainer = Trainer(plugins=[MyCluster()], ...)
2020-11-07 19:18:45 +00:00
precision
^^^^^^^^^
2023-09-07 17:21:00 +00:00
There are two different techniques to set the mixed precision. "True" precision and "Mixed" precision.
2020-11-07 19:18:45 +00:00
2023-09-07 17:21:00 +00:00
Lightning supports doing floating point operations in 64-bit precision ("double"), 32-bit precision ("full"), or 16-bit ("half") with both regular and `bfloat16 <https://pytorch.org/docs/1.10.0/generated/torch.Tensor.bfloat16.html> `_ ).
This selected precision will have a direct impact in the performance and memory usage based on your hardware.
Automatic mixed precision settings are denoted by a `` "-mixed" `` suffix, while "true" precision settings have a `` "-true" `` suffix:
2020-11-07 19:18:45 +00:00
2023-09-07 17:21:00 +00:00
.. code-block :: python
2020-11-07 19:18:45 +00:00
2023-09-07 17:21:00 +00:00
# Default used by the Trainer
fabric = Fabric(precision="32-true", devices=1)
2020-11-07 19:18:45 +00:00
2023-09-07 17:21:00 +00:00
# the same as:
2023-09-29 17:17:18 +00:00
trainer = Trainer(precision="32", devices=1)
2021-06-16 21:28:51 +00:00
2023-09-07 17:21:00 +00:00
# 16-bit mixed precision (model weights remain in torch.float32)
2023-09-29 17:17:18 +00:00
trainer = Trainer(precision="16-mixed", devices=1)
2023-08-02 11:56:22 +00:00
2023-09-07 17:21:00 +00:00
# 16-bit bfloat mixed precision (model weights remain in torch.float32)
2023-09-29 17:17:18 +00:00
trainer = Trainer(precision="bf16-mixed", devices=1)
2023-08-02 11:56:22 +00:00
2023-09-29 17:17:18 +00:00
# 8-bit mixed precision via TransformerEngine (model weights get cast to torch.bfloat16)
trainer = Trainer(precision="transformer-engine", devices=1)
2021-06-16 21:28:51 +00:00
2023-09-07 17:21:00 +00:00
# 16-bit precision (model weights get cast to torch.float16)
2023-09-29 17:17:18 +00:00
trainer = Trainer(precision="16-true", devices=1)
2021-06-16 21:28:51 +00:00
2023-09-07 17:21:00 +00:00
# 16-bit bfloat precision (model weights get cast to torch.bfloat16)
2023-09-29 17:17:18 +00:00
trainer = Trainer(precision="bf16-true", devices=1)
2023-08-02 11:56:22 +00:00
2023-09-07 17:21:00 +00:00
# 64-bit (double) precision (model weights get cast to torch.float64)
2023-09-29 17:17:18 +00:00
trainer = Trainer(precision="64-true", devices=1)
2023-09-07 17:21:00 +00:00
See the :doc: `N-bit precision guide <../common/precision>` for more details.
2021-06-16 21:28:51 +00:00
2023-08-02 11:56:22 +00:00
2020-11-07 19:18:45 +00:00
profiler
^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/profiler.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/profiler.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
To profile individual steps during training and assist in identifying bottlenecks.
2023-08-02 11:56:22 +00:00
See the :doc: `profiler documentation <../tuning/profiler>` for more details.
2020-11-07 19:18:45 +00:00
.. testcode ::
2023-02-27 20:14:23 +00:00
from lightning.pytorch.profilers import SimpleProfiler, AdvancedProfiler
2020-11-07 19:18:45 +00:00
# default used by the Trainer
trainer = Trainer(profiler=None)
# to profile standard training events, equivalent to `profiler=SimpleProfiler()`
trainer = Trainer(profiler="simple")
# advanced profiler for function-level stats, equivalent to `profiler=AdvancedProfiler()`
trainer = Trainer(profiler="advanced")
2021-09-25 05:53:31 +00:00
enable_progress_bar
^^^^^^^^^^^^^^^^^^^
Whether to enable or disable the progress bar. Defaults to True.
.. testcode ::
# default used by the Trainer
trainer = Trainer(enable_progress_bar=True)
# disable progress bar
trainer = Trainer(enable_progress_bar=False)
2021-07-07 11:10:08 +00:00
reload_dataloaders_every_n_epochs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2020-11-07 19:18:45 +00:00
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/reload_dataloaders_every_epoch.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/reload_%E2%80%A8dataloaders_%E2%80%A8every_epoch.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
2022-09-01 14:06:28 +00:00
Set to a positive integer to reload dataloaders every n epochs from your currently used data source.
DataSource can be a `` LightningModule `` or a `` LightningDataModule `` .
2020-11-07 19:18:45 +00:00
.. code-block :: python
2021-07-07 11:10:08 +00:00
# if 0 (default)
2020-11-07 19:18:45 +00:00
train_loader = model.train_dataloader()
2022-09-01 14:06:28 +00:00
# or if using data module: datamodule.train_dataloader()
2020-11-07 19:18:45 +00:00
for epoch in epochs:
for batch in train_loader:
...
2021-07-07 11:10:08 +00:00
# if a positive integer
2020-11-07 19:18:45 +00:00
for epoch in epochs:
2021-07-07 11:10:08 +00:00
if not epoch % reload_dataloaders_every_n_epochs:
train_loader = model.train_dataloader()
2022-09-01 14:06:28 +00:00
# or if using data module: datamodule.train_dataloader()
2020-11-07 19:18:45 +00:00
for batch in train_loader:
2021-07-30 12:10:15 +00:00
...
2020-11-07 19:18:45 +00:00
2022-09-01 14:06:28 +00:00
The pseudocode applies also to the `` val_dataloader `` .
2021-05-21 21:01:13 +00:00
.. _replace-sampler-ddp:
2023-02-22 13:07:02 +00:00
use_distributed_sampler
^^^^^^^^^^^^^^^^^^^^^^^
2020-11-07 19:18:45 +00:00
2023-02-27 20:14:23 +00:00
See :paramref: `lightning.pytorch.trainer.Trainer.params.use_distributed_sampler` .
2020-11-07 19:18:45 +00:00
.. testcode ::
# default used by the Trainer
2023-02-22 13:07:02 +00:00
trainer = Trainer(use_distributed_sampler=True)
2020-11-07 19:18:45 +00:00
By setting to False, you have to add your own distributed sampler:
.. code-block :: python
2021-05-21 21:01:13 +00:00
# in your LightningModule or LightningDataModule
def train_dataloader(self):
2023-02-22 13:07:02 +00:00
dataset = ...
2021-05-21 21:01:13 +00:00
# default used by the Trainer
2023-02-22 13:07:02 +00:00
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
2021-05-21 21:01:13 +00:00
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
return dataloader
2020-11-07 19:18:45 +00:00
2021-10-20 15:32:21 +00:00
strategy
^^^^^^^^
2023-02-20 11:20:50 +00:00
Supports passing different training strategies with aliases (ddp, fsdp, etc) as well as configured strategies.
2021-10-20 15:32:21 +00:00
.. code-block :: python
2023-02-20 11:20:50 +00:00
# Data-parallel training with the DDP strategy on 4 GPUs
2021-10-20 15:32:21 +00:00
trainer = Trainer(strategy="ddp", accelerator="gpu", devices=4)
2023-02-20 11:20:50 +00:00
# Model-parallel training with the FSDP strategy on 4 GPUs
trainer = Trainer(strategy="fsdp", accelerator="gpu", devices=4)
2021-10-20 15:32:21 +00:00
2023-02-20 11:20:50 +00:00
Additionally, you can pass a strategy object.
2021-10-20 15:32:21 +00:00
.. code-block :: python
2023-02-27 20:14:23 +00:00
from lightning.pytorch.strategies import DDPStrategy
2021-10-20 15:32:21 +00:00
2023-02-20 11:20:50 +00:00
trainer = Trainer(strategy=DDPStrategy(static_graph=True), accelerator="gpu", devices=2)
2021-10-20 15:32:21 +00:00
See Also:
2022-04-19 18:15:47 +00:00
- :ref: `Multi GPU Training <multi_gpu>` .
2022-03-03 10:19:05 +00:00
- :doc: `Model Parallel GPU training guide <../advanced/model_parallel>` .
2022-01-06 13:42:44 +00:00
- :doc: `TPU training guide <../accelerators/tpu>` .
2021-10-20 15:32:21 +00:00
2023-02-14 12:38:17 +00:00
2020-11-07 19:18:45 +00:00
sync_batchnorm
^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/sync_batchnorm.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/sync_batchnorm.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
Enable synchronization between batchnorm layers across all GPUs.
.. testcode ::
trainer = Trainer(sync_batchnorm=True)
val_check_interval
^^^^^^^^^^^^^^^^^^
2023-07-03 18:16:45 +00:00
.. video :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/val_check_interval.mp4
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/val_check_interval.jpg
:width: 400
:muted:
2020-11-07 19:18:45 +00:00
How often within one training epoch to check the validation set.
Can specify as float or int.
2022-02-21 21:20:34 +00:00
- pass a `` float `` in the range [0.0, 1.0] to check after a fraction of the training epoch.
2022-07-20 20:33:00 +00:00
- pass an `` int `` to check after a fixed number of training batches. An `` int `` value can only be higher than the number of training
batches when `` check_val_every_n_epoch=None `` , which validates after every `` N `` training batches across epochs or iteration-based training.
2020-11-07 19:18:45 +00:00
.. testcode ::
# default used by the Trainer
trainer = Trainer(val_check_interval=1.0)
# check validation set 4 times during a training epoch
trainer = Trainer(val_check_interval=0.25)
2022-07-20 20:33:00 +00:00
# check validation set every 1000 training batches in the current epoch
trainer = Trainer(val_check_interval=1000)
# check validation set every 1000 training batches across complete epochs or during iteration-based training
2020-11-07 19:18:45 +00:00
# use this when using iterableDataset and your dataset has no length
# (ie: production cases with streaming data)
2022-07-20 20:33:00 +00:00
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)
2020-11-07 19:18:45 +00:00
2021-07-27 11:22:05 +00:00
.. code-block :: python
2021-06-07 10:37:32 +00:00
# Here is the computation to estimate the total number of batches seen within an epoch.
# Find the total number of train batches
total_train_batches = total_train_samples // (train_batch_size * world_size)
# Compute how many times we will call validation during the training loop
val_check_batch = max(1, int(total_train_batches * val_check_interval))
val_checks_per_epoch = total_train_batches / val_check_batch
# Find the total number of validation batches
total_val_batches = total_val_samples // (val_batch_size * world_size)
# Total number of batches run
total_fit_batches = total_train_batches + total_val_batches
2021-10-13 11:50:54 +00:00
enable_model_summary
^^^^^^^^^^^^^^^^^^^^
Whether to enable or disable the model summarization. Defaults to True.
.. testcode ::
# default used by the Trainer
trainer = Trainer(enable_model_summary=True)
# disable summarization
trainer = Trainer(enable_model_summary=False)
# enable custom summarization
2023-02-27 20:14:23 +00:00
from lightning.pytorch.callbacks import ModelSummary
2021-10-13 11:50:54 +00:00
trainer = Trainer(enable_model_summary=True, callbacks=[ModelSummary(max_depth=-1)])
2022-10-12 12:22:01 +00:00
inference_mode
^^^^^^^^^^^^^^
Whether to use :func: `torch.inference_mode` or :func: `torch.no_grad` mode during evaluation
(`` validate `` /`` test `` /`` predict `` )
.. testcode ::
# default used by the Trainer
trainer = Trainer(inference_mode=True)
# Use `torch.no_grad` instead
trainer = Trainer(inference_mode=False)
With :func: `torch.inference_mode` disabled, you can enable the grad of your model layers if required.
.. code-block :: python
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
preds = self.layer1(batch)
with torch.enable_grad():
grad_preds = preds.requires_grad_()
preds2 = self.layer2(grad_preds)
model = LitModel()
trainer = Trainer(inference_mode=False)
trainer.validate(model)
2020-11-07 19:18:45 +00:00
-----
Trainer class API
-----------------
Methods
^^^^^^^
init
*** *
2023-02-27 20:14:23 +00:00
.. automethod :: lightning.pytorch.trainer.Trainer.__init__
2020-11-07 19:18:45 +00:00
:noindex:
fit
*** *
2023-02-27 20:14:23 +00:00
.. automethod :: lightning.pytorch.trainer.Trainer.fit
2020-11-07 19:18:45 +00:00
:noindex:
2021-04-23 10:38:44 +00:00
validate
***** ***
2023-02-27 20:14:23 +00:00
.. automethod :: lightning.pytorch.trainer.Trainer.validate
2021-04-23 10:38:44 +00:00
:noindex:
2020-11-07 19:18:45 +00:00
test
*** *
2023-02-27 20:14:23 +00:00
.. automethod :: lightning.pytorch.trainer.Trainer.test
2020-11-07 19:18:45 +00:00
:noindex:
2021-04-23 10:38:44 +00:00
predict
***** **
2023-02-27 20:14:23 +00:00
.. automethod :: lightning.pytorch.trainer.Trainer.predict
2021-04-23 10:38:44 +00:00
:noindex:
2022-02-28 12:40:48 +00:00
2020-11-07 19:18:45 +00:00
Properties
^^^^^^^^^^
callback_metrics
***** ***** ***** *
2023-03-17 11:28:43 +00:00
The metrics available to callbacks.
2020-11-07 19:18:45 +00:00
2023-09-20 17:09:34 +00:00
This includes metrics logged via :meth: `~lightning.pytorch.core.LightningModule.log` .
2020-11-07 19:18:45 +00:00
2023-05-04 14:18:30 +00:00
.. code-block :: python
2020-11-07 19:18:45 +00:00
2023-03-17 11:28:43 +00:00
def training_step(self, batch, batch_idx):
self.log("a_val", 2.0)
2020-11-07 19:18:45 +00:00
2023-05-04 14:18:30 +00:00
2020-11-07 19:18:45 +00:00
callback_metrics = trainer.callback_metrics
2023-03-17 11:28:43 +00:00
assert callback_metrics["a_val"] == 2.0
2020-11-07 19:18:45 +00:00
2023-03-17 11:28:43 +00:00
logged_metrics
***** ***** *** *
2020-11-07 19:18:45 +00:00
2023-03-17 11:28:43 +00:00
The metrics sent to the loggers.
2020-11-07 19:18:45 +00:00
2023-09-20 17:09:34 +00:00
This includes metrics logged via :meth: `~lightning.pytorch.core.LightningModule.log` with the
:paramref: `~lightning.pytorch.core.LightningModule.log.logger` argument set.
2020-11-07 19:18:45 +00:00
2023-03-17 11:28:43 +00:00
progress_bar_metrics
***** ***** ***** *****
2022-03-07 19:21:37 +00:00
2023-03-17 11:28:43 +00:00
The metrics sent to the progress bar.
2023-09-20 17:09:34 +00:00
This includes metrics logged via :meth: `~lightning.pytorch.core.LightningModule.log` with the
:paramref: `~lightning.pytorch.core.LightningModule.log.prog_bar` argument set.
2023-03-17 11:28:43 +00:00
current_epoch
***** ***** ***
The current epoch, updated after the epoch end hooks are run.
2022-09-14 15:46:34 +00:00
datamodule
***** *****
The current datamodule, which is used by the trainer.
.. code-block :: python
used_datamodule = trainer.datamodule
2022-07-12 09:08:25 +00:00
is_last_batch
***** ***** ***
2023-03-17 11:28:43 +00:00
Whether trainer is executing the last batch.
2022-07-12 09:08:25 +00:00
2022-03-07 19:21:37 +00:00
global_step
***** ***** *
The number of optimizer steps taken (does not reset each epoch).
2023-03-17 11:28:43 +00:00
This includes multiple optimizers (if enabled).
2020-11-07 19:18:45 +00:00
2022-02-09 22:39:41 +00:00
logger
***** **
2020-11-07 19:18:45 +00:00
2023-03-17 11:28:43 +00:00
The first :class: `~lightning.pytorch.loggers.logger.Logger` being used.
2022-02-09 22:39:41 +00:00
loggers
***** ***
2023-05-04 14:18:30 +00:00
The list of :class: `~lightning.pytorch.loggers.logger.Logger` used.
2022-02-09 22:39:41 +00:00
2023-05-04 14:18:30 +00:00
.. code-block :: python
2022-02-09 22:39:41 +00:00
2023-03-17 11:28:43 +00:00
for logger in trainer.loggers:
2022-02-09 22:39:41 +00:00
logger.log_metrics({"foo": 1.0})
2020-11-07 19:18:45 +00:00
2020-11-08 17:16:22 +00:00
log_dir
***** **
2023-03-17 11:28:43 +00:00
2020-11-08 17:16:22 +00:00
The directory for the current experiment. Use this to save images to, etc...
.. code-block :: python
def training_step(self, batch, batch_idx):
img = ...
save_img(img, self.trainer.log_dir)
2020-11-07 19:18:45 +00:00
is_global_zero
***** ***** *** *
2023-03-17 11:28:43 +00:00
Whether this process is the global zero in multi-node training.
2020-11-07 19:18:45 +00:00
.. code-block :: python
def training_step(self, batch, batch_idx):
if self.trainer.is_global_zero:
2021-07-30 12:10:15 +00:00
print("in node 0, accelerator 0")
2020-11-07 19:18:45 +00:00
2023-03-17 11:28:43 +00:00
estimated_stepping_batches
***** ***** ***** ***** ***** *
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
The estimated number of batches that will `` optimizer.step() `` during training.
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
This accounts for gradient accumulation and the current trainer configuration. This might sets up your training
dataloader if hadn't been set up already.
2022-09-14 15:46:34 +00:00
2023-05-04 14:18:30 +00:00
.. code-block :: python
2022-02-28 12:40:48 +00:00
2023-03-17 11:28:43 +00:00
def configure_optimizers(self):
optimizer = ...
stepping_batches = self.trainer.estimated_stepping_batches
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches)
2024-04-15 14:16:17 +00:00
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "interval": "step"},
}
2022-04-11 09:47:09 +00:00
state
*****
The current state of the Trainer, including the current function that is running, the stage of
execution within that function, and the status of the Trainer.
.. code-block :: python
2023-01-27 15:08:40 +00:00
# fn in ("fit", "validate", "test", "predict")
2022-04-11 09:47:09 +00:00
trainer.state.fn
# status in ("initializing", "running", "finished", "interrupted")
trainer.state.status
2023-01-27 15:08:40 +00:00
# stage in ("train", "sanity_check", "validate", "test", "predict")
2022-04-11 09:47:09 +00:00
trainer.state.stage
2022-08-27 12:12:24 +00:00
should_stop
***** ***** *
If you want to terminate the training during `` .fit `` , you can set `` trainer.should_stop=True `` to terminate the training
as soon as possible. Note that, it will respect the arguments `` min_steps `` and `` min_epochs `` to check whether to stop. If these
arguments are set and the `` current_epoch `` or `` global_step `` don't meet these minimum conditions, training will continue until
both conditions are met. If any of these arguments is not set, it won't be considered for the final decision.
.. code-block :: python
# setting `trainer.should_stop` at any point of training will terminate it
class LitModel(LightningModule):
def training_step(self, *args, * *kwargs):
self.trainer.should_stop = True
trainer = Trainer()
model = LitModel()
trainer.fit(model)
.. code-block :: python
# setting `trainer.should_stop` will stop training only after at least 5 epochs have run
class LitModel(LightningModule):
def training_step(self, *args, * *kwargs):
if self.current_epoch == 2:
self.trainer.should_stop = True
trainer = Trainer(min_epochs=5, max_epochs=100)
model = LitModel()
trainer.fit(model)
.. code-block :: python
# setting `trainer.should_stop` will stop training only after at least 5 steps have run
class LitModel(LightningModule):
def training_step(self, *args, * *kwargs):
if self.global_step == 2:
self.trainer.should_stop = True
trainer = Trainer(min_steps=5, max_epochs=100)
model = LitModel()
trainer.fit(model)
.. code-block :: python
# setting `trainer.should_stop` at any until both min_steps and min_epochs are satisfied
class LitModel(LightningModule):
def training_step(self, *args, * *kwargs):
if self.global_step == 7:
self.trainer.should_stop = True
trainer = Trainer(min_steps=5, min_epochs=5, max_epochs=100)
model = LitModel()
trainer.fit(model)
2022-09-14 15:46:34 +00:00
2023-05-05 13:31:02 +00:00
sanity_checking
***** ***** *****
Indicates if the trainer is currently running sanity checking. This property can be useful to disable some hooks,
logging or callbacks during the sanity checking.
.. code-block :: python
def validation_step(self, batch, batch_idx):
...
if not self.trainer.sanity_checking:
self.log("value", value)
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
num_training_batches
***** ***** ***** *****
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
The number of training batches that will be used during `` trainer.fit() `` .
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
num_sanity_val_batches
***** ***** ***** ***** **
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
The number of validation batches that will be used during the sanity-checking part of `` trainer.fit() `` .
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
num_val_batches
***** ***** *****
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
The number of validation batches that will be used during `` trainer.fit() `` or `` trainer.validate() `` .
num_test_batches
2022-09-14 15:46:34 +00:00
***** ***** ***** *
2023-03-17 11:28:43 +00:00
The number of test batches that will be used during `` trainer.test() `` .
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
num_predict_batches
***** ***** ***** *** *
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
The number of prediction batches that will be used during `` trainer.predict() `` .
train_dataloader
***** ***** ***** *
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
The training dataloader(s) used during `` trainer.fit() `` .
2022-09-14 15:46:34 +00:00
val_dataloaders
***** ***** *****
2023-03-17 11:28:43 +00:00
The validation dataloader(s) used during `` trainer.fit() `` or `` trainer.validate() `` .
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
test_dataloaders
***** ***** ***** *
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
The test dataloader(s) used during `` trainer.test() `` .
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
predict_dataloaders
***** ***** ***** *** *
2022-09-14 15:46:34 +00:00
2023-03-17 11:28:43 +00:00
The prediction dataloader(s) used during `` trainer.predict() `` .