This commit is contained in:
William Falcon 2020-10-12 16:48:07 -04:00 committed by GitHub
parent 42a4fe06b0
commit 8e83ac5aa9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 70 additions and 57 deletions

View File

@ -156,16 +156,32 @@ of the 20+ hooks found in :ref:`hooks`
def backward(self, loss, optimizer, optimizer_idx):
loss.backward()
In Lightning, training_step defines the train loop and is independent of forward. Use forward to define
what happens during inference/predictions
**FORWARD vs TRAINING_STEP**
In Lightning we separate training from inference. The training_step defines
the full training loop. We encourage users to use the forward to define inference
actions.
For example, in this case we could define the autoencoder to act as an embedding extractor:
.. code-block:: python
def forward(...):
# how you want your model to do inference/predictions
def forward(self, x):
embeddings = self.encoder(x)
return embeddings
def training_step(...):
# the train loop INDEPENDENT of forward.
Of course, nothing is stopping you from using forward from within the training_step
.. code-block:: python
def training_step(self, batch, batch_idx):
...
z = self(x)
It really comes down to your application. We do however, recommend that you keep both intents separate.
* Use forward for inference (predicting).
* Use training_step for training.
More details in :ref:`lightning_module` docs.
@ -222,6 +238,52 @@ features of the Trainer or LightningModule.
Basic features
**************
Manual vs automatic optimization
================================
Automatic optimization
----------------------
With Lightning you don't need to worry about when to enable/disable grads, do a backward pass, or update optimizers
as long as you return a loss with an attached graph from the `training_step`, Lightning will automate the optimization.
.. code-block:: python
def training_step(self, batch, batch_idx):
loss = self.encoder(batch[0])
return loss
.. _manual_opt:
Manual optimization
-------------------
However, for certain research like GANs, reinforcement learning or something with multiple optimizers
or an inner loop, you can turn off automatic optimization and fully control the training loop yourself.
First, turn off automatic optimization:
.. code-block:: python
trainer = Trainer(automatic_optimization=False)
Now you own the train loop!
.. code-block:: python
def training_step(self, batch, batch_idx, opt_idx):
(opt_a, opt_b, opt_c) = self.optimizers()
loss_a = self.generator(batch[0])
# use this instead of loss.backward so we can automate half precision, etc...
self.manual_backward(loss_a, opt_a, retain_graph=True)
self.manual_backward(loss_a, opt_a)
opt_a.step()
opt_a.zero_grad()
loss_b = self.discriminator(batch[0])
self.manual_backward(loss_b, opt_b)
...
Predict or Deploy
=================
@ -671,58 +733,9 @@ Lightning has many tools for debugging. Here is an example of just a few of them
---------------
*****************
Advanced features
*****************
Manual vs automatic optimization
================================
Automatic optimization
----------------------
With Lightning you don't need to worry about when to enable/disable grads, do a backward pass, or update optimizers
as long as you return a loss with an attached graph from the `training_step`, Lightning will automate the optimization.
.. code-block:: python
def training_step(self, batch, batch_idx):
loss = self.encoder(batch[0])
return loss
.. _manual_opt:
Manual optimization
-------------------
However, for certain research like GANs, reinforcement learning or something with multiple optimizers
or an inner loop, you can turn off automatic optimization and fully control the training loop yourself.
First, turn off automatic optimization:
.. code-block:: python
trainer = Trainer(automatic_optimization=False)
Now you own the train loop!
.. code-block:: python
def training_step(self, batch, batch_idx, opt_idx):
(opt_a, opt_b, opt_c) = self.optimizers()
loss_a = self.generator(batch[0])
# use this instead of loss.backward so we can automate half precision, etc...
self.manual_backward(loss_a, opt_a, retain_graph=True)
self.manual_backward(loss_a, opt_a)
opt_a.step()
opt_a.zero_grad()
loss_b = self.discriminator(batch[0])
self.manual_backward(loss_b, opt_b)
...
********************
Other coool features
====================
********************
Once you define and train your first Lightning model, you might want to try other cool features like