diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst
index 5f0d4574e3..e6d8b33ce8 100644
--- a/docs/source/introduction_guide.rst
+++ b/docs/source/introduction_guide.rst
@@ -15,7 +15,7 @@ code to work with Lightning.
.. raw:: html
-
+
|
diff --git a/docs/source/new-project.rst b/docs/source/new-project.rst
index fb0c55f6d3..68b4f54afa 100644
--- a/docs/source/new-project.rst
+++ b/docs/source/new-project.rst
@@ -6,6 +6,9 @@
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
+ from torch.utils.data import DataLoader
+ import pytorch_lightning as pl
+ from torch.utils.data import random_split
.. _quick-start:
@@ -16,11 +19,11 @@ PyTorch Lightning is nothing more than organized PyTorch code.
Once you've organized it into a LightningModule, it automates most of the training for you.
-To illustrate, here's the typical PyTorch project structure organized in a LightningModule.
+Here's a 2 minute conversion guide for PyTorch projects:
.. raw:: html
-
+
----------
@@ -34,12 +37,16 @@ A lightningModule defines
- Model + system architecture
- Optimizer
-.. testcode::
- :skipif: not TORCHVISION_AVAILABLE
-
+.. code-block::
+ import os
+ import torch
+ import torch.nn.functional as F
+ from torchvision.datasets import MNIST
+ from torchvision import transforms
+ from torch.utils.data import DataLoader
import pytorch_lightning as pl
- from pytorch_lightning.metrics.functional import accuracy
+ from torch.utils.data import random_split
class LitModel(pl.LightningModule):
@@ -74,7 +81,7 @@ well across any accelerator.
Here's an example of using the Trainer:
-.. code-block:: python
+.. code-block::
# dataloader
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
@@ -83,7 +90,7 @@ Here's an example of using the Trainer:
# init model
model = LitModel()
- # most basic trainer, uses good defaults
+ # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
trainer = pl.Trainer()
trainer.fit(model, train_loader)
@@ -350,30 +357,49 @@ And the matching code:
|
-.. code-block:: python
+.. code-block::
- class MyDataModule(pl.DataModule):
+ class MNISTDataModule(pl.LightningDataModule):
- def __init__(self):
- ...
+ def __init__(self, batch_size=32):
+ super().__init__()
+ self.batch_size = batch_size
+
+ def prepare_data(self):
+ # optional to support downloading only once when using multi-GPU or multi-TPU
+ MNIST(os.getcwd(), train=True, download=True)
+ MNIST(os.getcwd(), train=False, download=True)
+
+ def setup(self, stage):
+ transform=transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,))
+ ])
+
+ if stage == 'fit':
+ mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
+ self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
+ if stage == 'test':
+ mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
+ self.mnist_test = MNIST(os.getcwd(), train=False, download=True)
def train_dataloader(self):
- # your train transforms
- return DataLoader(YOUR_DATASET)
+ mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
+ return mnist_train
def val_dataloader(self):
- # your val transforms
- return DataLoader(YOUR_DATASET)
+ mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
+ return mnist_val
def test_dataloader(self):
- # your test transforms
- return DataLoader(YOUR_DATASET)
+ mnist_test = DataLoader(mnist_test, batch_size=self.batch_size)
+ return mnist_test
And train like so:
.. code-block:: python
- dm = MyDataModule()
+ dm = MNISTDataModule()
trainer.fit(model, dm)
When doing distributed training, Datamodules have two optional arguments for granular control