* training_end renamed to training_step_end

* training_end renamed to training_step_end

* training_end renamed to training_step_end

* training_end renamed to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* fix lost model reference

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end

* training_end to training_step_end
This commit is contained in:
William Falcon 2020-03-05 12:32:45 -05:00 committed by GitHub
parent 969e929a48
commit 29faea1862
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1419 additions and 664 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 968 KiB

View File

@ -34,11 +34,11 @@ Log metrics
To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, etc...)
1. Training_end, validation_end, test_end will all log anything in the "log" key of the return dict.
1. training_epoch_end, validation_epoch_end, test_epoch_end will all log anything in the "log" key of the return dict.
.. code-block:: python
def training_end(self, outputs):
def training_epoch_end(self, outputs):
loss = some_loss()
...
@ -46,7 +46,7 @@ To plot metrics into whatever logger you passed in (tensorboard, comet, neptune,
results = {'log': logs}
return results
def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
loss = some_loss()
...
@ -54,7 +54,7 @@ To plot metrics into whatever logger you passed in (tensorboard, comet, neptune,
results = {'log': logs}
return results
def test_end(self, outputs):
def test_epoch_end(self, outputs):
loss = some_loss()
...
@ -62,19 +62,7 @@ To plot metrics into whatever logger you passed in (tensorboard, comet, neptune,
results = {'log': logs}
return results
2. Most of the time, you only need training_step and not training_end. You can also return logs from here:
.. code-block:: python
def training_step(self, batch, batch_idx):
loss = some_loss()
...
logs = {'train_loss': loss}
results = {'log': logs}
return results
3. In addition, you can also use any arbitrary functionality from a particular logger from within your LightningModule.
2. In addition, you can also use any arbitrary functionality from a particular logger from within your LightningModule.
For instance, here we log images using tensorboard.
.. code-block:: python

View File

@ -26,7 +26,7 @@ Training loop
- on_batch_start
- tbptt_split_batch
- training_step
- training_end (optional)
- training_step_end (optional)
- backward
- on_after_backward
- optimizer.step()

View File

@ -165,12 +165,13 @@ you will only be operating on one of those pieces.
y_0 = batch
For most metrics, this doesn't really matter. However, if you want
full batch statistics or want to use the outputs of the training_step
to do something like a softmax, you can use the `training_end` step.
to add something to your computational graph (like softmax)
using all batch parts you can use the `training_step_end` step.
.. code-block:: python
def training_end(self, outputs):
def training_step_end(self, outputs):
# only use when on dp
outputs = torch.cat(outputs, dim=1)
softmax = softmax(outputs, dim=1)
out = softmax.mean()
@ -195,9 +196,43 @@ In pseudocode, the full sequence is:
out = gpu_model(batch_split)
all_results.append(out)
# calculate statistics for all parts of the batch
full out = model.training_end(all_results)
# use the full batch for something like softmax
full out = model.training_step_end(all_results)
to illustrate why this is needed, let's look at dataparallel
.. code-block:: python
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(batch)
# on dp or ddp2 if we did softmax now it would be wrong
# because batch is actually a piece of the full batch
return y_hat
def training_step_end(self, batch_parts_outputs):
# batch_parts_outputs has outputs of each part of the batch
# do softmax here
outputs = torch.cat(outputs, dim=1)
softmax = softmax(outputs, dim=1)
out = softmax.mean()
return out
If `training_step_end` is defined it will be called regardless of tpu, dp, ddp, etc... which means
it will behave the same no matter the backend.
Validation and test step also have the same option when using dp
.. code-block:: python
def validation_step_end(self, batch_parts_outputs):
...
def test_step_end(self, batch_parts_outputs):
...
Implement Your Own Distributed (DDP) training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -1,16 +1,73 @@
"""
A LightningModule is a strict superclass of torch.nn.Module but provides an interface to standardize
the "ingredients" for a research or production system.
A LightningModule organizes your PyTorch code into the following sections:
- The model/system definition (__init__)
- The model/system computations (forward)
- What happens in the training loop (training_step, training_end)
- What happens in the validation loop (validation_step, validation_end)
- What happens in the test loop (test_step, test_end)
- What optimizers to use (configure_optimizers)
- What data to use (train_dataloader, val_dataloader, test_dataloader)
.. figure:: /_images/lightning_module/pt_to_pl.png
:alt: Convert from PyTorch to Lightning
Most methods are optional. Here's a minimal example.
Notice a few things.
1. It's the SAME code.
2. The PyTorch code IS NOT abstracted - just organized.
3. All the other code that didn't go in the LightningModule has been automated
for you by the trainer
.. code-block:: python
net = Net()
trainer = Trainer()
trainer.fit(net)
4. There are no .cuda() or .to() calls... Lightning does these for you.
.. code-block:: python
# don't do in lightning
x = torch.Tensor(2, 3)
x = x.cuda()
x = x.to(device)
# do this instead
x = x # leave it alone!
# or to init a new tensor
new_x = torch.Tensor(2, 3)
new_x = new_x.type_as(x.type())
5. There are no samplers for distributed, Lightning also does this for you.
.. code-block:: python
# Don't do in Lightning...
data = MNIST(...)
sampler = DistributedSampler(data)
DataLoader(data, sampler=sampler)
# do this instead
data = MNIST(...)
DataLoader(data)
6. A LightingModule is a torch.nn.Module but with added functionality. Use it as such!
.. code-block:: python
net = Net.load_from_checkpoint(PATH)
net.freeze()
out = net(x)
Thus, to use Lightning, you just need to organize your code which takes about 30 minutes,
(and let's be real, you probably should do anyhow).
------------
Minimal Example
---------------
Here are the only required methods.
.. code-block:: python
@ -37,13 +94,13 @@ Most methods are optional. Here's a minimal example.
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True,
transform=transforms.ToTensor()), batch_size=32)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
Which you can train by doing:
.. code-block:: python
@ -53,7 +110,35 @@ Which you can train by doing:
trainer.fit(model)
If you wanted to add a validation loop
----------
Training loop structure
-----------------------
The general pattern is that each loop (training, validation, test loop)
has 2 methods:
- ``` ___step ```
- ``` ___epoch_end```
To show how lightning calls these, let's use the validation loop as an example
.. code-block:: python
val_outs = []
for val_batch in val_data:
# do something with each batch
out = validation_step(val_batch)
val_outs.append(out)
# do something with the outputs for all batches
# like calculate validation set accuracy or loss
validation_epoch_end(val_outs)
Add validation loop
^^^^^^^^^^^^^^^^^^^
Thus, if we wanted to add a validation loop you would add this to your LightningModule
.. code-block:: python
@ -63,43 +148,153 @@ If you wanted to add a validation loop
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_end(self, outputs):
def validation_epoch_end(self, outputs):
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'val_loss': val_loss_mean}
def val_dataloader(self):
# can also return a list of val dataloaders
return DataLoader(MNIST(os.getcwd(), train=True, download=True,
transform=transforms.ToTensor()), batch_size=32)
return DataLoader(...)
Or add a test loop
Add test loop
^^^^^^^^^^^^^
.. code_block:: python
.. code-block:: python
class CoolModel(pl.LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
return {'test_loss': F.cross_entropy(y_hat, y)}
def test_end(self, outputs):
def test_epoch_end(self, outputs):
test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
return {'test_loss': test_loss_mean}
def test_dataloader(self):
# OPTIONAL
# can also return a list of test dataloaders
return DataLoader(MNIST(os.getcwd(), train=False, download=True,
transform=transforms.ToTensor()), batch_size=32)
return DataLoader(...)
However, the test loop won't ever be called automatically to make sure you
don't run your test data by accident. Instead you have to explicitly call:
.. code-block:: python
# call after training
trainer = Trainer()
trainer.fit(model)
trainer.test()
# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model)
-------------
Training_step_end method
------------------------
When using dataParallel or distributedDataParallel2, the training_step
will be operating on a portion of the batch. This is normally ok but in special
cases like calculating NCE loss using negative samples, we might want to
perform a softmax across all samples in the batch.
For these types of situations, each loop has an additional ```__step_end``` method
which allows you to operate on the pieces of the batch
.. code-block:: python
training_outs = []
for train_batch in train_data:
# dp, ddp2 splits the batch
sub_batches = split_batches_for_dp(batch)
# run training_step on each piece of the batch
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
# do softmax with all pieces
out = training_step_end(batch_parts_outputs)
training_outs.append(out)
# do something with the outputs for all batches
# like calculate validation set accuracy or loss
training_epoch_end(val_outs)
-------------
Remove cuda calls
-----------------
In a LightningModule, all calls to ```.cuda()```
and ```.to(device)``` should be removed. Lightning will do these
automatically. This will allow your code to work on CPUs, TPUs and GPUs.
When you init a new tensor in your code, just use type_as
.. code-block:: python
def training_step(self, batch, batch_idx):
x, y = batch
# put the z on the appropriate gpu or tpu core
z = sample_noise()
z = z.type_as(x.type())
-------------
Data preparation
----------------
Data preparation in PyTorch follows 5 steps:
1. Download
2. Clean and (maybe) save to disk
3. Load inside dataset
4. Apply transforms (rotate, tokenize, etc...)
5. Wrap inside a dataloader
When working in distributed settings, steps 1 and 2 have to be done
from a single GPU, otherwise you will overwrite these files from
every GPU. The lightningModule has the ```prepare_data``` method to
allow for this
.. code-block:: python
def prepare_data(self):
# do stuff that writes to disk or should be done once
# this will only happen from the master GPU or TPU core
.. note:: ```prepare_data``` is called once.
Lifecycle
---------
The methods in the LightningModule are called in this order:
1. ```__init__```
2. ```prepare_data```
3. ```configure_optimizers```
4. ```prepare_data```
5. ```train_dataloader```
If you define a validation loop then
6. ```val_dataloader```
And if you define a test loop:
7. ```test_dataloader```
.. note:: test_dataloader is only called with .test()
In every epoch, the loop methods are called in this frequency:
1. ```validation_step``` called every batch
2. ```validation_epoch_end``` called every epoch
Live demo
---------
Check out this
`COLAB <https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg>`_
for a live demo.
.. note:: Remove all .cuda() or .to() calls from LightningModules. See:
`the multi-gpu training guide for details <multi_gpu.rst>`_.
"""
from .decorators import data_loader

View File

@ -222,26 +222,40 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
"""
def training_end(self, *args, **kwargs):
"""return loss, dict with metrics for tqdm
"""
.. warning:: Deprecated in v0.7.0. use training_step_end instead
"""
:param outputs: What you return in `training_step`.
def training_step_end(self, *args, **kwargs):
"""
Use this when training with dp or ddp2 because training_step will operate
on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
.. note:: If you later switch to ddp or some other mode, this will still be called
so that you don't have to change your code
.. code-block:: python
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
training_step_end(batch_parts_outputs)
:param batch_parts_outputs: What you return in `training_step` for each batch part.
:return dict: dictionary with loss key and optional log, progress keys:
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
In certain cases (dp, ddp2), you might want to use all outputs of every process to do something.
For instance, if using negative samples, you could run a batch via dp and use ALL the outputs
for a single softmax across the full batch (ie: the denominator would use the full batch).
In this case you should define training_end to perform those calculations.
In this case you should define training_step_end to perform those calculations.
Example
-------
.. code-block:: python
# WITHOUT training_end
# WITHOUT training_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
@ -253,7 +267,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
return {'loss': loss}
# --------------
# with training_end to do softmax over the full batch
# with training_step_end to do softmax over the full batch
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
@ -261,48 +275,32 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
out = self.forward(x)
return {'out': out}
def training_end(self, outputs):
def training_step_end(self, outputs):
# this out is now the full size of the batch
out = outputs['out']
# this softmax now uses the full batch size
loss = self.softmax(out)
loss = nce_loss(loss)
return {'loss': loss}
.. note:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
If you define multiple optimizers, this step will also be called with an additional `optimizer_idx` param.
.. code-block:: python
# Multiple optimizers (ie: GANs)
def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 0:
# do training_step with encoder
if optimizer_idx == 1:
# do training_step with decoder
If you add truncated back propagation through time you will also get an additional argument
with the hidden states of the previous step.
.. code-block:: python
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hiddens from the previous truncated backprop step
You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to
break out of the current training epoch early.
"""
def validation_step(self, *args, **kwargs):
r"""
This is the validation loop. It is called for each batch of the validation set.
Whatever is returned from here will be passed in as a list on validation_end.
Operate on a single batch of data from the validation set
In this step you'd normally generate examples or calculate anything of interest such as accuracy.
.. code-block:: python
# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
out = validation_step(train_batch)
val_outs.append(out
validation_epoch_end(val_outs)
Args:
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your dataloader.
A tensor, tuple or list
@ -311,7 +309,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
val datasets used)
Return:
Dict or OrderedDict - passed to the validation_end step
Dict or OrderedDict - passed to the validation_epoch_end
.. code-block:: python
@ -319,7 +317,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
def validation_step(self, batch, batch_idx)
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idxdx)
def validation_step(self, batch, batch_idx, dataloader_idx)
Example
-------
@ -368,12 +366,175 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
"""
def validation_step_end(self, *args, **kwargs):
"""
Use this when training with dp or ddp2 because training_step will operate
on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
.. note:: If you later switch to ddp or some other mode, this will still be called
so that you don't have to change your code
.. code-block:: python
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
validation_step_end(batch_parts_outputs)
:param batch_parts_outputs: What you return in `training_step` for each batch part.
:return dict: dictionary with loss key and optional log, progress keys:
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
In this case you should define validation_step_end to perform those calculations.
Example
-------
.. code-block:: python
# WITHOUT validation_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
loss = self.softmax(out)
loss = nce_loss(loss)
return {'loss': loss}
# --------------
# with validation_step_end to do softmax over the full batch
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
return {'out': out}
def validation_step_end(self, outputs):
# this out is now the full size of the batch
out = outputs['out']
# this softmax now uses the full batch size
loss = nce_loss(loss)
return {'loss': loss}
.. note:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
"""
def validation_end(self, outputs):
"""
.. warning:: Deprecated in v0.7.0. use validation_epoch_end instead.
Will be removed 1.0.0
:param outputs:
:return:
"""
def validation_epoch_end(self, outputs):
"""
Called at end of validation epoch with the output of all validation_steps
.. code-block:: python
# the pseudocode for these calls
val_outs = []
for val_batch in val_data:
out = validation_step(train_batch)
train_outs.append(out
validation_epoch_end(val_outs)
Args:
outputs (list): List of outputs you defined in validation_step, or if there are multiple dataloaders,
a list containing a list of outputs for each dataloader
Return:
Dict or OrderedDict (dict): Dict has the following optional keys:
progress_bar -> Dict for progress bar display. Must have only tensors
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
.. note:: If you didn't define a validation_step, this won't be called.
- The outputs here are strictly for logging or progress bar.
- If you don't need to display anything, don't return anything.
- If you want to manually set current step, you can specify it with the 'step' key in the 'log' Dict.
Example
-------
With a single dataloader
.. code-block:: python
def validation_epoch_end(self, outputs):
val_acc_mean = 0
for output in outputs:
val_acc_mean += output['val_acc']
val_acc_mean /= len(outputs)
tqdm_dict = {'val_acc': val_acc_mean.item()}
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_acc': val_acc_mean.item()}
}
return results
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.
.. code-block:: python
def validation_epoch_end(self, outputs):
val_acc_mean = 0
i = 0
for dataloader_outputs in outputs:
for output in dataloader_outputs:
val_acc_mean += output['val_acc']
i += 1
val_acc_mean /= i
tqdm_dict = {'val_acc': val_acc_mean.item()}
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_acc': val_acc_mean.item(), 'step': self.current_epoch}
}
return results
"""
def test_step(self, *args, **kwargs):
"""return whatever outputs will need to be aggregated in test_end
:param batch: The output of your dataloader. A tensor, tuple or list
:param int batch_idx: Integer displaying which batch this is
:param int dataloader_idx: Integer displaying which dataloader this is (only if multiple test datasets used)
:return dict: Dict or OrderedDict with metrics to display in progress bar. All keys must be tensors.
r"""
Operate on a single batch of data from the test set
In this step you'd normally generate examples or calculate anything of interest such as accuracy.
.. code-block:: python
# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
out = test_step(train_batch)
test_outs.append(out
test_epoch_end(test_outs)
Args:
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your dataloader.
A tensor, tuple or list
batch_idx (int): The index of this batch
dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple
test datasets used)
Return:
Dict or OrderedDict - passed to the test_epoch_end
.. code-block:: python
@ -381,21 +542,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
def test_step(self, batch, batch_idx)
# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idxdx)
**OPTIONAL**
If you don't need to test you don't need to implement this method.
In this step you'd normally generate examples or
calculate anything of interest such as accuracy.
When the validation_step is called, the model has been put in eval mode
and PyTorch gradients have been disabled.
At the end of validation, model goes back to training mode and gradients are enabled.
The dict you return here will be available in the `test_end` method.
This function is used when you execute `trainer.test()`.
def test_step(self, batch, batch_idx, dataloader_idx)
Example
-------
@ -410,50 +557,136 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
out = self.forward(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# all optional...
# return whatever you need for the collation function test_end
# return whatever you need for the collation function validation_end
output = OrderedDict({
'test_loss': loss_test,
'test_acc': torch.tensor(test_acc), # everything must be a tensor
'val_loss': loss_val,
'val_acc': torch.tensor(val_acc), # everything must be a tensor
})
# return an optional dict
return output
If you pass in multiple test datasets, `test_step` will have an additional argument.
If you pass in multiple validation datasets, validation_step will have an additional argument.
.. code-block:: python
# CASE 2: multiple test datasets
# CASE 2: multiple validation datasets
def test_step(self, batch, batch_idx, dataset_idx):
# dataset_idx tells you which dataset this is.
.. note:: If you don't need to validate you don't need to implement this method.
The `dataset_idx` corresponds to the order of datasets returned in `test_dataloader`.
.. note:: When the validation_step is called, the model has been put in eval mode and PyTorch gradients
have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
"""
def validation_end(self, outputs):
"""Outputs has the appended output after each validation step.
def test_step_end(self, *args, **kwargs):
"""
Use this when training with dp or ddp2 because training_step will operate
on only part of the batch. However, this is still optional
and only needed for things like softmax or NCE loss.
:param outputs: List of outputs you defined in validation_step, or if there are multiple dataloaders,
a list containing a list of outputs for each dataloader
:return dict: Dictionary or OrderedDict with optional:
.. note:: If you later switch to ddp or some other mode, this will still be called
so that you don't have to change your code
.. code-block:: python
# pseudocode
sub_batches = split_batches_for_dp(batch)
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
test_step_end(batch_parts_outputs)
:param batch_parts_outputs: What you return in `training_step` for each batch part.
:return dict: dictionary with loss key and optional log, progress keys:
- loss -> tensor scalar [REQUIRED]
- progress_bar -> Dict for progress bar display. Must have only tensors
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
In this case you should define test_step_end to perform those calculations.
Example
-------
.. code-block:: python
# WITHOUT test_step_end
# if used in DP or DDP2, this batch is 1/num_gpus large
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
loss = self.softmax(out)
loss = nce_loss(loss)
return {'loss': loss}
# --------------
# with test_step_end to do softmax over the full batch
def training_step(self, batch, batch_idx):
# batch is 1/num_gpus big
x, y = batch
out = self.forward(x)
return {'out': out}
def test_step_end(self, outputs):
# this out is now the full size of the batch
out = outputs['out']
# this softmax now uses the full batch size
loss = nce_loss(loss)
return {'loss': loss}
.. note:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
"""
def test_end(self, outputs):
"""
.. warning:: Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0
:param outputs:
:return:
"""
def test_epoch_end(self, outputs):
"""
Called at end of test epoch with the output of all test_steps
.. code-block:: python
# the pseudocode for these calls
test_outs = []
for test_batch in test_data:
out = test_step(test_batch)
test_outs.append(out)
test_epoch_end(test_outs)
Args:
outputs (list): List of outputs you defined in test_step, or if there are multiple dataloaders,
a list containing a list of outputs for each dataloader
Return:
Dict or OrderedDict (dict): Dict has the following optional keys:
progress_bar -> Dict for progress bar display. Must have only tensors
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
If you didn't define a validation_step, this won't be called.
Called at the end of the validation loop with the outputs of validation_step.
.. note:: If you didn't define a test_step, this won't be called.
The outputs here are strictly for the progress bar.
If you don't need to display anything, don't return anything.
Any keys present in 'log', 'progress_bar' or the rest of the dictionary
are available for callbacks to access. If you want to manually set current step, you can specify it with
'step' key in the 'log' Dict.
- The outputs here are strictly for logging or progress bar.
- If you don't need to display anything, don't return anything.
- If you want to manually set current step, you can specify it with the 'step' key in the 'log' Dict.
Example
-------
@ -462,120 +695,48 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
.. code-block:: python
def validation_end(self, outputs):
val_loss_mean = 0
val_acc_mean = 0
for output in outputs:
val_loss_mean += output['val_loss']
val_acc_mean += output['val_acc']
val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_loss': val_loss_mean.item()}
}
return results
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.
.. code-block:: python
def validation_end(self, outputs):
val_loss_mean = 0
val_acc_mean = 0
i = 0
for dataloader_outputs in outputs:
for output in dataloader_outputs:
val_loss_mean += output['val_loss']
val_acc_mean += output['val_acc']
i += 1
val_loss_mean /= i
val_acc_mean /= i
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_loss': val_loss_mean.item(), 'step': self.current_epoch}
}
return results
"""
def test_end(self, outputs):
"""Outputs has the appended output after each test step.
:param outputs: List of outputs you defined in test_step, or if there are multiple dataloaders,
a list containing a list of outputs for each dataloader
:return dict: Dict of OrderedDict with metrics to display in progress bar
If you didn't define a test_step, this won't be called.
Called at the end of the test step with the output of each test_step.
The outputs here are strictly for the progress bar.
If you don't need to display anything, don't return anything.
Example
-------
.. code-block:: python
def test_end(self, outputs):
test_loss_mean = 0
def test_epoch_end(self, outputs):
test_acc_mean = 0
for output in outputs:
test_loss_mean += output['test_loss']
test_acc_mean += output['test_acc']
test_loss_mean /= len(outputs)
test_acc_mean /= len(outputs)
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
tqdm_dict = {'test_acc': test_acc_mean.item()}
# show test_loss and test_acc in progress bar but only log test_loss
results = {
'progress_bar': tqdm_dict,
'log': {'test_loss': val_loss_mean.item()}
'log': {'test_acc': test_acc_mean.item()}
}
return results
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.
each test step for that dataloader.
.. code-block:: python
def test_end(self, outputs):
test_loss_mean = 0
def test_epoch_end(self, outputs):
test_acc_mean = 0
i = 0
for dataloader_outputs in outputs:
for output in dataloader_outputs:
test_loss_mean += output['test_loss']
test_acc_mean += output['test_acc']
i += 1
test_loss_mean /= i
test_acc_mean /= i
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
tqdm_dict = {'test_acc': test_acc_mean.item()}
# show test_loss and test_acc in progress bar but only log test_loss
results = {
'progress_bar': tqdm_dict,
'log': {'test_loss': val_loss_mean.item()}
'log': {'test_acc': test_acc_mean.item(), 'step': self.current_epoch}
}
return results
"""
def configure_ddp(self, model, device_ids):
r"""
Override to init DDP in your own way or with your own wrapper.
The only requirements are that:
@ -980,14 +1141,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
return None
@data_loader
def tng_dataloader(self): # todo: remove in v0.8.0
def tng_dataloader(self): # todo: remove in v1.0.0
"""Implement a PyTorch DataLoader.
.. warning:: Deprecated in v0.5.0. use train_dataloader instead.
.. warning:: Deprecated in v0.5.0. use train_dataloader instead. Will be removed 1.0.0
"""
output = self.train_dataloader()
warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0."
" and this method will be removed in v0.8.0", DeprecationWarning)
" and this method will be removed in v1.0.0", DeprecationWarning)
return output
def test_dataloader(self):

View File

@ -16,7 +16,7 @@ PyTorch Lightning supports profiling standard actions in the training loop out o
- on_after_backward
- optimizer_step
- on_batch_end
- training_end
- training_step_end
- on_training_end
Enable simple profiling

View File

@ -1,17 +1,24 @@
"""
Once you've organized your PyTorch code into a LightningModule,
the Trainer automates everything else.
The trainer de-couples the engineering code (16-bit, early stopping, GPU distribution, etc...) from the
science code (GAN, BERT, your project, etc...). It uses many assumptions which are best practices in
AI research today.
.. figure:: /_images/lightning_module/pt_trainer.png
:alt: Convert from PyTorch to Lightning
The trainer automates all parts of training except:
This abstraction achieves the folowing:
- what happens in training , test, val loop
- where the data come from
- which optimizers to use
- how to do the computations
1. You maintain control over all aspects via PyTorch
code without an added abstraction.
The Trainer delegates those calls to your LightningModule which defines how to do those parts.
2. The trainer uses best practices embedded by contributors and users
from top AI labs such as Facebook AI Research, NYU, MIT, Stanford, etc...
3. The trainer allows overriding any key part that you don't want automated.
-----------
Basic use
---------
This is the basic use of the trainer:
@ -23,6 +30,763 @@ This is the basic use of the trainer:
trainer = Trainer()
trainer.fit(model)
--------
Best Practices
--------------
For cluster computing, it's recommended you structure your
main.py file this way
.. code-block:: python
from argparser import AugumentParser
def main(hparams):
model = LightningModule()
trainer = Trainer(gpus=hparams.gpus)
trainer.fit(model)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--gpus', default=None)
args = parser.parse_args()
main(args)
So you can run it like so:
.. code-block:: bash
$ python main.py --gpus 2
------------
Testing
-------
Once you're done training, feel free to run the test set!
(Only right before publishing your paper or pushing to production)
.. code-block:: python
trainer.test()
------------
Deployment / prediction
-----------------------
You just trained a LightningModule which is also just a torch.nn.Module.
Use it to do whatever!
.. code-block:: python
# load model
pretrained_model = LightningModule.load_from_checkpoint(PATH)
pretrained_model.freeze()
# use it for finetuning
def forward(self, x):
features = pretrained_model(x)
classes = classifier(features)
# or for prediction
out = pretrained_model(x)
api_write({'response': out}
-------
Trainer flags
-------------
logger
^^^^^^
Logger (or iterable collection of loggers) for experiment tracking.
.. code-block:: python
Trainer(logger=logger)
Example::
from pytorch_lightning.loggers import TensorBoardLogger
# default logger used by trainer
logger = TensorBoardLogger(
save_dir=os.getcwd(),
version=self.slurm_job_id,
name='lightning_logs'
)
checkpoint_callback
^^^^^^^^^^^^^^^^^^^
Callback for checkpointing.
.. code-block:: python
trainer = Trainer(checkpoint_callback=checkpoint_callback)
Example::
from pytorch_lightning.callbacks import ModelCheckpoint
# default used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min',
prefix=''
)
early_stop_callback
^^^^^^^^^^^^^^^^^^^
Callback for early stopping.
early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`)
- If set to ``True``, then a default callback monitoring ``'val_loss'`` is created.
- Will raise an error if ``'val_loss'`` is not found.
- If set to ``False``, then early stopping will be disabled.
- If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
- If ``'val_loss'`` is not found will work as if early stopping is disabled.
- Default: ``None``.
.. code-block:: python
trainer = Trainer(early_stop_callback=early_stop_callback)
Example::
from pytorch_lightning.callbacks import EarlyStopping
# default used by the Trainer
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=False,
verbose=False,
mode='min'
)
callbacks
^^^^^^^^^
callbacks: Add a list of callbacks.
.. code-block:: python
# a list of callbacks
callbacks = [PrintCallback()]
trainer = Trainer(callbacks=callbacks)
Example::
from pytorch_lightning.callbacks import Callback
class PrintCallback(Callback):
def on_train_start(self):
print("Training is started!")
def on_train_end(self):
print(f"Training is done. The logs are: {self.trainer.logs}")
default_save_path
^^^^^^^^^^^^^^^^^
Default path for logs and weights when no logger/ckpt_callback passed
Example::
# default used by the Trainer
trainer = Trainer(default_save_path=os.getcwd())
gradient_clip_val
^^^^^^^^^^^^^^^^^
Gradient clipping value
- 0 means don't clip.
Example::
# default used by the Trainer
trainer = Trainer(gradient_clip_val=0.0)
gradient_clip
.. warning: .. deprecated:: 0.5.0
Use `gradient_clip_val` instead. Will remove 0.8.0.
process_position
^^^^^^^^^^^^^^^^
orders the tqdm bar when running multiple models on same machine.
Example::
# default used by the Trainer
trainer = Trainer(process_position=0)
num_nodes
^^^^^^^^^
Number of GPU nodes for distributed training.
Example::
# default used by the Trainer
trainer = Trainer(num_nodes=1)
# to train on 8 nodes
trainer = Trainer(num_nodes=8)
nb_gpu_nodes
..warning:: .. deprecated:: 0.5.0
Use `num_nodes` instead. Will remove 0.8.0.
gpus
^^^^
- Number of GPUs to train on
- or Which GPUs to train on
- can handle strings
Example::
# default used by the Trainer (ie: train on CPU)
trainer = Trainer(gpus=None)
# int: train on 2 gpus
trainer = Trainer(gpus=2)
# list: train on GPUs 1, 4 (by bus ordering)
trainer = Trainer(gpus=[1, 4])
trainer = Trainer(gpus='1, 4') # equivalent
# -1: train on all gpus
trainer = Trainer(gpus=-1)
trainer = Trainer(gpus='-1') # equivalent
# combine with num_nodes to train on multiple GPUs across nodes
trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total
num_tpu_cores
^^^^^^^^^^^^^
How many TPU cores to train on (1 or 8).
A single TPU v2 or v3 has 8 cores. A TPU pod has
up to 2048 cores. A slice of a POD means you get as many cores
as you request.
You MUST use DistributedDataSampler with your dataloader for this
to work. Your effective batch size is batch_size * total tpu cores.
This parameter can be either 1 or 8.
Example::
# your_trainer_file.py
# default used by the Trainer (ie: train on CPU)
trainer = Trainer(num_tpu_cores=None)
# int: train on a single core
trainer = Trainer(num_tpu_cores=1)
# int: train on all cores few cores
trainer = Trainer(num_tpu_cores=8)
# for 8+ cores must submit via xla script with
# a max of 8 cores specified. The XLA script
# will duplicate script onto each TPU in the POD
trainer = Trainer(num_tpu_cores=8)
# -1: train on all available TPUs
trainer = Trainer(num_tpu_cores=-1)
To train on more than 8 cores (ie: a POD),
submit this script using the xla_dist script.
Example::
$ python -m torch_xla.distributed.xla_dist
--tpu=$TPU_POD_NAME
--conda-env=torch-xla-nightly
--env=XLA_USE_BF16=1
-- python your_trainer_file.py
log_gpu_memory
^^^^^^^^^^^^^^
Options:
- None
- 'min_max'
- 'all'
.. note:: Might slow performance because it uses the output of nvidia-smi.
Example::
# default used by the Trainer
trainer = Trainer(log_gpu_memory=None)
# log all the GPUs (on master node only)
trainer = Trainer(log_gpu_memory='all')
# log only the min and max memory on the master node
trainer = Trainer(log_gpu_memory='min_max')
show_progress_bar
^^^^^^^^^^^^^^^^^
If true shows tqdm progress bar
Example::
# default used by the Trainer
trainer = Trainer(show_progress_bar=True)
progress_bar_refresh_rate
^^^^^^^^^^^^^^^^^^^^^^^^^
How often to refresh progress bar (in steps)
overfit_pct
^^^^^^^^^^^
uses this much data of all datasets.
Example::
# default used by the Trainer
trainer = Trainer(overfit_pct=0.0)
# use only 1% of the train, test, val datasets
trainer = Trainer(overfit_pct=0.01)
track_grad_norm
^^^^^^^^^^^^^^^
- no tracking (-1)
- Otherwise tracks that norm (2 for 2-norm)
Example::
# default used by the Trainer
trainer = Trainer(track_grad_norm=-1)
# track the 2-norm
trainer = Trainer(track_grad_norm=2)
check_val_every_n_epoch
^^^^^^^^^^^^^^^^^^^^^^^
Check val every n train epochs.
Example::
# default used by the Trainer
trainer = Trainer(check_val_every_n_epoch=1)
# run val loop every 10 training epochs
trainer = Trainer(check_val_every_n_epoch=10)
fast_dev_run
^^^^^^^^^^^^
Runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
Under the hood the pseudocode looks like this:
.. code-block:: python
# loading
__init__()
prepare_data
# test training step
training_batch = next(train_dataloader)
training_step(training_batch)
# test val step
val_batch = next(val_dataloader)
out = validation_step(val_batch)
validation_epoch_end([out])
Example::
# default used by the Trainer
trainer = Trainer(fast_dev_run=False)
# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)
accumulate_grad_batches
^^^^^^^^^^^^^^^^^^^^^^^
Accumulates grads every k batches or as set up in the dict.
Example::
# default used by the Trainer (no accumulation)
trainer = Trainer(accumulate_grad_batches=1)
# accumulate every 4 batches (effective batch size is batch*4)
trainer = Trainer(accumulate_grad_batches=4)
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
max_epochs
^^^^^^^^^^
Stop training once this number of epochs is reached
Example::
# default used by the Trainer
trainer = Trainer(max_epochs=1000)
max_nb_epochs
.. warning:: .. deprecated:: 0.5.0
Use `max_epochs` instead. Will remove 0.8.0.
min_epochs
^^^^^^^^^^
Force training for at least these many epochs
Example::
# default used by the Trainer
trainer = Trainer(min_epochs=1)
min_nb_epochs:
.. warning:: .. deprecated:: 0.5.0
Use `min_nb_epochs` instead. Will remove 0.8.0.
max_steps
^^^^^^^^^
Stop training after this number of steps. Disabled by default (None).
Training will stop if max_steps or max_epochs have reached (earliest).
Example::
# Stop after 100 steps
trainer = Trainer(max_steps=100)
min_steps
^^^^^^^^^
Force training for at least these number of steps. Disabled by default (None).
Trainer will train model for at least min_steps or min_epochs (latest).
Example::
# Run at least for 100 steps (disable min_epochs)
trainer = Trainer(min_steps=100, min_epochs=0)
train_percent_check
^^^^^^^^^^^^^^^^^^^
How much of training dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Example::
# default used by the Trainer
trainer = Trainer(train_percent_check=1.0)
# run through only 25% of the training set each epoch
trainer = Trainer(train_percent_check=0.25)
val_percent_check
^^^^^^^^^^^^^^^^^
How much of validation dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Example::
# default used by the Trainer
trainer = Trainer(val_percent_check=1.0)
# run through only 25% of the validation set each epoch
trainer = Trainer(val_percent_check=0.25)
test_percent_check
^^^^^^^^^^^^^^^^^^
How much of test dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Example::
# default used by the Trainer
trainer = Trainer(test_percent_check=1.0)
# run through only 25% of the test set each epoch
trainer = Trainer(test_percent_check=0.25)
val_check_interval:
How often within one training epoch to check the validation set.
Can specify as float or int.
- use (float) to check within a training epoch
- use (int) to check every n steps (batches)
Example::
# default used by the Trainer
trainer = Trainer(val_check_interval=1.0)
# check validation set 4 times during a training epoch
trainer = Trainer(val_check_interval=0.25)
# check validation set every 1000 training batches
# use this when using iterableDataset and your dataset has no length
# (ie: production cases with streaming data)
trainer = Trainer(val_check_interval=1000)
log_save_interval
^^^^^^^^^^^^^^^^^
Writes logs to disk this often
Example::
# default used by the Trainer
trainer = Trainer(log_save_interval=100)
row_log_interval
^^^^^^^^^^^^^^^^
How often to add logging rows (does not write to disk)
Example::
# default used by the Trainer
trainer = Trainer(row_log_interval=10)
add_row_log_interval
.. warning:: .. deprecated:: 0.5.0
Use `row_log_interval` instead. Will remove 0.8.0.
distributed_backend
^^^^^^^^^^^^^^^^^^^
The distributed backend to use.
- ('dp') is DataParallel (split batch among GPUs of same machine)
- ('ddp') is DistributedDataParallel (each gpu on each node trains, and syncs grads)
- ('ddp2') dp on node, ddp across nodes
Example::
# default used by the Trainer
trainer = Trainer(distributed_backend=None)
# dp = DataParallel (split a batch onto k gpus on same machine).
trainer = Trainer(gpus=2, distributed_backend='dp')
# ddp = DistributedDataParallel
# Each gpu trains by itself on a subset of the data.
# Gradients sync across all gpus and all machines.
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
# ddp2 = DistributedDataParallel + dp
# behaves like dp on every node
# syncs gradients across nodes like ddp
# useful for things like increasing the number of negative samples
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
use_amp:
.. warning:: .. deprecated:: 0.6.1
Use `precision` instead. Will remove 0.8.0.
precision
^^^^^^^^^
Full precision (32), half precision (16).
Can be used on CPU, GPU or TPUs.
If used on TPU will use torch.bfloat16 but tensor printing
will still show torch.float32.
Example::
# default used by the Trainer
trainer = Trainer(precision=32)
# 16-bit precision
trainer = Trainer(precision=16)
# one day
trainer = Trainer(precision=8|4|2)
print_nan_grads
^^^^^^^^^^^^^^^
Prints gradients with nan values
Example::
# default used by the Trainer
trainer = Trainer(print_nan_grads=False)
weights_summary
^^^^^^^^^^^^^^^
Prints a summary of the weights when training begins.
Options: 'full', 'top', None.
Example::
# default used by the Trainer (ie: print all weights)
trainer = Trainer(weights_summary='full')
# print only the top level modules
trainer = Trainer(weights_summary='top')
# don't print a summary
trainer = Trainer(weights_summary=None)
weights_save_path
^^^^^^^^^^^^^^^^^
Where to save weights if specified.
Example::
# default used by the Trainer
trainer = Trainer(weights_save_path=os.getcwd())
# save to your custom path
trainer = Trainer(weights_save_path='my/path')
# if checkpoint callback used, then overrides the weights path
# **NOTE: this saves weights to some/path NOT my/path
checkpoint_callback = ModelCheckpoint(filepath='some/path')
trainer = Trainer(
checkpoint_callback=checkpoint_callback,
weights_save_path='my/path'
)
amp_level
^^^^^^^^^
The optimization level to use (O1, O2, etc...)
for 16-bit GPU precision (using NVIDIA apex under the hood).
Check nvidia docs for level (https://nvidia.github.io/apex/amp.html#opt-levels)
Example::
# default used by the Trainer
trainer = Trainer(amp_level='O1')
num_sanity_val_steps
^^^^^^^^^^^^^^^^^^^^
Sanity check runs n batches of val before starting the training routine.
This catches any bugs in your validation without having to wait for the first validation check.
The Trainer uses 5 steps by default. Turn it off or modify it here.
Example::
# default used by the Trainer
trainer = Trainer(num_sanity_val_steps=5)
# turn it off
trainer = Trainer(num_sanity_val_steps=0)
nb_sanity_val_steps:
.. warning:: .. deprecated:: 0.5.0
Use `num_sanity_val_steps` instead. Will remove 0.8.0.
truncated_bptt_steps
^^^^^^^^^^^^^^^^^^^^
Truncated back prop breaks performs backprop every k steps of
a much longer sequence If this is enabled, your batches will automatically get truncated
and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence
dimension. (`Williams et al. "An efficient gradient-based algorithm for on-line training of
recurrent network trajectories."
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
Example::
# default used by the Trainer (ie: disabled)
trainer = Trainer(truncated_bptt_steps=None)
# backprop every 5 steps in a batch
trainer = Trainer(truncated_bptt_steps=5)
Lightning takes care to split your batch along the time-dimension.
.. note:: If you need to modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
.. note:: Using this feature requires updating your LightningModule's
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
resume_from_checkpoint
^^^^^^^^^^^^^^^^^^^^^^
To resume training from a specific checkpoint pass in the path here.k
Example::
# default used by the Trainer
trainer = Trainer(resume_from_checkpoint=None)
# resume from a specific checkpoint
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
profiler
^^^^^^^^
To profile individual steps during training and assist in identifying bottlenecks.
Example::
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
# default used by the Trainer
trainer = Trainer(profiler=None)
# to profile standard training events
trainer = Trainer(profiler=True)
# equivalent to profiler=True
profiler = Profiler()
trainer = Trainer(profiler=profiler)
# advanced profiler for function-level stats
profiler = AdvancedProfiler()
trainer = Trainer(profiler=profiler)
reload_dataloaders_every_epoch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Set to True to reload dataloaders every epoch
.. code-block:: python
# if False (default)
train_loader = model.train_dataloader()
for epoch in epochs:
for batch in train_loader:
...
# if True
for epoch in epochs:
train_loader = model.train_dataloader()
for batch in train_loader:
benchmark
^^^^^^^^^
If true enables cudnn.benchmark.
This flag is likely to increase the speed of your system if your
input sizes don't change. However, if it does, then it will likely
make your system slower.
The speedup comes from allowing the cudnn auto-tuner to find the best
algorithm for the hardware `[see discussion here]
<https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`_.
"""
from .trainer import Trainer

View File

@ -131,6 +131,7 @@ from abc import ABC, abstractmethod
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import warnings
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.debugging import MisconfigurationException
@ -260,6 +261,18 @@ class TrainerEvaluationLoopMixin(ABC):
# -----------------
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
# on dp / ddp2 might still want to do something with the batch parts
if test_mode:
if self.is_overriden('test_step_end'):
model_ref = self.get_model()
with self.profiler.profile('test_step_end'):
output = model_ref.test_step_end(output)
else:
if self.is_overriden('validation_step_end'):
model_ref = self.get_model()
with self.profiler.profile('validation_step_end'):
output = model_ref.validation_step_end(output)
# track outputs for collation
dl_outputs.append(output)
@ -280,10 +293,23 @@ class TrainerEvaluationLoopMixin(ABC):
# give model a chance to do something with the outputs (and method defined)
model = self.get_model()
if test_mode and self.is_overriden('test_epoch_end'):
eval_results = model.test_epoch_end(outputs)
elif self.is_overriden('validation_epoch_end'):
eval_results = model.validation_epoch_end(outputs)
# TODO: remove in v 1.0.0
if test_mode and self.is_overriden('test_end'):
eval_results = model.test_end(outputs)
m = 'test_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
'Use test_epoch_end instead.'
warnings.warn(m, DeprecationWarning)
elif self.is_overriden('validation_end'):
eval_results = model.validation_end(outputs)
m = 'validation_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
'Use validation_epoch_end instead.'
warnings.warn(m, DeprecationWarning)
# enable train mode again
model.train()
@ -392,7 +418,7 @@ class TrainerEvaluationLoopMixin(ABC):
output = model(*args)
return output
# single GPU
# single GPU data transfer
if self.single_gpu:
# for single GPU put inputs on gpu manually
root_gpu = 0
@ -401,12 +427,12 @@ class TrainerEvaluationLoopMixin(ABC):
batch = self.transfer_batch_to_gpu(batch, root_gpu)
args[0] = batch
# TPU
# TPU data transfer
if self.use_tpu:
batch = self.transfer_batch_to_tpu(batch)
args[0] = batch
# CPU
# CPU, TPU or gpu step
if test_mode:
output = model.test_step(*args)
else:

View File

@ -126,497 +126,72 @@ class Trainer(TrainerIOMixin,
Args:
logger: Logger (or iterable collection of loggers) for experiment tracking.
Example::
from pytorch_lightning.loggers import TensorBoardLogger
# default logger used by trainer
logger = TensorBoardLogger(
save_dir=os.getcwd(),
version=self.slurm_job_id,
name='lightning_logs'
)
Trainer(logger=logger)
checkpoint_callback: Callback for checkpointing.
Example::
from pytorch_lightning.callbacks import ModelCheckpoint
# default used by the Trainer
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(),
save_best_only=True,
verbose=True,
monitor='val_loss',
mode='min',
prefix=''
)
trainer = Trainer(checkpoint_callback=checkpoint_callback)
early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`):
Callback for early stopping.
If set to ``True``, then a default callback monitoring ``'val_loss'`` is created.
Will raise an error if ``'val_loss'`` is not found.
If set to ``False``, then early stopping will be disabled.
If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
If ``'val_loss'`` is not found will work as if early stopping is disabled.
Default: ``None``.
Example::
from pytorch_lightning.callbacks import EarlyStopping
# default used by the Trainer
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=False,
verbose=False,
mode='min'
)
trainer = Trainer(early_stop_callback=early_stop_callback)
callbacks: Add a list of callbacks.
Example::
from pytorch_lightning.callbacks import Callback
class PrintCallback(Callback):
def on_train_start(self):
print("Training is started!")
def on_train_end(self):
print(f"Training is done. The logs are: {self.trainer.logs}")
# a list of callbacks
callbacks = [PrintCallback()]
trainer = Trainer(callbacks=callbacks)
default_save_path: Default path for logs and weights when no logger/ckpt_callback passed
Example::
# default used by the Trainer
trainer = Trainer(default_save_path=os.getcwd())
gradient_clip_val: 0 means don't clip.
Example::
# default used by the Trainer
trainer = Trainer(gradient_clip_val=0.0)
gradient_clip:
.. warning: .. deprecated:: 0.5.0
Use `gradient_clip_val` instead. Will remove 0.8.0.
.. warning:: .. deprecated:: 0.6.1
Use `gradient_clip_val` instead. Will remove 0.8.0.
process_position: orders the tqdm bar when running multiple models on same machine.
Example::
# default used by the Trainer
trainer = Trainer(process_position=0)
num_nodes: number of GPU nodes for distributed training.
Example::
# default used by the Trainer
trainer = Trainer(num_nodes=1)
# to train on 8 nodes
trainer = Trainer(num_nodes=8)
nb_gpu_nodes:
..warning:: .. deprecated:: 0.5.0
Use `num_nodes` instead. Will remove 0.8.0.
.. warning:: .. deprecated:: 0.6.1
Use `num_nodes` instead. Will remove 0.8.0.
gpus: Which GPUs to train on.
Example::
# default used by the Trainer (ie: train on CPU)
trainer = Trainer(gpus=None)
# int: train on 2 gpus
trainer = Trainer(gpus=2)
# list: train on GPUs 1, 4 (by bus ordering)
trainer = Trainer(gpus=[1, 4])
trainer = Trainer(gpus='1, 4') # equivalent
# -1: train on all gpus
trainer = Trainer(gpus=-1)
trainer = Trainer(gpus='-1') # equivalent
# combine with num_nodes to train on multiple GPUs across nodes
trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total
num_tpu_cores: How many TPU cores to train on (1 or 8).
A single TPU v2 or v3 has 8 cores. A TPU pod has
up to 2048 cores. A slice of a POD means you get as many cores
as you request.
You MUST use DistributedDataSampler with your dataloader for this
to work. Your effective batch size is batch_size * total tpu cores.
This parameter can be either 1 or 8.
Example::
# your_trainer_file.py
# default used by the Trainer (ie: train on CPU)
trainer = Trainer(num_tpu_cores=None)
# int: train on a single core
trainer = Trainer(num_tpu_cores=1)
# int: train on all cores few cores
trainer = Trainer(num_tpu_cores=8)
# for 8+ cores must submit via xla script with
# a max of 8 cores specified. The XLA script
# will duplicate script onto each TPU in the POD
trainer = Trainer(num_tpu_cores=8)
# -1: train on all available TPUs
trainer = Trainer(num_tpu_cores=-1)
To train on more than 8 cores (ie: a POD),
submit this script using the xla_dist script.
Example::
$ python -m torch_xla.distributed.xla_dist
--tpu=$TPU_POD_NAME
--conda-env=torch-xla-nightly
--env=XLA_USE_BF16=1
-- python your_trainer_file.py
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
because it uses the output of nvidia-smi.
Example::
# default used by the Trainer
trainer = Trainer(log_gpu_memory=None)
# log all the GPUs (on master node only)
trainer = Trainer(log_gpu_memory='all')
# log only the min and max memory on the master node
trainer = Trainer(log_gpu_memory='min_max')
show_progress_bar: If true shows tqdm progress bar
Example::
# default used by the Trainer
trainer = Trainer(show_progress_bar=True)
progress_bar_refresh_rate: How often to refresh progress bar (in steps)
overfit_pct: uses this much data of all datasets.
Example::
# default used by the Trainer
trainer = Trainer(overfit_pct=0.0)
# use only 1% of the train, test, val datasets
trainer = Trainer(overfit_pct=0.01)
track_grad_norm: -1 no tracking. Otherwise tracks that norm
Example::
# default used by the Trainer
trainer = Trainer(track_grad_norm=-1)
# track the 2-norm
trainer = Trainer(track_grad_norm=2)
check_val_every_n_epoch: Check val every n train epochs.
Example::
# default used by the Trainer
trainer = Trainer(check_val_every_n_epoch=1)
# run val loop every 10 training epochs
trainer = Trainer(check_val_every_n_epoch=10)
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
Example::
# default used by the Trainer
trainer = Trainer(fast_dev_run=False)
# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
Example::
# default used by the Trainer (no accumulation)
trainer = Trainer(accumulate_grad_batches=1)
# accumulate every 4 batches (effective batch size is batch*4)
trainer = Trainer(accumulate_grad_batches=4)
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
max_epochs: Stop training once this number of epochs is reached.
Example::
# default used by the Trainer
trainer = Trainer(max_epochs=1000)
max_nb_epochs:
.. warning:: .. deprecated:: 0.5.0
Use `max_epochs` instead. Will remove 0.8.0.
.. warning:: .. deprecated:: 0.6.1
Use `max_epochs` instead. Will remove 0.8.0.
min_epochs: Force training for at least these many epochs
Example::
# default used by the Trainer
trainer = Trainer(min_epochs=1)
min_nb_epochs:
.. warning:: .. deprecated:: 0.5.0
Use `min_nb_epochs` instead. Will remove 0.8.0.
.. warning:: .. deprecated:: 0.6.1
Use `min_epochs` instead. Will remove 0.8.0.
max_steps: Stop training after this number of steps. Disabled by default (None).
Training will stop if max_steps or max_epochs have reached (earliest).
Example::
# Stop after 100 steps
trainer = Trainer(max_steps=100)
min_steps: Force training for at least these number of steps. Disabled by default (None).
Trainer will train model for at least min_steps or min_epochs (latest).
Example::
# Run at least for 100 steps (disable min_epochs)
trainer = Trainer(min_steps=100, min_epochs=0)
train_percent_check: How much of training dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Example::
# default used by the Trainer
trainer = Trainer(train_percent_check=1.0)
# run through only 25% of the training set each epoch
trainer = Trainer(train_percent_check=0.25)
val_percent_check: How much of validation dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Example::
# default used by the Trainer
trainer = Trainer(val_percent_check=1.0)
# run through only 25% of the validation set each epoch
trainer = Trainer(val_percent_check=0.25)
test_percent_check: How much of test dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
Example::
# default used by the Trainer
trainer = Trainer(test_percent_check=1.0)
# run through only 25% of the test set each epoch
trainer = Trainer(test_percent_check=0.25)
val_check_interval: How often within one training epoch to check the validation set
If float, % of tng epoch. If int, check every n batch
Example::
# default used by the Trainer
trainer = Trainer(val_check_interval=1.0)
# check validation set 4 times during a training epoch
trainer = Trainer(val_check_interval=0.25)
# check validation set every 1000 training batches
# use this when using iterableDataset and your dataset has no length
# (ie: production cases with streaming data)
trainer = Trainer(val_check_interval=1000)
log_save_interval: Writes logs to disk this often
Example::
# default used by the Trainer
trainer = Trainer(log_save_interval=100)
row_log_interval: How often to add logging rows (does not write to disk)
Example::
# default used by the Trainer
trainer = Trainer(row_log_interval=10)
add_row_log_interval:
.. warning:: .. deprecated:: 0.5.0
Use `row_log_interval` instead. Will remove 0.8.0.
.. warning:: .. deprecated:: 0.6.1
Use `row_log_interval` instead. Will remove 0.8.0.
distributed_backend: The distributed backend to use.
Options: 'dp', 'ddp', 'ddp2'.
Example::
# default used by the Trainer
trainer = Trainer(distributed_backend=None)
# dp = DataParallel (split a batch onto k gpus on same machine).
trainer = Trainer(gpus=2, distributed_backend='dp')
# ddp = DistributedDataParallel
# Each gpu trains by itself on a subset of the data.
# Gradients sync across all gpus and all machines.
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
# ddp2 = DistributedDataParallel + dp
# behaves like dp on every node
# syncs gradients across nodes like ddp
# useful for things like increasing the number of negative samples
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
use_amp:
.. warning:: .. deprecated:: 0.6.1
Use `precision` instead. Will remove 0.8.0.
.. warning:: .. deprecated:: 0.7.0
Use `precision` instead. Will remove 0.8.0.
precision: Full precision (32), half precision (16).
Can be used on CPU, GPU or TPUs.
If used on TPU will use torch.bfloat16 but tensor printing
will still show torch.float32.
Example::
# default used by the Trainer
trainer = Trainer(precision=32)
# 16-bit precision
trainer = Trainer(precision=16)
# one day
trainer = Trainer(precision=8|4|2)
print_nan_grads: Prints gradients with nan values
Example::
# default used by the Trainer
trainer = Trainer(print_nan_grads=False)
weights_summary: Prints a summary of the weights when training begins.
Options: 'full', 'top', None.
Example::
# default used by the Trainer (ie: print all weights)
trainer = Trainer(weights_summary='full')
# print only the top level modules
trainer = Trainer(weights_summary='top')
# don't print a summary
trainer = Trainer(weights_summary=None)
weights_save_path: Where to save weights if specified.
Example::
# default used by the Trainer
trainer = Trainer(weights_save_path=os.getcwd())
# save to your custom path
trainer = Trainer(weights_save_path='my/path')
# if checkpoint callback used, then overrides the weights path
# **NOTE: this saves weights to some/path NOT my/path
checkpoint_callback = ModelCheckpoint(filepath='some/path')
trainer = Trainer(
checkpoint_callback=checkpoint_callback,
weights_save_path='my/path'
)
amp_level: The optimization level to use (O1, O2, etc...).
Check nvidia docs for level (https://nvidia.github.io/apex/amp.html#opt-levels)
Example::
# default used by the Trainer
trainer = Trainer(amp_level='O1')
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
This catches any bugs in your validation without having to wait for the first validation check.
The Trainer uses 5 steps by default. Turn it off or modify it here.
Example::
# default used by the Trainer
trainer = Trainer(num_sanity_val_steps=5)
# turn it off
trainer = Trainer(num_sanity_val_steps=0)
nb_sanity_val_steps:
.. warning:: .. deprecated:: 0.5.0
Use `num_sanity_val_steps` instead. Will remove 0.8.0.
.. warning:: .. deprecated:: 0.7.0
Use `num_sanity_val_steps` instead. Will remove 0.8.0.
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
a much longer sequence If this is enabled, your batches will automatically get truncated
and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence
dimension. (`Williams et al. "An efficient gradient-based algorithm for on-line training of
recurrent network trajectories."
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
Example::
# default used by the Trainer (ie: disabled)
trainer = Trainer(truncated_bptt_steps=None)
# backprop every 5 steps in a batch
trainer = Trainer(truncated_bptt_steps=5)
Lightning takes care to split your batch along the time-dimension.
.. note:: If you need to modify how the batch is split,
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
.. note:: Using this feature requires updating your LightningModule's
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.k
Example::
# default used by the Trainer
trainer = Trainer(resume_from_checkpoint=None)
# resume from a specific checkpoint
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
profiler: To profile individual steps during training and assist in
identifying bottlenecks.
Example::
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
# default used by the Trainer
trainer = Trainer(profiler=None)
# to profile standard training events
trainer = Trainer(profiler=True)
# equivalent to profiler=True
profiler = Profiler()
trainer = Trainer(profiler=profiler)
# advanced profiler for function-level stats
profiler = AdvancedProfiler()
trainer = Trainer(profiler=profiler)
reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch
benchmark (bool): If true enables cudnn.benchmark.
This flag is likely to increase the speed of your system if your
input sizes don't change. However, if it does, then it will likely
make your system slower.
The speedup comes from allowing the cudnn auto-tuner to find the best
algorithm for the hardware `[see discussion here]
<https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`_.
.. warning:: Following arguments become deprecated and they will be removed in v0.8.0:
- `nb_sanity_val_steps`
"""
# Init callbacks

View File

@ -698,12 +698,24 @@ class TrainerTrainLoopMixin(ABC):
else:
output = self.model.training_step(*args)
# allow any mode to define training_step_end
# do something will all the dp outputs (like softmax)
if self.is_overriden('training_step_end'):
model_ref = self.get_model()
with self.profiler.profile('training_step_end'):
output = model_ref.training_step_end(output)
# allow any mode to define training_end
# TODO: remove in 1.0.0
if self.is_overriden('training_end'):
model_ref = self.get_model()
with self.profiler.profile('training_end'):
output = model_ref.training_end(output)
m = 'training_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
'Use training_epoch_end instead'
warnings.warn(m, DeprecationWarning)
# format and reduce outputs accordingly
output = self.process_output(output, train=True)