Steps (#1051)
* training_end renamed to training_step_end * training_end renamed to training_step_end * training_end renamed to training_step_end * training_end renamed to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * fix lost model reference * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end * training_end to training_step_end
This commit is contained in:
parent
969e929a48
commit
29faea1862
Binary file not shown.
After Width: | Height: | Size: 1.0 MiB |
Binary file not shown.
After Width: | Height: | Size: 968 KiB |
|
@ -34,11 +34,11 @@ Log metrics
|
|||
|
||||
To plot metrics into whatever logger you passed in (tensorboard, comet, neptune, etc...)
|
||||
|
||||
1. Training_end, validation_end, test_end will all log anything in the "log" key of the return dict.
|
||||
1. training_epoch_end, validation_epoch_end, test_epoch_end will all log anything in the "log" key of the return dict.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_end(self, outputs):
|
||||
def training_epoch_end(self, outputs):
|
||||
loss = some_loss()
|
||||
...
|
||||
|
||||
|
@ -46,7 +46,7 @@ To plot metrics into whatever logger you passed in (tensorboard, comet, neptune,
|
|||
results = {'log': logs}
|
||||
return results
|
||||
|
||||
def validation_end(self, outputs):
|
||||
def validation_epoch_end(self, outputs):
|
||||
loss = some_loss()
|
||||
...
|
||||
|
||||
|
@ -54,7 +54,7 @@ To plot metrics into whatever logger you passed in (tensorboard, comet, neptune,
|
|||
results = {'log': logs}
|
||||
return results
|
||||
|
||||
def test_end(self, outputs):
|
||||
def test_epoch_end(self, outputs):
|
||||
loss = some_loss()
|
||||
...
|
||||
|
||||
|
@ -62,19 +62,7 @@ To plot metrics into whatever logger you passed in (tensorboard, comet, neptune,
|
|||
results = {'log': logs}
|
||||
return results
|
||||
|
||||
2. Most of the time, you only need training_step and not training_end. You can also return logs from here:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = some_loss()
|
||||
...
|
||||
|
||||
logs = {'train_loss': loss}
|
||||
results = {'log': logs}
|
||||
return results
|
||||
|
||||
3. In addition, you can also use any arbitrary functionality from a particular logger from within your LightningModule.
|
||||
2. In addition, you can also use any arbitrary functionality from a particular logger from within your LightningModule.
|
||||
For instance, here we log images using tensorboard.
|
||||
|
||||
.. code-block:: python
|
||||
|
|
|
@ -26,7 +26,7 @@ Training loop
|
|||
- on_batch_start
|
||||
- tbptt_split_batch
|
||||
- training_step
|
||||
- training_end (optional)
|
||||
- training_step_end (optional)
|
||||
- backward
|
||||
- on_after_backward
|
||||
- optimizer.step()
|
||||
|
|
|
@ -165,12 +165,13 @@ you will only be operating on one of those pieces.
|
|||
y_0 = batch
|
||||
|
||||
For most metrics, this doesn't really matter. However, if you want
|
||||
full batch statistics or want to use the outputs of the training_step
|
||||
to do something like a softmax, you can use the `training_end` step.
|
||||
to add something to your computational graph (like softmax)
|
||||
using all batch parts you can use the `training_step_end` step.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_end(self, outputs):
|
||||
def training_step_end(self, outputs):
|
||||
# only use when on dp
|
||||
outputs = torch.cat(outputs, dim=1)
|
||||
softmax = softmax(outputs, dim=1)
|
||||
out = softmax.mean()
|
||||
|
@ -195,9 +196,43 @@ In pseudocode, the full sequence is:
|
|||
out = gpu_model(batch_split)
|
||||
all_results.append(out)
|
||||
|
||||
# calculate statistics for all parts of the batch
|
||||
full out = model.training_end(all_results)
|
||||
# use the full batch for something like softmax
|
||||
full out = model.training_step_end(all_results)
|
||||
|
||||
to illustrate why this is needed, let's look at dataparallel
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.forward(batch)
|
||||
|
||||
# on dp or ddp2 if we did softmax now it would be wrong
|
||||
# because batch is actually a piece of the full batch
|
||||
return y_hat
|
||||
|
||||
def training_step_end(self, batch_parts_outputs):
|
||||
# batch_parts_outputs has outputs of each part of the batch
|
||||
|
||||
# do softmax here
|
||||
outputs = torch.cat(outputs, dim=1)
|
||||
softmax = softmax(outputs, dim=1)
|
||||
out = softmax.mean()
|
||||
|
||||
return out
|
||||
|
||||
If `training_step_end` is defined it will be called regardless of tpu, dp, ddp, etc... which means
|
||||
it will behave the same no matter the backend.
|
||||
|
||||
Validation and test step also have the same option when using dp
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def validation_step_end(self, batch_parts_outputs):
|
||||
...
|
||||
|
||||
def test_step_end(self, batch_parts_outputs):
|
||||
...
|
||||
|
||||
Implement Your Own Distributed (DDP) training
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -1,16 +1,73 @@
|
|||
"""
|
||||
A LightningModule is a strict superclass of torch.nn.Module but provides an interface to standardize
|
||||
the "ingredients" for a research or production system.
|
||||
A LightningModule organizes your PyTorch code into the following sections:
|
||||
|
||||
- The model/system definition (__init__)
|
||||
- The model/system computations (forward)
|
||||
- What happens in the training loop (training_step, training_end)
|
||||
- What happens in the validation loop (validation_step, validation_end)
|
||||
- What happens in the test loop (test_step, test_end)
|
||||
- What optimizers to use (configure_optimizers)
|
||||
- What data to use (train_dataloader, val_dataloader, test_dataloader)
|
||||
.. figure:: /_images/lightning_module/pt_to_pl.png
|
||||
:alt: Convert from PyTorch to Lightning
|
||||
|
||||
Most methods are optional. Here's a minimal example.
|
||||
|
||||
Notice a few things.
|
||||
|
||||
1. It's the SAME code.
|
||||
|
||||
2. The PyTorch code IS NOT abstracted - just organized.
|
||||
|
||||
3. All the other code that didn't go in the LightningModule has been automated
|
||||
for you by the trainer
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
net = Net()
|
||||
trainer = Trainer()
|
||||
trainer.fit(net)
|
||||
|
||||
4. There are no .cuda() or .to() calls... Lightning does these for you.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# don't do in lightning
|
||||
x = torch.Tensor(2, 3)
|
||||
x = x.cuda()
|
||||
x = x.to(device)
|
||||
|
||||
# do this instead
|
||||
x = x # leave it alone!
|
||||
|
||||
# or to init a new tensor
|
||||
new_x = torch.Tensor(2, 3)
|
||||
new_x = new_x.type_as(x.type())
|
||||
|
||||
|
||||
5. There are no samplers for distributed, Lightning also does this for you.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Don't do in Lightning...
|
||||
data = MNIST(...)
|
||||
sampler = DistributedSampler(data)
|
||||
DataLoader(data, sampler=sampler)
|
||||
|
||||
# do this instead
|
||||
data = MNIST(...)
|
||||
DataLoader(data)
|
||||
|
||||
|
||||
6. A LightingModule is a torch.nn.Module but with added functionality. Use it as such!
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
net = Net.load_from_checkpoint(PATH)
|
||||
net.freeze()
|
||||
out = net(x)
|
||||
|
||||
Thus, to use Lightning, you just need to organize your code which takes about 30 minutes,
|
||||
(and let's be real, you probably should do anyhow).
|
||||
|
||||
------------
|
||||
|
||||
Minimal Example
|
||||
---------------
|
||||
|
||||
Here are the only required methods.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -37,13 +94,13 @@ Most methods are optional. Here's a minimal example.
|
|||
y_hat = self.forward(x)
|
||||
return {'loss': F.cross_entropy(y_hat, y)}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True,
|
||||
transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
Which you can train by doing:
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -53,7 +110,35 @@ Which you can train by doing:
|
|||
|
||||
trainer.fit(model)
|
||||
|
||||
If you wanted to add a validation loop
|
||||
----------
|
||||
|
||||
Training loop structure
|
||||
-----------------------
|
||||
|
||||
The general pattern is that each loop (training, validation, test loop)
|
||||
has 2 methods:
|
||||
|
||||
- ``` ___step ```
|
||||
- ``` ___epoch_end```
|
||||
|
||||
To show how lightning calls these, let's use the validation loop as an example
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
val_outs = []
|
||||
for val_batch in val_data:
|
||||
# do something with each batch
|
||||
out = validation_step(val_batch)
|
||||
val_outs.append(out)
|
||||
|
||||
# do something with the outputs for all batches
|
||||
# like calculate validation set accuracy or loss
|
||||
validation_epoch_end(val_outs)
|
||||
|
||||
Add validation loop
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Thus, if we wanted to add a validation loop you would add this to your LightningModule
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -63,43 +148,153 @@ If you wanted to add a validation loop
|
|||
y_hat = self.forward(x)
|
||||
return {'val_loss': F.cross_entropy(y_hat, y)}
|
||||
|
||||
def validation_end(self, outputs):
|
||||
def validation_epoch_end(self, outputs):
|
||||
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
|
||||
return {'val_loss': val_loss_mean}
|
||||
|
||||
def val_dataloader(self):
|
||||
# can also return a list of val dataloaders
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True,
|
||||
transform=transforms.ToTensor()), batch_size=32)
|
||||
return DataLoader(...)
|
||||
|
||||
Or add a test loop
|
||||
Add test loop
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. code_block:: python
|
||||
.. code-block:: python
|
||||
|
||||
class CoolModel(pl.LightningModule):
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self.forward(x)
|
||||
return {'test_loss': F.cross_entropy(y_hat, y)}
|
||||
|
||||
def test_end(self, outputs):
|
||||
def test_epoch_end(self, outputs):
|
||||
test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean()
|
||||
return {'test_loss': test_loss_mean}
|
||||
|
||||
def test_dataloader(self):
|
||||
# OPTIONAL
|
||||
# can also return a list of test dataloaders
|
||||
return DataLoader(MNIST(os.getcwd(), train=False, download=True,
|
||||
transform=transforms.ToTensor()), batch_size=32)
|
||||
return DataLoader(...)
|
||||
|
||||
However, the test loop won't ever be called automatically to make sure you
|
||||
don't run your test data by accident. Instead you have to explicitly call:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# call after training
|
||||
trainer = Trainer()
|
||||
trainer.fit(model)
|
||||
trainer.test()
|
||||
|
||||
# or call with pretrained model
|
||||
model = MyLightningModule.load_from_checkpoint(PATH)
|
||||
trainer = Trainer()
|
||||
trainer.test(model)
|
||||
|
||||
-------------
|
||||
|
||||
Training_step_end method
|
||||
------------------------
|
||||
When using dataParallel or distributedDataParallel2, the training_step
|
||||
will be operating on a portion of the batch. This is normally ok but in special
|
||||
cases like calculating NCE loss using negative samples, we might want to
|
||||
perform a softmax across all samples in the batch.
|
||||
|
||||
For these types of situations, each loop has an additional ```__step_end``` method
|
||||
which allows you to operate on the pieces of the batch
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
training_outs = []
|
||||
for train_batch in train_data:
|
||||
# dp, ddp2 splits the batch
|
||||
sub_batches = split_batches_for_dp(batch)
|
||||
|
||||
# run training_step on each piece of the batch
|
||||
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
||||
|
||||
# do softmax with all pieces
|
||||
out = training_step_end(batch_parts_outputs)
|
||||
training_outs.append(out)
|
||||
|
||||
# do something with the outputs for all batches
|
||||
# like calculate validation set accuracy or loss
|
||||
training_epoch_end(val_outs)
|
||||
|
||||
-------------
|
||||
|
||||
Remove cuda calls
|
||||
-----------------
|
||||
In a LightningModule, all calls to ```.cuda()```
|
||||
and ```.to(device)``` should be removed. Lightning will do these
|
||||
automatically. This will allow your code to work on CPUs, TPUs and GPUs.
|
||||
|
||||
When you init a new tensor in your code, just use type_as
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
|
||||
# put the z on the appropriate gpu or tpu core
|
||||
z = sample_noise()
|
||||
z = z.type_as(x.type())
|
||||
|
||||
-------------
|
||||
|
||||
Data preparation
|
||||
----------------
|
||||
Data preparation in PyTorch follows 5 steps:
|
||||
|
||||
1. Download
|
||||
2. Clean and (maybe) save to disk
|
||||
3. Load inside dataset
|
||||
4. Apply transforms (rotate, tokenize, etc...)
|
||||
5. Wrap inside a dataloader
|
||||
|
||||
When working in distributed settings, steps 1 and 2 have to be done
|
||||
from a single GPU, otherwise you will overwrite these files from
|
||||
every GPU. The lightningModule has the ```prepare_data``` method to
|
||||
allow for this
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def prepare_data(self):
|
||||
# do stuff that writes to disk or should be done once
|
||||
# this will only happen from the master GPU or TPU core
|
||||
|
||||
.. note:: ```prepare_data``` is called once.
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
The methods in the LightningModule are called in this order:
|
||||
|
||||
1. ```__init__```
|
||||
2. ```prepare_data```
|
||||
3. ```configure_optimizers```
|
||||
4. ```prepare_data```
|
||||
5. ```train_dataloader```
|
||||
|
||||
If you define a validation loop then
|
||||
|
||||
6. ```val_dataloader```
|
||||
|
||||
And if you define a test loop:
|
||||
|
||||
7. ```test_dataloader```
|
||||
|
||||
.. note:: test_dataloader is only called with .test()
|
||||
|
||||
In every epoch, the loop methods are called in this frequency:
|
||||
|
||||
1. ```validation_step``` called every batch
|
||||
2. ```validation_epoch_end``` called every epoch
|
||||
|
||||
Live demo
|
||||
---------
|
||||
Check out this
|
||||
`COLAB <https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg>`_
|
||||
for a live demo.
|
||||
|
||||
.. note:: Remove all .cuda() or .to() calls from LightningModules. See:
|
||||
`the multi-gpu training guide for details <multi_gpu.rst>`_.
|
||||
|
||||
"""
|
||||
|
||||
from .decorators import data_loader
|
||||
|
|
|
@ -222,26 +222,40 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
"""
|
||||
|
||||
def training_end(self, *args, **kwargs):
|
||||
"""return loss, dict with metrics for tqdm
|
||||
"""
|
||||
.. warning:: Deprecated in v0.7.0. use training_step_end instead
|
||||
"""
|
||||
|
||||
:param outputs: What you return in `training_step`.
|
||||
def training_step_end(self, *args, **kwargs):
|
||||
"""
|
||||
Use this when training with dp or ddp2 because training_step will operate
|
||||
on only part of the batch. However, this is still optional
|
||||
and only needed for things like softmax or NCE loss.
|
||||
|
||||
.. note:: If you later switch to ddp or some other mode, this will still be called
|
||||
so that you don't have to change your code
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pseudocode
|
||||
sub_batches = split_batches_for_dp(batch)
|
||||
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
||||
training_step_end(batch_parts_outputs)
|
||||
|
||||
:param batch_parts_outputs: What you return in `training_step` for each batch part.
|
||||
:return dict: dictionary with loss key and optional log, progress keys:
|
||||
- loss -> tensor scalar [REQUIRED]
|
||||
- progress_bar -> Dict for progress bar display. Must have only tensors
|
||||
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
||||
|
||||
In certain cases (dp, ddp2), you might want to use all outputs of every process to do something.
|
||||
For instance, if using negative samples, you could run a batch via dp and use ALL the outputs
|
||||
for a single softmax across the full batch (ie: the denominator would use the full batch).
|
||||
|
||||
In this case you should define training_end to perform those calculations.
|
||||
In this case you should define training_step_end to perform those calculations.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# WITHOUT training_end
|
||||
# WITHOUT training_step_end
|
||||
# if used in DP or DDP2, this batch is 1/num_gpus large
|
||||
def training_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
|
@ -253,7 +267,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
return {'loss': loss}
|
||||
|
||||
# --------------
|
||||
# with training_end to do softmax over the full batch
|
||||
# with training_step_end to do softmax over the full batch
|
||||
def training_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
x, y = batch
|
||||
|
@ -261,48 +275,32 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
out = self.forward(x)
|
||||
return {'out': out}
|
||||
|
||||
def training_end(self, outputs):
|
||||
def training_step_end(self, outputs):
|
||||
# this out is now the full size of the batch
|
||||
out = outputs['out']
|
||||
|
||||
# this softmax now uses the full batch size
|
||||
loss = self.softmax(out)
|
||||
loss = nce_loss(loss)
|
||||
return {'loss': loss}
|
||||
|
||||
.. note:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
|
||||
|
||||
If you define multiple optimizers, this step will also be called with an additional `optimizer_idx` param.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Multiple optimizers (ie: GANs)
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
if optimizer_idx == 0:
|
||||
# do training_step with encoder
|
||||
if optimizer_idx == 1:
|
||||
# do training_step with decoder
|
||||
|
||||
If you add truncated back propagation through time you will also get an additional argument
|
||||
with the hidden states of the previous step.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Truncated back-propagation through time
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
# hiddens are the hiddens from the previous truncated backprop step
|
||||
|
||||
You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to
|
||||
break out of the current training epoch early.
|
||||
"""
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
r"""
|
||||
|
||||
This is the validation loop. It is called for each batch of the validation set.
|
||||
Whatever is returned from here will be passed in as a list on validation_end.
|
||||
Operate on a single batch of data from the validation set
|
||||
In this step you'd normally generate examples or calculate anything of interest such as accuracy.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# the pseudocode for these calls
|
||||
|
||||
val_outs = []
|
||||
for val_batch in val_data:
|
||||
out = validation_step(train_batch)
|
||||
val_outs.append(out
|
||||
validation_epoch_end(val_outs)
|
||||
|
||||
Args:
|
||||
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your dataloader.
|
||||
A tensor, tuple or list
|
||||
|
@ -311,7 +309,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
val datasets used)
|
||||
|
||||
Return:
|
||||
Dict or OrderedDict - passed to the validation_end step
|
||||
Dict or OrderedDict - passed to the validation_epoch_end
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -319,7 +317,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
def validation_step(self, batch, batch_idx)
|
||||
|
||||
# if you have multiple val dataloaders:
|
||||
def validation_step(self, batch, batch_idx, dataloader_idxdx)
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx)
|
||||
|
||||
Example
|
||||
-------
|
||||
|
@ -368,12 +366,175 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
|
||||
"""
|
||||
|
||||
def validation_step_end(self, *args, **kwargs):
|
||||
"""
|
||||
Use this when training with dp or ddp2 because training_step will operate
|
||||
on only part of the batch. However, this is still optional
|
||||
and only needed for things like softmax or NCE loss.
|
||||
|
||||
.. note:: If you later switch to ddp or some other mode, this will still be called
|
||||
so that you don't have to change your code
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pseudocode
|
||||
sub_batches = split_batches_for_dp(batch)
|
||||
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
||||
validation_step_end(batch_parts_outputs)
|
||||
|
||||
:param batch_parts_outputs: What you return in `training_step` for each batch part.
|
||||
:return dict: dictionary with loss key and optional log, progress keys:
|
||||
- loss -> tensor scalar [REQUIRED]
|
||||
- progress_bar -> Dict for progress bar display. Must have only tensors
|
||||
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
||||
|
||||
In this case you should define validation_step_end to perform those calculations.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# WITHOUT validation_step_end
|
||||
# if used in DP or DDP2, this batch is 1/num_gpus large
|
||||
def training_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
x, y = batch
|
||||
|
||||
out = self.forward(x)
|
||||
loss = self.softmax(out)
|
||||
loss = nce_loss(loss)
|
||||
return {'loss': loss}
|
||||
|
||||
# --------------
|
||||
# with validation_step_end to do softmax over the full batch
|
||||
def training_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
x, y = batch
|
||||
|
||||
out = self.forward(x)
|
||||
return {'out': out}
|
||||
|
||||
def validation_step_end(self, outputs):
|
||||
# this out is now the full size of the batch
|
||||
out = outputs['out']
|
||||
|
||||
# this softmax now uses the full batch size
|
||||
loss = nce_loss(loss)
|
||||
return {'loss': loss}
|
||||
|
||||
.. note:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
|
||||
"""
|
||||
|
||||
def validation_end(self, outputs):
|
||||
"""
|
||||
.. warning:: Deprecated in v0.7.0. use validation_epoch_end instead.
|
||||
Will be removed 1.0.0
|
||||
:param outputs:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
"""
|
||||
Called at end of validation epoch with the output of all validation_steps
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# the pseudocode for these calls
|
||||
|
||||
val_outs = []
|
||||
for val_batch in val_data:
|
||||
out = validation_step(train_batch)
|
||||
train_outs.append(out
|
||||
validation_epoch_end(val_outs)
|
||||
|
||||
Args:
|
||||
|
||||
outputs (list): List of outputs you defined in validation_step, or if there are multiple dataloaders,
|
||||
a list containing a list of outputs for each dataloader
|
||||
|
||||
Return:
|
||||
Dict or OrderedDict (dict): Dict has the following optional keys:
|
||||
progress_bar -> Dict for progress bar display. Must have only tensors
|
||||
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
||||
|
||||
.. note:: If you didn't define a validation_step, this won't be called.
|
||||
|
||||
- The outputs here are strictly for logging or progress bar.
|
||||
- If you don't need to display anything, don't return anything.
|
||||
- If you want to manually set current step, you can specify it with the 'step' key in the 'log' Dict.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
With a single dataloader
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
val_acc_mean = 0
|
||||
for output in outputs:
|
||||
val_acc_mean += output['val_acc']
|
||||
|
||||
val_acc_mean /= len(outputs)
|
||||
tqdm_dict = {'val_acc': val_acc_mean.item()}
|
||||
|
||||
# show val_loss and val_acc in progress bar but only log val_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'val_acc': val_acc_mean.item()}
|
||||
}
|
||||
return results
|
||||
|
||||
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
||||
one entry per dataloader, while the inner list contains the individual outputs of
|
||||
each validation step for that dataloader.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
val_acc_mean = 0
|
||||
i = 0
|
||||
for dataloader_outputs in outputs:
|
||||
for output in dataloader_outputs:
|
||||
val_acc_mean += output['val_acc']
|
||||
i += 1
|
||||
|
||||
val_acc_mean /= i
|
||||
tqdm_dict = {'val_acc': val_acc_mean.item()}
|
||||
|
||||
# show val_loss and val_acc in progress bar but only log val_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'val_acc': val_acc_mean.item(), 'step': self.current_epoch}
|
||||
}
|
||||
return results
|
||||
"""
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
"""return whatever outputs will need to be aggregated in test_end
|
||||
:param batch: The output of your dataloader. A tensor, tuple or list
|
||||
:param int batch_idx: Integer displaying which batch this is
|
||||
:param int dataloader_idx: Integer displaying which dataloader this is (only if multiple test datasets used)
|
||||
:return dict: Dict or OrderedDict with metrics to display in progress bar. All keys must be tensors.
|
||||
r"""
|
||||
Operate on a single batch of data from the test set
|
||||
In this step you'd normally generate examples or calculate anything of interest such as accuracy.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# the pseudocode for these calls
|
||||
|
||||
test_outs = []
|
||||
for test_batch in test_data:
|
||||
out = test_step(train_batch)
|
||||
test_outs.append(out
|
||||
test_epoch_end(test_outs)
|
||||
|
||||
Args:
|
||||
batch (torch.nn.Tensor | (Tensor, Tensor) | [Tensor, Tensor]): The output of your dataloader.
|
||||
A tensor, tuple or list
|
||||
batch_idx (int): The index of this batch
|
||||
dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple
|
||||
test datasets used)
|
||||
|
||||
Return:
|
||||
Dict or OrderedDict - passed to the test_epoch_end
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -381,21 +542,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
def test_step(self, batch, batch_idx)
|
||||
|
||||
# if you have multiple test dataloaders:
|
||||
def test_step(self, batch, batch_idx, dataloader_idxdx)
|
||||
|
||||
|
||||
**OPTIONAL**
|
||||
If you don't need to test you don't need to implement this method.
|
||||
In this step you'd normally generate examples or
|
||||
calculate anything of interest such as accuracy.
|
||||
|
||||
When the validation_step is called, the model has been put in eval mode
|
||||
and PyTorch gradients have been disabled.
|
||||
At the end of validation, model goes back to training mode and gradients are enabled.
|
||||
|
||||
The dict you return here will be available in the `test_end` method.
|
||||
|
||||
This function is used when you execute `trainer.test()`.
|
||||
def test_step(self, batch, batch_idx, dataloader_idx)
|
||||
|
||||
Example
|
||||
-------
|
||||
|
@ -410,50 +557,136 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
out = self.forward(x)
|
||||
loss = self.loss(out, y)
|
||||
|
||||
# log 6 example images
|
||||
# or generated text... or whatever
|
||||
sample_imgs = x[:6]
|
||||
grid = torchvision.utils.make_grid(sample_imgs)
|
||||
self.logger.experiment.add_image('example_images', grid, 0)
|
||||
|
||||
# calculate acc
|
||||
labels_hat = torch.argmax(out, dim=1)
|
||||
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
|
||||
# all optional...
|
||||
# return whatever you need for the collation function test_end
|
||||
# return whatever you need for the collation function validation_end
|
||||
output = OrderedDict({
|
||||
'test_loss': loss_test,
|
||||
'test_acc': torch.tensor(test_acc), # everything must be a tensor
|
||||
'val_loss': loss_val,
|
||||
'val_acc': torch.tensor(val_acc), # everything must be a tensor
|
||||
})
|
||||
|
||||
# return an optional dict
|
||||
return output
|
||||
|
||||
|
||||
If you pass in multiple test datasets, `test_step` will have an additional argument.
|
||||
If you pass in multiple validation datasets, validation_step will have an additional argument.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# CASE 2: multiple test datasets
|
||||
# CASE 2: multiple validation datasets
|
||||
def test_step(self, batch, batch_idx, dataset_idx):
|
||||
# dataset_idx tells you which dataset this is.
|
||||
|
||||
.. note:: If you don't need to validate you don't need to implement this method.
|
||||
|
||||
The `dataset_idx` corresponds to the order of datasets returned in `test_dataloader`.
|
||||
.. note:: When the validation_step is called, the model has been put in eval mode and PyTorch gradients
|
||||
have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
|
||||
"""
|
||||
|
||||
def validation_end(self, outputs):
|
||||
"""Outputs has the appended output after each validation step.
|
||||
def test_step_end(self, *args, **kwargs):
|
||||
"""
|
||||
Use this when training with dp or ddp2 because training_step will operate
|
||||
on only part of the batch. However, this is still optional
|
||||
and only needed for things like softmax or NCE loss.
|
||||
|
||||
:param outputs: List of outputs you defined in validation_step, or if there are multiple dataloaders,
|
||||
a list containing a list of outputs for each dataloader
|
||||
:return dict: Dictionary or OrderedDict with optional:
|
||||
.. note:: If you later switch to ddp or some other mode, this will still be called
|
||||
so that you don't have to change your code
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# pseudocode
|
||||
sub_batches = split_batches_for_dp(batch)
|
||||
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
||||
test_step_end(batch_parts_outputs)
|
||||
|
||||
:param batch_parts_outputs: What you return in `training_step` for each batch part.
|
||||
:return dict: dictionary with loss key and optional log, progress keys:
|
||||
- loss -> tensor scalar [REQUIRED]
|
||||
- progress_bar -> Dict for progress bar display. Must have only tensors
|
||||
- log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
||||
|
||||
In this case you should define test_step_end to perform those calculations.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# WITHOUT test_step_end
|
||||
# if used in DP or DDP2, this batch is 1/num_gpus large
|
||||
def training_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
x, y = batch
|
||||
|
||||
out = self.forward(x)
|
||||
loss = self.softmax(out)
|
||||
loss = nce_loss(loss)
|
||||
return {'loss': loss}
|
||||
|
||||
# --------------
|
||||
# with test_step_end to do softmax over the full batch
|
||||
def training_step(self, batch, batch_idx):
|
||||
# batch is 1/num_gpus big
|
||||
x, y = batch
|
||||
|
||||
out = self.forward(x)
|
||||
return {'out': out}
|
||||
|
||||
def test_step_end(self, outputs):
|
||||
# this out is now the full size of the batch
|
||||
out = outputs['out']
|
||||
|
||||
# this softmax now uses the full batch size
|
||||
loss = nce_loss(loss)
|
||||
return {'loss': loss}
|
||||
|
||||
.. note:: see the `multi-gpu guide for more details <multi_gpu.rst#caveats>`_.
|
||||
"""
|
||||
|
||||
def test_end(self, outputs):
|
||||
"""
|
||||
.. warning:: Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0
|
||||
:param outputs:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
"""
|
||||
Called at end of test epoch with the output of all test_steps
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# the pseudocode for these calls
|
||||
|
||||
test_outs = []
|
||||
for test_batch in test_data:
|
||||
out = test_step(test_batch)
|
||||
test_outs.append(out)
|
||||
test_epoch_end(test_outs)
|
||||
|
||||
Args:
|
||||
|
||||
outputs (list): List of outputs you defined in test_step, or if there are multiple dataloaders,
|
||||
a list containing a list of outputs for each dataloader
|
||||
|
||||
Return:
|
||||
Dict or OrderedDict (dict): Dict has the following optional keys:
|
||||
progress_bar -> Dict for progress bar display. Must have only tensors
|
||||
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
|
||||
|
||||
If you didn't define a validation_step, this won't be called.
|
||||
Called at the end of the validation loop with the outputs of validation_step.
|
||||
.. note:: If you didn't define a test_step, this won't be called.
|
||||
|
||||
The outputs here are strictly for the progress bar.
|
||||
If you don't need to display anything, don't return anything.
|
||||
Any keys present in 'log', 'progress_bar' or the rest of the dictionary
|
||||
are available for callbacks to access. If you want to manually set current step, you can specify it with
|
||||
'step' key in the 'log' Dict.
|
||||
- The outputs here are strictly for logging or progress bar.
|
||||
- If you don't need to display anything, don't return anything.
|
||||
- If you want to manually set current step, you can specify it with the 'step' key in the 'log' Dict.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
@ -462,120 +695,48 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
def validation_end(self, outputs):
|
||||
val_loss_mean = 0
|
||||
val_acc_mean = 0
|
||||
for output in outputs:
|
||||
val_loss_mean += output['val_loss']
|
||||
val_acc_mean += output['val_acc']
|
||||
|
||||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
|
||||
# show val_loss and val_acc in progress bar but only log val_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'val_loss': val_loss_mean.item()}
|
||||
}
|
||||
return results
|
||||
|
||||
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
||||
one entry per dataloader, while the inner list contains the individual outputs of
|
||||
each validation step for that dataloader.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def validation_end(self, outputs):
|
||||
val_loss_mean = 0
|
||||
val_acc_mean = 0
|
||||
i = 0
|
||||
for dataloader_outputs in outputs:
|
||||
for output in dataloader_outputs:
|
||||
val_loss_mean += output['val_loss']
|
||||
val_acc_mean += output['val_acc']
|
||||
i += 1
|
||||
|
||||
val_loss_mean /= i
|
||||
val_acc_mean /= i
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
|
||||
# show val_loss and val_acc in progress bar but only log val_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'val_loss': val_loss_mean.item(), 'step': self.current_epoch}
|
||||
}
|
||||
return results
|
||||
|
||||
"""
|
||||
|
||||
def test_end(self, outputs):
|
||||
"""Outputs has the appended output after each test step.
|
||||
|
||||
:param outputs: List of outputs you defined in test_step, or if there are multiple dataloaders,
|
||||
a list containing a list of outputs for each dataloader
|
||||
:return dict: Dict of OrderedDict with metrics to display in progress bar
|
||||
|
||||
If you didn't define a test_step, this won't be called.
|
||||
Called at the end of the test step with the output of each test_step.
|
||||
The outputs here are strictly for the progress bar.
|
||||
If you don't need to display anything, don't return anything.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def test_end(self, outputs):
|
||||
test_loss_mean = 0
|
||||
def test_epoch_end(self, outputs):
|
||||
test_acc_mean = 0
|
||||
for output in outputs:
|
||||
test_loss_mean += output['test_loss']
|
||||
test_acc_mean += output['test_acc']
|
||||
|
||||
test_loss_mean /= len(outputs)
|
||||
test_acc_mean /= len(outputs)
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
tqdm_dict = {'test_acc': test_acc_mean.item()}
|
||||
|
||||
# show test_loss and test_acc in progress bar but only log test_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'test_loss': val_loss_mean.item()}
|
||||
'log': {'test_acc': test_acc_mean.item()}
|
||||
}
|
||||
return results
|
||||
|
||||
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
||||
one entry per dataloader, while the inner list contains the individual outputs of
|
||||
each validation step for that dataloader.
|
||||
each test step for that dataloader.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def test_end(self, outputs):
|
||||
test_loss_mean = 0
|
||||
def test_epoch_end(self, outputs):
|
||||
test_acc_mean = 0
|
||||
i = 0
|
||||
for dataloader_outputs in outputs:
|
||||
for output in dataloader_outputs:
|
||||
test_loss_mean += output['test_loss']
|
||||
test_acc_mean += output['test_acc']
|
||||
i += 1
|
||||
|
||||
test_loss_mean /= i
|
||||
test_acc_mean /= i
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
tqdm_dict = {'test_acc': test_acc_mean.item()}
|
||||
|
||||
# show test_loss and test_acc in progress bar but only log test_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'test_loss': val_loss_mean.item()}
|
||||
'log': {'test_acc': test_acc_mean.item(), 'step': self.current_epoch}
|
||||
}
|
||||
return results
|
||||
|
||||
"""
|
||||
|
||||
def configure_ddp(self, model, device_ids):
|
||||
r"""
|
||||
|
||||
Override to init DDP in your own way or with your own wrapper.
|
||||
The only requirements are that:
|
||||
|
||||
|
@ -980,14 +1141,13 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
return None
|
||||
|
||||
@data_loader
|
||||
def tng_dataloader(self): # todo: remove in v0.8.0
|
||||
def tng_dataloader(self): # todo: remove in v1.0.0
|
||||
"""Implement a PyTorch DataLoader.
|
||||
|
||||
.. warning:: Deprecated in v0.5.0. use train_dataloader instead.
|
||||
.. warning:: Deprecated in v0.5.0. use train_dataloader instead. Will be removed 1.0.0
|
||||
"""
|
||||
output = self.train_dataloader()
|
||||
warnings.warn("`tng_dataloader` has been renamed to `train_dataloader` since v0.5.0."
|
||||
" and this method will be removed in v0.8.0", DeprecationWarning)
|
||||
" and this method will be removed in v1.0.0", DeprecationWarning)
|
||||
return output
|
||||
|
||||
def test_dataloader(self):
|
||||
|
|
|
@ -16,7 +16,7 @@ PyTorch Lightning supports profiling standard actions in the training loop out o
|
|||
- on_after_backward
|
||||
- optimizer_step
|
||||
- on_batch_end
|
||||
- training_end
|
||||
- training_step_end
|
||||
- on_training_end
|
||||
|
||||
Enable simple profiling
|
||||
|
|
|
@ -1,17 +1,24 @@
|
|||
"""
|
||||
Once you've organized your PyTorch code into a LightningModule,
|
||||
the Trainer automates everything else.
|
||||
|
||||
The trainer de-couples the engineering code (16-bit, early stopping, GPU distribution, etc...) from the
|
||||
science code (GAN, BERT, your project, etc...). It uses many assumptions which are best practices in
|
||||
AI research today.
|
||||
.. figure:: /_images/lightning_module/pt_trainer.png
|
||||
:alt: Convert from PyTorch to Lightning
|
||||
|
||||
The trainer automates all parts of training except:
|
||||
This abstraction achieves the folowing:
|
||||
|
||||
- what happens in training , test, val loop
|
||||
- where the data come from
|
||||
- which optimizers to use
|
||||
- how to do the computations
|
||||
1. You maintain control over all aspects via PyTorch
|
||||
code without an added abstraction.
|
||||
|
||||
The Trainer delegates those calls to your LightningModule which defines how to do those parts.
|
||||
2. The trainer uses best practices embedded by contributors and users
|
||||
from top AI labs such as Facebook AI Research, NYU, MIT, Stanford, etc...
|
||||
|
||||
3. The trainer allows overriding any key part that you don't want automated.
|
||||
|
||||
-----------
|
||||
|
||||
Basic use
|
||||
---------
|
||||
|
||||
This is the basic use of the trainer:
|
||||
|
||||
|
@ -23,6 +30,763 @@ This is the basic use of the trainer:
|
|||
|
||||
trainer = Trainer()
|
||||
trainer.fit(model)
|
||||
|
||||
--------
|
||||
|
||||
Best Practices
|
||||
--------------
|
||||
For cluster computing, it's recommended you structure your
|
||||
main.py file this way
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from argparser import AugumentParser
|
||||
|
||||
def main(hparams):
|
||||
model = LightningModule()
|
||||
trainer = Trainer(gpus=hparams.gpus)
|
||||
trainer.fit(model)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--gpus', default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
So you can run it like so:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python main.py --gpus 2
|
||||
|
||||
------------
|
||||
|
||||
Testing
|
||||
-------
|
||||
Once you're done training, feel free to run the test set!
|
||||
(Only right before publishing your paper or pushing to production)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer.test()
|
||||
|
||||
------------
|
||||
|
||||
Deployment / prediction
|
||||
-----------------------
|
||||
You just trained a LightningModule which is also just a torch.nn.Module.
|
||||
Use it to do whatever!
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# load model
|
||||
pretrained_model = LightningModule.load_from_checkpoint(PATH)
|
||||
pretrained_model.freeze()
|
||||
|
||||
# use it for finetuning
|
||||
def forward(self, x):
|
||||
features = pretrained_model(x)
|
||||
classes = classifier(features)
|
||||
|
||||
# or for prediction
|
||||
out = pretrained_model(x)
|
||||
api_write({'response': out}
|
||||
|
||||
-------
|
||||
|
||||
Trainer flags
|
||||
-------------
|
||||
|
||||
logger
|
||||
^^^^^^
|
||||
|
||||
Logger (or iterable collection of loggers) for experiment tracking.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Trainer(logger=logger)
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
||||
# default logger used by trainer
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=os.getcwd(),
|
||||
version=self.slurm_job_id,
|
||||
name='lightning_logs'
|
||||
)
|
||||
|
||||
|
||||
checkpoint_callback
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
Callback for checkpointing.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
# default used by the Trainer
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.getcwd(),
|
||||
save_best_only=True,
|
||||
verbose=True,
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
prefix=''
|
||||
)
|
||||
|
||||
early_stop_callback
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Callback for early stopping.
|
||||
early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`)
|
||||
|
||||
- If set to ``True``, then a default callback monitoring ``'val_loss'`` is created.
|
||||
- Will raise an error if ``'val_loss'`` is not found.
|
||||
- If set to ``False``, then early stopping will be disabled.
|
||||
- If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
|
||||
- If ``'val_loss'`` is not found will work as if early stopping is disabled.
|
||||
- Default: ``None``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(early_stop_callback=early_stop_callback)
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
|
||||
# default used by the Trainer
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=3,
|
||||
strict=False,
|
||||
verbose=False,
|
||||
mode='min'
|
||||
)
|
||||
|
||||
callbacks
|
||||
^^^^^^^^^
|
||||
|
||||
callbacks: Add a list of callbacks.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# a list of callbacks
|
||||
callbacks = [PrintCallback()]
|
||||
trainer = Trainer(callbacks=callbacks)
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
class PrintCallback(Callback):
|
||||
def on_train_start(self):
|
||||
print("Training is started!")
|
||||
def on_train_end(self):
|
||||
print(f"Training is done. The logs are: {self.trainer.logs}")
|
||||
|
||||
default_save_path
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
Default path for logs and weights when no logger/ckpt_callback passed
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(default_save_path=os.getcwd())
|
||||
|
||||
gradient_clip_val
|
||||
^^^^^^^^^^^^^^^^^
|
||||
Gradient clipping value
|
||||
|
||||
- 0 means don't clip.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(gradient_clip_val=0.0)
|
||||
|
||||
gradient_clip
|
||||
.. warning: .. deprecated:: 0.5.0
|
||||
Use `gradient_clip_val` instead. Will remove 0.8.0.
|
||||
|
||||
process_position
|
||||
^^^^^^^^^^^^^^^^
|
||||
orders the tqdm bar when running multiple models on same machine.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(process_position=0)
|
||||
|
||||
num_nodes
|
||||
^^^^^^^^^
|
||||
|
||||
Number of GPU nodes for distributed training.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(num_nodes=1)
|
||||
|
||||
# to train on 8 nodes
|
||||
trainer = Trainer(num_nodes=8)
|
||||
|
||||
nb_gpu_nodes
|
||||
|
||||
..warning:: .. deprecated:: 0.5.0
|
||||
Use `num_nodes` instead. Will remove 0.8.0.
|
||||
|
||||
gpus
|
||||
^^^^
|
||||
|
||||
- Number of GPUs to train on
|
||||
- or Which GPUs to train on
|
||||
- can handle strings
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer (ie: train on CPU)
|
||||
trainer = Trainer(gpus=None)
|
||||
|
||||
# int: train on 2 gpus
|
||||
trainer = Trainer(gpus=2)
|
||||
|
||||
# list: train on GPUs 1, 4 (by bus ordering)
|
||||
trainer = Trainer(gpus=[1, 4])
|
||||
trainer = Trainer(gpus='1, 4') # equivalent
|
||||
|
||||
# -1: train on all gpus
|
||||
trainer = Trainer(gpus=-1)
|
||||
trainer = Trainer(gpus='-1') # equivalent
|
||||
|
||||
# combine with num_nodes to train on multiple GPUs across nodes
|
||||
trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total
|
||||
|
||||
num_tpu_cores
|
||||
^^^^^^^^^^^^^
|
||||
How many TPU cores to train on (1 or 8).
|
||||
|
||||
A single TPU v2 or v3 has 8 cores. A TPU pod has
|
||||
up to 2048 cores. A slice of a POD means you get as many cores
|
||||
as you request.
|
||||
|
||||
You MUST use DistributedDataSampler with your dataloader for this
|
||||
to work. Your effective batch size is batch_size * total tpu cores.
|
||||
|
||||
This parameter can be either 1 or 8.
|
||||
|
||||
Example::
|
||||
|
||||
# your_trainer_file.py
|
||||
|
||||
# default used by the Trainer (ie: train on CPU)
|
||||
trainer = Trainer(num_tpu_cores=None)
|
||||
|
||||
# int: train on a single core
|
||||
trainer = Trainer(num_tpu_cores=1)
|
||||
|
||||
# int: train on all cores few cores
|
||||
trainer = Trainer(num_tpu_cores=8)
|
||||
|
||||
# for 8+ cores must submit via xla script with
|
||||
# a max of 8 cores specified. The XLA script
|
||||
# will duplicate script onto each TPU in the POD
|
||||
trainer = Trainer(num_tpu_cores=8)
|
||||
|
||||
# -1: train on all available TPUs
|
||||
trainer = Trainer(num_tpu_cores=-1)
|
||||
|
||||
To train on more than 8 cores (ie: a POD),
|
||||
submit this script using the xla_dist script.
|
||||
|
||||
Example::
|
||||
|
||||
$ python -m torch_xla.distributed.xla_dist
|
||||
--tpu=$TPU_POD_NAME
|
||||
--conda-env=torch-xla-nightly
|
||||
--env=XLA_USE_BF16=1
|
||||
-- python your_trainer_file.py
|
||||
|
||||
log_gpu_memory
|
||||
^^^^^^^^^^^^^^
|
||||
Options:
|
||||
|
||||
- None
|
||||
- 'min_max'
|
||||
- 'all'
|
||||
|
||||
.. note:: Might slow performance because it uses the output of nvidia-smi.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(log_gpu_memory=None)
|
||||
|
||||
# log all the GPUs (on master node only)
|
||||
trainer = Trainer(log_gpu_memory='all')
|
||||
|
||||
# log only the min and max memory on the master node
|
||||
trainer = Trainer(log_gpu_memory='min_max')
|
||||
|
||||
show_progress_bar
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
If true shows tqdm progress bar
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(show_progress_bar=True)
|
||||
|
||||
progress_bar_refresh_rate
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
How often to refresh progress bar (in steps)
|
||||
|
||||
overfit_pct
|
||||
^^^^^^^^^^^
|
||||
uses this much data of all datasets.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(overfit_pct=0.0)
|
||||
|
||||
# use only 1% of the train, test, val datasets
|
||||
trainer = Trainer(overfit_pct=0.01)
|
||||
|
||||
track_grad_norm
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
- no tracking (-1)
|
||||
- Otherwise tracks that norm (2 for 2-norm)
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(track_grad_norm=-1)
|
||||
|
||||
# track the 2-norm
|
||||
trainer = Trainer(track_grad_norm=2)
|
||||
|
||||
check_val_every_n_epoch
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Check val every n train epochs.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(check_val_every_n_epoch=1)
|
||||
|
||||
# run val loop every 10 training epochs
|
||||
trainer = Trainer(check_val_every_n_epoch=10)
|
||||
|
||||
fast_dev_run
|
||||
^^^^^^^^^^^^
|
||||
|
||||
Runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
|
||||
|
||||
Under the hood the pseudocode looks like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# loading
|
||||
__init__()
|
||||
prepare_data
|
||||
|
||||
# test training step
|
||||
training_batch = next(train_dataloader)
|
||||
training_step(training_batch)
|
||||
|
||||
# test val step
|
||||
val_batch = next(val_dataloader)
|
||||
out = validation_step(val_batch)
|
||||
validation_epoch_end([out])
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(fast_dev_run=False)
|
||||
|
||||
# runs 1 train, val, test batch and program ends
|
||||
trainer = Trainer(fast_dev_run=True)
|
||||
|
||||
accumulate_grad_batches
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Accumulates grads every k batches or as set up in the dict.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer (no accumulation)
|
||||
trainer = Trainer(accumulate_grad_batches=1)
|
||||
|
||||
# accumulate every 4 batches (effective batch size is batch*4)
|
||||
trainer = Trainer(accumulate_grad_batches=4)
|
||||
|
||||
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
|
||||
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
|
||||
|
||||
max_epochs
|
||||
^^^^^^^^^^
|
||||
Stop training once this number of epochs is reached
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(max_epochs=1000)
|
||||
|
||||
max_nb_epochs
|
||||
|
||||
.. warning:: .. deprecated:: 0.5.0
|
||||
Use `max_epochs` instead. Will remove 0.8.0.
|
||||
|
||||
min_epochs
|
||||
^^^^^^^^^^
|
||||
Force training for at least these many epochs
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(min_epochs=1)
|
||||
|
||||
min_nb_epochs:
|
||||
.. warning:: .. deprecated:: 0.5.0
|
||||
Use `min_nb_epochs` instead. Will remove 0.8.0.
|
||||
|
||||
max_steps
|
||||
^^^^^^^^^
|
||||
Stop training after this number of steps. Disabled by default (None).
|
||||
Training will stop if max_steps or max_epochs have reached (earliest).
|
||||
|
||||
Example::
|
||||
|
||||
# Stop after 100 steps
|
||||
trainer = Trainer(max_steps=100)
|
||||
|
||||
min_steps
|
||||
^^^^^^^^^
|
||||
|
||||
Force training for at least these number of steps. Disabled by default (None).
|
||||
Trainer will train model for at least min_steps or min_epochs (latest).
|
||||
|
||||
Example::
|
||||
|
||||
# Run at least for 100 steps (disable min_epochs)
|
||||
trainer = Trainer(min_steps=100, min_epochs=0)
|
||||
|
||||
train_percent_check
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
How much of training dataset to check.
|
||||
Useful when debugging or testing something that happens at the end of an epoch.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(train_percent_check=1.0)
|
||||
|
||||
# run through only 25% of the training set each epoch
|
||||
trainer = Trainer(train_percent_check=0.25)
|
||||
|
||||
val_percent_check
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
How much of validation dataset to check.
|
||||
Useful when debugging or testing something that happens at the end of an epoch.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(val_percent_check=1.0)
|
||||
|
||||
# run through only 25% of the validation set each epoch
|
||||
trainer = Trainer(val_percent_check=0.25)
|
||||
|
||||
test_percent_check
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
How much of test dataset to check.
|
||||
Useful when debugging or testing something that happens at the end of an epoch.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(test_percent_check=1.0)
|
||||
|
||||
# run through only 25% of the test set each epoch
|
||||
trainer = Trainer(test_percent_check=0.25)
|
||||
|
||||
val_check_interval:
|
||||
How often within one training epoch to check the validation set.
|
||||
Can specify as float or int.
|
||||
|
||||
- use (float) to check within a training epoch
|
||||
- use (int) to check every n steps (batches)
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(val_check_interval=1.0)
|
||||
|
||||
# check validation set 4 times during a training epoch
|
||||
trainer = Trainer(val_check_interval=0.25)
|
||||
|
||||
# check validation set every 1000 training batches
|
||||
# use this when using iterableDataset and your dataset has no length
|
||||
# (ie: production cases with streaming data)
|
||||
trainer = Trainer(val_check_interval=1000)
|
||||
|
||||
log_save_interval
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
Writes logs to disk this often
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(log_save_interval=100)
|
||||
|
||||
row_log_interval
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
How often to add logging rows (does not write to disk)
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(row_log_interval=10)
|
||||
|
||||
add_row_log_interval
|
||||
.. warning:: .. deprecated:: 0.5.0
|
||||
Use `row_log_interval` instead. Will remove 0.8.0.
|
||||
|
||||
distributed_backend
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
The distributed backend to use.
|
||||
|
||||
- ('dp') is DataParallel (split batch among GPUs of same machine)
|
||||
- ('ddp') is DistributedDataParallel (each gpu on each node trains, and syncs grads)
|
||||
- ('ddp2') dp on node, ddp across nodes
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(distributed_backend=None)
|
||||
|
||||
# dp = DataParallel (split a batch onto k gpus on same machine).
|
||||
trainer = Trainer(gpus=2, distributed_backend='dp')
|
||||
|
||||
# ddp = DistributedDataParallel
|
||||
# Each gpu trains by itself on a subset of the data.
|
||||
# Gradients sync across all gpus and all machines.
|
||||
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
|
||||
|
||||
# ddp2 = DistributedDataParallel + dp
|
||||
# behaves like dp on every node
|
||||
# syncs gradients across nodes like ddp
|
||||
# useful for things like increasing the number of negative samples
|
||||
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
|
||||
|
||||
use_amp:
|
||||
.. warning:: .. deprecated:: 0.6.1
|
||||
Use `precision` instead. Will remove 0.8.0.
|
||||
|
||||
precision
|
||||
^^^^^^^^^
|
||||
Full precision (32), half precision (16).
|
||||
Can be used on CPU, GPU or TPUs.
|
||||
|
||||
If used on TPU will use torch.bfloat16 but tensor printing
|
||||
will still show torch.float32.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(precision=32)
|
||||
|
||||
# 16-bit precision
|
||||
trainer = Trainer(precision=16)
|
||||
|
||||
# one day
|
||||
trainer = Trainer(precision=8|4|2)
|
||||
|
||||
print_nan_grads
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
Prints gradients with nan values
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(print_nan_grads=False)
|
||||
|
||||
weights_summary
|
||||
^^^^^^^^^^^^^^^
|
||||
Prints a summary of the weights when training begins.
|
||||
Options: 'full', 'top', None.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer (ie: print all weights)
|
||||
trainer = Trainer(weights_summary='full')
|
||||
|
||||
# print only the top level modules
|
||||
trainer = Trainer(weights_summary='top')
|
||||
|
||||
# don't print a summary
|
||||
trainer = Trainer(weights_summary=None)
|
||||
|
||||
weights_save_path
|
||||
^^^^^^^^^^^^^^^^^
|
||||
Where to save weights if specified.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(weights_save_path=os.getcwd())
|
||||
|
||||
# save to your custom path
|
||||
trainer = Trainer(weights_save_path='my/path')
|
||||
|
||||
# if checkpoint callback used, then overrides the weights path
|
||||
# **NOTE: this saves weights to some/path NOT my/path
|
||||
checkpoint_callback = ModelCheckpoint(filepath='some/path')
|
||||
trainer = Trainer(
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
weights_save_path='my/path'
|
||||
)
|
||||
|
||||
amp_level
|
||||
^^^^^^^^^
|
||||
The optimization level to use (O1, O2, etc...)
|
||||
for 16-bit GPU precision (using NVIDIA apex under the hood).
|
||||
|
||||
Check nvidia docs for level (https://nvidia.github.io/apex/amp.html#opt-levels)
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(amp_level='O1')
|
||||
|
||||
num_sanity_val_steps
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Sanity check runs n batches of val before starting the training routine.
|
||||
This catches any bugs in your validation without having to wait for the first validation check.
|
||||
The Trainer uses 5 steps by default. Turn it off or modify it here.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(num_sanity_val_steps=5)
|
||||
|
||||
# turn it off
|
||||
trainer = Trainer(num_sanity_val_steps=0)
|
||||
|
||||
nb_sanity_val_steps:
|
||||
.. warning:: .. deprecated:: 0.5.0
|
||||
Use `num_sanity_val_steps` instead. Will remove 0.8.0.
|
||||
|
||||
truncated_bptt_steps
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Truncated back prop breaks performs backprop every k steps of
|
||||
a much longer sequence If this is enabled, your batches will automatically get truncated
|
||||
and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence
|
||||
dimension. (`Williams et al. "An efficient gradient-based algorithm for on-line training of
|
||||
recurrent network trajectories."
|
||||
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer (ie: disabled)
|
||||
trainer = Trainer(truncated_bptt_steps=None)
|
||||
|
||||
# backprop every 5 steps in a batch
|
||||
trainer = Trainer(truncated_bptt_steps=5)
|
||||
|
||||
|
||||
Lightning takes care to split your batch along the time-dimension.
|
||||
|
||||
.. note:: If you need to modify how the batch is split,
|
||||
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
|
||||
|
||||
.. note:: Using this feature requires updating your LightningModule's
|
||||
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
|
||||
|
||||
resume_from_checkpoint
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
To resume training from a specific checkpoint pass in the path here.k
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(resume_from_checkpoint=None)
|
||||
|
||||
# resume from a specific checkpoint
|
||||
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
|
||||
|
||||
profiler
|
||||
^^^^^^^^
|
||||
To profile individual steps during training and assist in identifying bottlenecks.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(profiler=None)
|
||||
|
||||
# to profile standard training events
|
||||
trainer = Trainer(profiler=True)
|
||||
|
||||
# equivalent to profiler=True
|
||||
profiler = Profiler()
|
||||
trainer = Trainer(profiler=profiler)
|
||||
|
||||
# advanced profiler for function-level stats
|
||||
profiler = AdvancedProfiler()
|
||||
trainer = Trainer(profiler=profiler)
|
||||
|
||||
|
||||
reload_dataloaders_every_epoch
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Set to True to reload dataloaders every epoch
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# if False (default)
|
||||
train_loader = model.train_dataloader()
|
||||
for epoch in epochs:
|
||||
for batch in train_loader:
|
||||
...
|
||||
|
||||
# if True
|
||||
for epoch in epochs:
|
||||
train_loader = model.train_dataloader()
|
||||
for batch in train_loader:
|
||||
|
||||
benchmark
|
||||
^^^^^^^^^
|
||||
|
||||
If true enables cudnn.benchmark.
|
||||
This flag is likely to increase the speed of your system if your
|
||||
input sizes don't change. However, if it does, then it will likely
|
||||
make your system slower.
|
||||
|
||||
The speedup comes from allowing the cudnn auto-tuner to find the best
|
||||
algorithm for the hardware `[see discussion here]
|
||||
<https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`_.
|
||||
|
||||
"""
|
||||
|
||||
from .trainer import Trainer
|
||||
|
|
|
@ -131,6 +131,7 @@ from abc import ABC, abstractmethod
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
import warnings
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
|
@ -260,6 +261,18 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# -----------------
|
||||
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
|
||||
|
||||
# on dp / ddp2 might still want to do something with the batch parts
|
||||
if test_mode:
|
||||
if self.is_overriden('test_step_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('test_step_end'):
|
||||
output = model_ref.test_step_end(output)
|
||||
else:
|
||||
if self.is_overriden('validation_step_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('validation_step_end'):
|
||||
output = model_ref.validation_step_end(output)
|
||||
|
||||
# track outputs for collation
|
||||
dl_outputs.append(output)
|
||||
|
||||
|
@ -280,10 +293,23 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
# give model a chance to do something with the outputs (and method defined)
|
||||
model = self.get_model()
|
||||
|
||||
if test_mode and self.is_overriden('test_epoch_end'):
|
||||
eval_results = model.test_epoch_end(outputs)
|
||||
elif self.is_overriden('validation_epoch_end'):
|
||||
eval_results = model.validation_epoch_end(outputs)
|
||||
|
||||
# TODO: remove in v 1.0.0
|
||||
if test_mode and self.is_overriden('test_end'):
|
||||
eval_results = model.test_end(outputs)
|
||||
m = 'test_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
|
||||
'Use test_epoch_end instead.'
|
||||
warnings.warn(m, DeprecationWarning)
|
||||
elif self.is_overriden('validation_end'):
|
||||
eval_results = model.validation_end(outputs)
|
||||
m = 'validation_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
|
||||
'Use validation_epoch_end instead.'
|
||||
warnings.warn(m, DeprecationWarning)
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
|
@ -392,7 +418,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
output = model(*args)
|
||||
return output
|
||||
|
||||
# single GPU
|
||||
# single GPU data transfer
|
||||
if self.single_gpu:
|
||||
# for single GPU put inputs on gpu manually
|
||||
root_gpu = 0
|
||||
|
@ -401,12 +427,12 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
batch = self.transfer_batch_to_gpu(batch, root_gpu)
|
||||
args[0] = batch
|
||||
|
||||
# TPU
|
||||
# TPU data transfer
|
||||
if self.use_tpu:
|
||||
batch = self.transfer_batch_to_tpu(batch)
|
||||
args[0] = batch
|
||||
|
||||
# CPU
|
||||
# CPU, TPU or gpu step
|
||||
if test_mode:
|
||||
output = model.test_step(*args)
|
||||
else:
|
||||
|
|
|
@ -126,497 +126,72 @@ class Trainer(TrainerIOMixin,
|
|||
|
||||
Args:
|
||||
logger: Logger (or iterable collection of loggers) for experiment tracking.
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
||||
# default logger used by trainer
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=os.getcwd(),
|
||||
version=self.slurm_job_id,
|
||||
name='lightning_logs'
|
||||
)
|
||||
|
||||
Trainer(logger=logger)
|
||||
|
||||
checkpoint_callback: Callback for checkpointing.
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
|
||||
# default used by the Trainer
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
filepath=os.getcwd(),
|
||||
save_best_only=True,
|
||||
verbose=True,
|
||||
monitor='val_loss',
|
||||
mode='min',
|
||||
prefix=''
|
||||
)
|
||||
|
||||
trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
||||
|
||||
early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`):
|
||||
Callback for early stopping.
|
||||
If set to ``True``, then a default callback monitoring ``'val_loss'`` is created.
|
||||
Will raise an error if ``'val_loss'`` is not found.
|
||||
If set to ``False``, then early stopping will be disabled.
|
||||
If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
|
||||
If ``'val_loss'`` is not found will work as if early stopping is disabled.
|
||||
Default: ``None``.
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
|
||||
# default used by the Trainer
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor='val_loss',
|
||||
patience=3,
|
||||
strict=False,
|
||||
verbose=False,
|
||||
mode='min'
|
||||
)
|
||||
|
||||
trainer = Trainer(early_stop_callback=early_stop_callback)
|
||||
|
||||
callbacks: Add a list of callbacks.
|
||||
Example::
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
class PrintCallback(Callback):
|
||||
def on_train_start(self):
|
||||
print("Training is started!")
|
||||
def on_train_end(self):
|
||||
print(f"Training is done. The logs are: {self.trainer.logs}")
|
||||
# a list of callbacks
|
||||
callbacks = [PrintCallback()]
|
||||
trainer = Trainer(callbacks=callbacks)
|
||||
|
||||
default_save_path: Default path for logs and weights when no logger/ckpt_callback passed
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(default_save_path=os.getcwd())
|
||||
|
||||
gradient_clip_val: 0 means don't clip.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(gradient_clip_val=0.0)
|
||||
|
||||
gradient_clip:
|
||||
.. warning: .. deprecated:: 0.5.0
|
||||
Use `gradient_clip_val` instead. Will remove 0.8.0.
|
||||
.. warning:: .. deprecated:: 0.6.1
|
||||
Use `gradient_clip_val` instead. Will remove 0.8.0.
|
||||
|
||||
process_position: orders the tqdm bar when running multiple models on same machine.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(process_position=0)
|
||||
|
||||
num_nodes: number of GPU nodes for distributed training.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(num_nodes=1)
|
||||
|
||||
# to train on 8 nodes
|
||||
trainer = Trainer(num_nodes=8)
|
||||
|
||||
nb_gpu_nodes:
|
||||
..warning:: .. deprecated:: 0.5.0
|
||||
Use `num_nodes` instead. Will remove 0.8.0.
|
||||
.. warning:: .. deprecated:: 0.6.1
|
||||
Use `num_nodes` instead. Will remove 0.8.0.
|
||||
|
||||
gpus: Which GPUs to train on.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer (ie: train on CPU)
|
||||
trainer = Trainer(gpus=None)
|
||||
|
||||
# int: train on 2 gpus
|
||||
trainer = Trainer(gpus=2)
|
||||
|
||||
# list: train on GPUs 1, 4 (by bus ordering)
|
||||
trainer = Trainer(gpus=[1, 4])
|
||||
trainer = Trainer(gpus='1, 4') # equivalent
|
||||
|
||||
# -1: train on all gpus
|
||||
trainer = Trainer(gpus=-1)
|
||||
trainer = Trainer(gpus='-1') # equivalent
|
||||
|
||||
# combine with num_nodes to train on multiple GPUs across nodes
|
||||
trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total
|
||||
|
||||
num_tpu_cores: How many TPU cores to train on (1 or 8).
|
||||
A single TPU v2 or v3 has 8 cores. A TPU pod has
|
||||
up to 2048 cores. A slice of a POD means you get as many cores
|
||||
as you request.
|
||||
|
||||
You MUST use DistributedDataSampler with your dataloader for this
|
||||
to work. Your effective batch size is batch_size * total tpu cores.
|
||||
|
||||
This parameter can be either 1 or 8.
|
||||
|
||||
Example::
|
||||
|
||||
# your_trainer_file.py
|
||||
|
||||
# default used by the Trainer (ie: train on CPU)
|
||||
trainer = Trainer(num_tpu_cores=None)
|
||||
|
||||
# int: train on a single core
|
||||
trainer = Trainer(num_tpu_cores=1)
|
||||
|
||||
# int: train on all cores few cores
|
||||
trainer = Trainer(num_tpu_cores=8)
|
||||
|
||||
# for 8+ cores must submit via xla script with
|
||||
# a max of 8 cores specified. The XLA script
|
||||
# will duplicate script onto each TPU in the POD
|
||||
trainer = Trainer(num_tpu_cores=8)
|
||||
|
||||
# -1: train on all available TPUs
|
||||
trainer = Trainer(num_tpu_cores=-1)
|
||||
|
||||
To train on more than 8 cores (ie: a POD),
|
||||
submit this script using the xla_dist script.
|
||||
|
||||
Example::
|
||||
|
||||
$ python -m torch_xla.distributed.xla_dist
|
||||
--tpu=$TPU_POD_NAME
|
||||
--conda-env=torch-xla-nightly
|
||||
--env=XLA_USE_BF16=1
|
||||
-- python your_trainer_file.py
|
||||
|
||||
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
|
||||
because it uses the output of nvidia-smi.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(log_gpu_memory=None)
|
||||
|
||||
# log all the GPUs (on master node only)
|
||||
trainer = Trainer(log_gpu_memory='all')
|
||||
|
||||
# log only the min and max memory on the master node
|
||||
trainer = Trainer(log_gpu_memory='min_max')
|
||||
|
||||
show_progress_bar: If true shows tqdm progress bar
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(show_progress_bar=True)
|
||||
|
||||
progress_bar_refresh_rate: How often to refresh progress bar (in steps)
|
||||
|
||||
overfit_pct: uses this much data of all datasets.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(overfit_pct=0.0)
|
||||
|
||||
# use only 1% of the train, test, val datasets
|
||||
trainer = Trainer(overfit_pct=0.01)
|
||||
|
||||
track_grad_norm: -1 no tracking. Otherwise tracks that norm
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(track_grad_norm=-1)
|
||||
|
||||
# track the 2-norm
|
||||
trainer = Trainer(track_grad_norm=2)
|
||||
|
||||
check_val_every_n_epoch: Check val every n train epochs.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(check_val_every_n_epoch=1)
|
||||
|
||||
# run val loop every 10 training epochs
|
||||
trainer = Trainer(check_val_every_n_epoch=10)
|
||||
|
||||
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(fast_dev_run=False)
|
||||
|
||||
# runs 1 train, val, test batch and program ends
|
||||
trainer = Trainer(fast_dev_run=True)
|
||||
|
||||
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer (no accumulation)
|
||||
trainer = Trainer(accumulate_grad_batches=1)
|
||||
|
||||
# accumulate every 4 batches (effective batch size is batch*4)
|
||||
trainer = Trainer(accumulate_grad_batches=4)
|
||||
|
||||
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
|
||||
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
|
||||
|
||||
max_epochs: Stop training once this number of epochs is reached.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(max_epochs=1000)
|
||||
|
||||
max_nb_epochs:
|
||||
.. warning:: .. deprecated:: 0.5.0
|
||||
Use `max_epochs` instead. Will remove 0.8.0.
|
||||
.. warning:: .. deprecated:: 0.6.1
|
||||
Use `max_epochs` instead. Will remove 0.8.0.
|
||||
|
||||
min_epochs: Force training for at least these many epochs
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(min_epochs=1)
|
||||
|
||||
min_nb_epochs:
|
||||
.. warning:: .. deprecated:: 0.5.0
|
||||
Use `min_nb_epochs` instead. Will remove 0.8.0.
|
||||
.. warning:: .. deprecated:: 0.6.1
|
||||
Use `min_epochs` instead. Will remove 0.8.0.
|
||||
|
||||
max_steps: Stop training after this number of steps. Disabled by default (None).
|
||||
Training will stop if max_steps or max_epochs have reached (earliest).
|
||||
Example::
|
||||
|
||||
# Stop after 100 steps
|
||||
trainer = Trainer(max_steps=100)
|
||||
|
||||
min_steps: Force training for at least these number of steps. Disabled by default (None).
|
||||
Trainer will train model for at least min_steps or min_epochs (latest).
|
||||
Example::
|
||||
|
||||
# Run at least for 100 steps (disable min_epochs)
|
||||
trainer = Trainer(min_steps=100, min_epochs=0)
|
||||
|
||||
train_percent_check: How much of training dataset to check.
|
||||
Useful when debugging or testing something that happens at the end of an epoch.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(train_percent_check=1.0)
|
||||
|
||||
# run through only 25% of the training set each epoch
|
||||
trainer = Trainer(train_percent_check=0.25)
|
||||
|
||||
val_percent_check: How much of validation dataset to check.
|
||||
Useful when debugging or testing something that happens at the end of an epoch.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(val_percent_check=1.0)
|
||||
|
||||
# run through only 25% of the validation set each epoch
|
||||
trainer = Trainer(val_percent_check=0.25)
|
||||
|
||||
test_percent_check: How much of test dataset to check.
|
||||
Useful when debugging or testing something that happens at the end of an epoch.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(test_percent_check=1.0)
|
||||
|
||||
# run through only 25% of the test set each epoch
|
||||
trainer = Trainer(test_percent_check=0.25)
|
||||
|
||||
val_check_interval: How often within one training epoch to check the validation set
|
||||
If float, % of tng epoch. If int, check every n batch
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(val_check_interval=1.0)
|
||||
|
||||
# check validation set 4 times during a training epoch
|
||||
trainer = Trainer(val_check_interval=0.25)
|
||||
|
||||
# check validation set every 1000 training batches
|
||||
# use this when using iterableDataset and your dataset has no length
|
||||
# (ie: production cases with streaming data)
|
||||
trainer = Trainer(val_check_interval=1000)
|
||||
|
||||
log_save_interval: Writes logs to disk this often
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(log_save_interval=100)
|
||||
|
||||
row_log_interval: How often to add logging rows (does not write to disk)
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(row_log_interval=10)
|
||||
|
||||
add_row_log_interval:
|
||||
.. warning:: .. deprecated:: 0.5.0
|
||||
Use `row_log_interval` instead. Will remove 0.8.0.
|
||||
.. warning:: .. deprecated:: 0.6.1
|
||||
Use `row_log_interval` instead. Will remove 0.8.0.
|
||||
|
||||
distributed_backend: The distributed backend to use.
|
||||
Options: 'dp', 'ddp', 'ddp2'.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(distributed_backend=None)
|
||||
|
||||
# dp = DataParallel (split a batch onto k gpus on same machine).
|
||||
trainer = Trainer(gpus=2, distributed_backend='dp')
|
||||
|
||||
# ddp = DistributedDataParallel
|
||||
# Each gpu trains by itself on a subset of the data.
|
||||
# Gradients sync across all gpus and all machines.
|
||||
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
|
||||
|
||||
# ddp2 = DistributedDataParallel + dp
|
||||
# behaves like dp on every node
|
||||
# syncs gradients across nodes like ddp
|
||||
# useful for things like increasing the number of negative samples
|
||||
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
|
||||
|
||||
use_amp:
|
||||
.. warning:: .. deprecated:: 0.6.1
|
||||
Use `precision` instead. Will remove 0.8.0.
|
||||
.. warning:: .. deprecated:: 0.7.0
|
||||
Use `precision` instead. Will remove 0.8.0.
|
||||
|
||||
precision: Full precision (32), half precision (16).
|
||||
Can be used on CPU, GPU or TPUs.
|
||||
|
||||
If used on TPU will use torch.bfloat16 but tensor printing
|
||||
will still show torch.float32.
|
||||
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(precision=32)
|
||||
|
||||
# 16-bit precision
|
||||
trainer = Trainer(precision=16)
|
||||
|
||||
# one day
|
||||
trainer = Trainer(precision=8|4|2)
|
||||
|
||||
print_nan_grads: Prints gradients with nan values
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(print_nan_grads=False)
|
||||
|
||||
weights_summary: Prints a summary of the weights when training begins.
|
||||
Options: 'full', 'top', None.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer (ie: print all weights)
|
||||
trainer = Trainer(weights_summary='full')
|
||||
|
||||
# print only the top level modules
|
||||
trainer = Trainer(weights_summary='top')
|
||||
|
||||
# don't print a summary
|
||||
trainer = Trainer(weights_summary=None)
|
||||
|
||||
weights_save_path: Where to save weights if specified.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(weights_save_path=os.getcwd())
|
||||
|
||||
# save to your custom path
|
||||
trainer = Trainer(weights_save_path='my/path')
|
||||
|
||||
# if checkpoint callback used, then overrides the weights path
|
||||
# **NOTE: this saves weights to some/path NOT my/path
|
||||
checkpoint_callback = ModelCheckpoint(filepath='some/path')
|
||||
trainer = Trainer(
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
weights_save_path='my/path'
|
||||
)
|
||||
|
||||
amp_level: The optimization level to use (O1, O2, etc...).
|
||||
Check nvidia docs for level (https://nvidia.github.io/apex/amp.html#opt-levels)
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(amp_level='O1')
|
||||
|
||||
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
|
||||
This catches any bugs in your validation without having to wait for the first validation check.
|
||||
The Trainer uses 5 steps by default. Turn it off or modify it here.
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(num_sanity_val_steps=5)
|
||||
|
||||
# turn it off
|
||||
trainer = Trainer(num_sanity_val_steps=0)
|
||||
|
||||
nb_sanity_val_steps:
|
||||
.. warning:: .. deprecated:: 0.5.0
|
||||
Use `num_sanity_val_steps` instead. Will remove 0.8.0.
|
||||
.. warning:: .. deprecated:: 0.7.0
|
||||
Use `num_sanity_val_steps` instead. Will remove 0.8.0.
|
||||
|
||||
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
|
||||
a much longer sequence If this is enabled, your batches will automatically get truncated
|
||||
and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence
|
||||
dimension. (`Williams et al. "An efficient gradient-based algorithm for on-line training of
|
||||
recurrent network trajectories."
|
||||
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
|
||||
Example::
|
||||
|
||||
# default used by the Trainer (ie: disabled)
|
||||
trainer = Trainer(truncated_bptt_steps=None)
|
||||
|
||||
# backprop every 5 steps in a batch
|
||||
trainer = Trainer(truncated_bptt_steps=5)
|
||||
|
||||
|
||||
Lightning takes care to split your batch along the time-dimension.
|
||||
|
||||
.. note:: If you need to modify how the batch is split,
|
||||
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
|
||||
|
||||
.. note:: Using this feature requires updating your LightningModule's
|
||||
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
|
||||
|
||||
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.k
|
||||
Example::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(resume_from_checkpoint=None)
|
||||
|
||||
# resume from a specific checkpoint
|
||||
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
|
||||
profiler: To profile individual steps during training and assist in
|
||||
identifying bottlenecks.
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(profiler=None)
|
||||
|
||||
# to profile standard training events
|
||||
trainer = Trainer(profiler=True)
|
||||
|
||||
# equivalent to profiler=True
|
||||
profiler = Profiler()
|
||||
trainer = Trainer(profiler=profiler)
|
||||
|
||||
# advanced profiler for function-level stats
|
||||
profiler = AdvancedProfiler()
|
||||
trainer = Trainer(profiler=profiler)
|
||||
reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch
|
||||
|
||||
benchmark (bool): If true enables cudnn.benchmark.
|
||||
This flag is likely to increase the speed of your system if your
|
||||
input sizes don't change. However, if it does, then it will likely
|
||||
make your system slower.
|
||||
|
||||
The speedup comes from allowing the cudnn auto-tuner to find the best
|
||||
algorithm for the hardware `[see discussion here]
|
||||
<https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936>`_.
|
||||
|
||||
.. warning:: Following arguments become deprecated and they will be removed in v0.8.0:
|
||||
|
||||
- `nb_sanity_val_steps`
|
||||
|
||||
"""
|
||||
|
||||
# Init callbacks
|
||||
|
|
|
@ -698,12 +698,24 @@ class TrainerTrainLoopMixin(ABC):
|
|||
else:
|
||||
output = self.model.training_step(*args)
|
||||
|
||||
# allow any mode to define training_step_end
|
||||
# do something will all the dp outputs (like softmax)
|
||||
if self.is_overriden('training_step_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('training_step_end'):
|
||||
output = model_ref.training_step_end(output)
|
||||
|
||||
# allow any mode to define training_end
|
||||
# TODO: remove in 1.0.0
|
||||
if self.is_overriden('training_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('training_end'):
|
||||
output = model_ref.training_end(output)
|
||||
|
||||
m = 'training_end was deprecated in 0.7.0 and will be removed 1.0.0. ' \
|
||||
'Use training_epoch_end instead'
|
||||
warnings.warn(m, DeprecationWarning)
|
||||
|
||||
# format and reduce outputs accordingly
|
||||
output = self.process_output(output, train=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue