Update README.md
This commit is contained in:
parent
619f984c36
commit
39584d08ad
31
README.md
31
README.md
|
@ -54,6 +54,37 @@ pip install pytorch-lightning
|
|||
[![Watch the video](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_images/general/tutorial_cover.png)](https://www.youtube.com/watch?v=QHww1JH7IDU)
|
||||
|
||||
## Demo
|
||||
Here's a minimal example without a validation or test loop.
|
||||
|
||||
```python
|
||||
# this is just a plain nn.Module with some structure
|
||||
|
||||
class MNISTModel(pl.LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super(MNISTModel, self).__init__()
|
||||
self.l1 = torch.nn.Linear(28 * 28, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.relu(self.l1(x.view(x.size(0), -1)))
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
loss = F.cross_entropy(self(x), y)
|
||||
tensorboard_logs = {'train_loss': loss}
|
||||
return {'loss': loss, 'log': tensorboard_logs}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
# train!
|
||||
train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
mnist_model = MNISTModel()
|
||||
trainer = pl.Trainer(gpus=8, precision=16)
|
||||
trainer.fit(mnist_model, train_loader)
|
||||
```
|
||||
|
||||
[MNIST, GAN, BERT, DQN on COLAB!](https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=HOk9c4_35FKg)
|
||||
[MNIST on TPUs](https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3)
|
||||
|
||||
|
|
Loading…
Reference in New Issue