lightning/docs/source/converting.rst

104 lines
3.3 KiB
ReStructuredText
Raw Normal View History

.. testsetup:: *
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.trainer import Trainer
.. _converting:
**************************************
How to organize PyTorch into Lightning
**************************************
To enable your code to work with Lightning, here's how to organize PyTorch into Lightning
1. Move your computational code
===============================
Move the model architecture and forward pass to your :class:`~pytorch_lightning.core.LightningModule`.
.. testcode::
class LitModel(LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
2. Move the optimizer(s) and schedulers
=======================================
Move your optimizers to :func:`pytorch_lightning.core.LightningModule.configure_optimizers` hook. Make sure to use the hook parameters (self in this case).
.. testcode::
class LitModel(LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
3. Find the train loop "meat"
=============================
Lightning automates most of the trining for you, the epoch and batch iterations, all you need to keep is the training step logic. This should go into :func:`pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, self in this case):
.. testcode::
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
4. Find the val loop "meat"
===========================
To add an (optional) validation loop add logic to :func:`pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, self in this case).
.. testcode::
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
val_loss = F.cross_entropy(y_hat, y)
return val_loss
.. note:: model.eval() and torch.no_grad() are called automatically for validation
5. Find the test loop "meat"
============================
To add an (optional) test loop add logic to :func:`pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, self in this case).
.. testcode::
class LitModel(LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
.. note:: model.eval() and torch.no_grad() are called automatically for testing.
The test loop will not be used until you call.
.. code-block::
trainer.test()
.. note:: .test() loads the best checkpoint automatically
6. Remove any .cuda() or to.device() calls
==========================================
Your :class:`~pytorch_lightning.core.LightningModule` can automatically run on any hardware!