diff --git a/docs/source/new-project.rst b/docs/source/new-project.rst index ccea9d9106..829717b8d4 100644 --- a/docs/source/new-project.rst +++ b/docs/source/new-project.rst @@ -156,16 +156,32 @@ of the 20+ hooks found in :ref:`hooks` def backward(self, loss, optimizer, optimizer_idx): loss.backward() -In Lightning, training_step defines the train loop and is independent of forward. Use forward to define -what happens during inference/predictions +**FORWARD vs TRAINING_STEP** + +In Lightning we separate training from inference. The training_step defines +the full training loop. We encourage users to use the forward to define inference +actions. + +For example, in this case we could define the autoencoder to act as an embedding extractor: .. code-block:: python - def forward(...): - # how you want your model to do inference/predictions + def forward(self, x): + embeddings = self.encoder(x) + return embeddings - def training_step(...): - # the train loop INDEPENDENT of forward. +Of course, nothing is stopping you from using forward from within the training_step + +.. code-block:: python + + def training_step(self, batch, batch_idx): + ... + z = self(x) + +It really comes down to your application. We do however, recommend that you keep both intents separate. + +* Use forward for inference (predicting). +* Use training_step for training. More details in :ref:`lightning_module` docs. @@ -222,6 +238,52 @@ features of the Trainer or LightningModule. Basic features ************** +Manual vs automatic optimization +================================ + +Automatic optimization +---------------------- +With Lightning you don't need to worry about when to enable/disable grads, do a backward pass, or update optimizers +as long as you return a loss with an attached graph from the `training_step`, Lightning will automate the optimization. + +.. code-block:: python + + def training_step(self, batch, batch_idx): + loss = self.encoder(batch[0]) + return loss + +.. _manual_opt: + +Manual optimization +------------------- +However, for certain research like GANs, reinforcement learning or something with multiple optimizers +or an inner loop, you can turn off automatic optimization and fully control the training loop yourself. + +First, turn off automatic optimization: + +.. code-block:: python + + trainer = Trainer(automatic_optimization=False) + +Now you own the train loop! + +.. code-block:: python + + def training_step(self, batch, batch_idx, opt_idx): + (opt_a, opt_b, opt_c) = self.optimizers() + + loss_a = self.generator(batch[0]) + + # use this instead of loss.backward so we can automate half precision, etc... + self.manual_backward(loss_a, opt_a, retain_graph=True) + self.manual_backward(loss_a, opt_a) + opt_a.step() + opt_a.zero_grad() + + loss_b = self.discriminator(batch[0]) + self.manual_backward(loss_b, opt_b) + ... + Predict or Deploy ================= @@ -671,58 +733,9 @@ Lightning has many tools for debugging. Here is an example of just a few of them --------------- -***************** -Advanced features -***************** - -Manual vs automatic optimization -================================ - -Automatic optimization ----------------------- -With Lightning you don't need to worry about when to enable/disable grads, do a backward pass, or update optimizers -as long as you return a loss with an attached graph from the `training_step`, Lightning will automate the optimization. - -.. code-block:: python - - def training_step(self, batch, batch_idx): - loss = self.encoder(batch[0]) - return loss - -.. _manual_opt: - -Manual optimization -------------------- -However, for certain research like GANs, reinforcement learning or something with multiple optimizers -or an inner loop, you can turn off automatic optimization and fully control the training loop yourself. - -First, turn off automatic optimization: - -.. code-block:: python - - trainer = Trainer(automatic_optimization=False) - -Now you own the train loop! - -.. code-block:: python - - def training_step(self, batch, batch_idx, opt_idx): - (opt_a, opt_b, opt_c) = self.optimizers() - - loss_a = self.generator(batch[0]) - - # use this instead of loss.backward so we can automate half precision, etc... - self.manual_backward(loss_a, opt_a, retain_graph=True) - self.manual_backward(loss_a, opt_a) - opt_a.step() - opt_a.zero_grad() - - loss_b = self.discriminator(batch[0]) - self.manual_backward(loss_b, opt_b) - ... - +******************** Other coool features -==================== +******************** Once you define and train your first Lightning model, you might want to try other cool features like