Fix missing imports in converting.rst ()

This commit is contained in:
John Kilpatrick 2022-02-21 21:26:18 +00:00 committed by GitHub
parent 70f9c6fda3
commit b5c135896f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 5 deletions
docs/source/starter

View File

@ -23,7 +23,13 @@ Move the model architecture and forward pass to your :class:`~pytorch_lightning.
.. testcode::
class LitModel(LightningModule):
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
@ -46,7 +52,7 @@ Move your optimizers to the :meth:`~pytorch_lightning.core.lightning.LightningMo
.. testcode::
class LitModel(LightningModule):
class LitModel(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
@ -67,7 +73,7 @@ as arguments. Optionally, it can take ``optimizer_idx`` if your LightningModule
.. testcode::
class LitModel(LightningModule):
class LitModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
@ -90,7 +96,7 @@ To add an (optional) validation loop add logic to the
.. testcode::
class LitModel(LightningModule):
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
@ -121,7 +127,7 @@ method. When using Lightning, simply override the :meth:`~pytorch_lightning.core
.. testcode::
class LitModel(LightningModule):
class LitModel(pl.LightningModule):
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)