.. testsetup:: * from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.trainer.trainer import Trainer .. _converting: ************************************** How to organize PyTorch into Lightning ************************************** To enable your code to work with Lightning, here's how to organize PyTorch into Lightning 1. Move your computational code =============================== Move the model architecture and forward pass to your :class:`~pytorch_lightning.core.LightningModule`. .. code-block:: class LitModel(pl.LightningModule): def __init__(self): super().__init__() self.layer_1 = torch.nn.Linear(28 * 28, 128) self.layer_2 = torch.nn.Linear(128, 10) def forward(self, x): x = x.view(x.size(0), -1) x = self.layer_1(x) x = F.relu(x) x = self.layer_2(x) return x 2. Move the optimizer(s) and schedulers ======================================= Move your optimizers to :func:`pytorch_lightning.core.LightningModule.configure_optimizers` hook. Make sure to use the hook parameters (self in this case). .. code-block:: class LitModel(pl.LightningModule): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer 3. Find the train loop "meat" ============================= Lightning automates most of the trining for you, the epoch and batch iterations, all you need to keep is the training step logic. This should go into :func:`pytorch_lightning.core.LightningModule.training_step` hook (make sure to use the hook parameters, self in this case): .. code-block:: class LitModel(pl.LightningModule): def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss 4. Find the val loop "meat" =========================== To add an (optional) validation loop add logic to :func:`pytorch_lightning.core.LightningModule.validation_step` hook (make sure to use the hook parameters, self in this case). .. testcode:: class LitModel(LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) val_loss = F.cross_entropy(y_hat, y) return val_loss .. note:: model.eval() and torch.no_grad() are called automatically for validation 5. Find the test loop "meat" ============================ To add an (optional) test loop add logic to :func:`pytorch_lightning.core.LightningModule.test_step` hook (make sure to use the hook parameters, self in this case). .. code-block:: class LitModel(pl.LightningModule): def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss .. note:: model.eval() and torch.no_grad() are called automatically for testing. The test loop will not be used until you call. .. code-block:: trainer.test() .. note:: .test() loads the best checkpoint automatically 6. Remove any .cuda() or to.device() calls ========================================== Your :class:`~pytorch_lightning.core.LightningModule` can automatically run on any hardware!