lightning/docs/source/converting.rst

119 lines
3.4 KiB
ReStructuredText
Raw Normal View History

.. testsetup:: *
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.trainer import Trainer
.. _converting:
**************************************
How to organize PyTorch into Lightning
**************************************
To enable your code to work with Lightning, here's how to organize PyTorch into Lightning
--------
1. Move your computational code
===============================
Move the model architecture and forward pass to your :class:`~pytorch_lightning.core.LightningModule`.
.. testcode::
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
--------
2. Move the optimizer(s) and schedulers
=======================================
Move your optimizers to the :func:`~pytorch_lightning.core.LightningModule.configure_optimizers` hook.
.. testcode::
class LitModel(LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
--------
3. Find the train loop "meat"
=============================
Lightning automates most of the training for you, the epoch and batch iterations, all you need to keep is the training step logic.
This should go into the :func:`~pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case):
.. testcode::
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
--------
4. Find the val loop "meat"
===========================
To add an (optional) validation loop add logic to the
:func:`~pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
.. testcode::
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
return val_loss
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for validation
--------
5. Find the test loop "meat"
============================
To add an (optional) test loop add logic to the
:func:`~pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, ``batch`` and ``batch_idx`` in this case).
.. testcode::
class LitModel(LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.
The test loop will not be used until you call.
.. code-block::
trainer.test()
.. tip:: .test() loads the best checkpoint automatically
--------
6. Remove any .cuda() or to.device() calls
==========================================
Your :class:`~pytorch_lightning.core.LightningModule` can automatically run on any hardware!