quick start docs changes (#3028)
* updated code example * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj
This commit is contained in:
parent
ca18e11f6e
commit
5dfc7b157e
|
@ -15,7 +15,7 @@ code to work with Lightning.
|
||||||
|
|
||||||
.. raw:: html
|
.. raw:: html
|
||||||
|
|
||||||
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_module_vid.m4v"></video>
|
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_quick_start_full.m4v"></video>
|
||||||
|
|
||||||
|
|
|
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,9 @@
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch.utils.data import random_split
|
||||||
|
|
||||||
.. _quick-start:
|
.. _quick-start:
|
||||||
|
|
||||||
|
@ -16,11 +19,11 @@ PyTorch Lightning is nothing more than organized PyTorch code.
|
||||||
|
|
||||||
Once you've organized it into a LightningModule, it automates most of the training for you.
|
Once you've organized it into a LightningModule, it automates most of the training for you.
|
||||||
|
|
||||||
To illustrate, here's the typical PyTorch project structure organized in a LightningModule.
|
Here's a 2 minute conversion guide for PyTorch projects:
|
||||||
|
|
||||||
.. raw:: html
|
.. raw:: html
|
||||||
|
|
||||||
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_module_vid.m4v"></video>
|
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_quick_start_full.m4v"></video>
|
||||||
|
|
||||||
----------
|
----------
|
||||||
|
|
||||||
|
@ -34,12 +37,16 @@ A lightningModule defines
|
||||||
- Model + system architecture
|
- Model + system architecture
|
||||||
- Optimizer
|
- Optimizer
|
||||||
|
|
||||||
.. testcode::
|
.. code-block::
|
||||||
:skipif: not TORCHVISION_AVAILABLE
|
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
from torchvision import transforms
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.metrics.functional import accuracy
|
from torch.utils.data import random_split
|
||||||
|
|
||||||
class LitModel(pl.LightningModule):
|
class LitModel(pl.LightningModule):
|
||||||
|
|
||||||
|
@ -74,7 +81,7 @@ well across any accelerator.
|
||||||
|
|
||||||
Here's an example of using the Trainer:
|
Here's an example of using the Trainer:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block::
|
||||||
|
|
||||||
# dataloader
|
# dataloader
|
||||||
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
|
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
|
||||||
|
@ -83,7 +90,7 @@ Here's an example of using the Trainer:
|
||||||
# init model
|
# init model
|
||||||
model = LitModel()
|
model = LitModel()
|
||||||
|
|
||||||
# most basic trainer, uses good defaults
|
# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
|
||||||
trainer = pl.Trainer()
|
trainer = pl.Trainer()
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
@ -350,30 +357,49 @@ And the matching code:
|
||||||
|
|
||||||
|
|
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block::
|
||||||
|
|
||||||
class MyDataModule(pl.DataModule):
|
class MNISTDataModule(pl.LightningDataModule):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, batch_size=32):
|
||||||
...
|
super().__init__()
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def prepare_data(self):
|
||||||
|
# optional to support downloading only once when using multi-GPU or multi-TPU
|
||||||
|
MNIST(os.getcwd(), train=True, download=True)
|
||||||
|
MNIST(os.getcwd(), train=False, download=True)
|
||||||
|
|
||||||
|
def setup(self, stage):
|
||||||
|
transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.1307,), (0.3081,))
|
||||||
|
])
|
||||||
|
|
||||||
|
if stage == 'fit':
|
||||||
|
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
|
||||||
|
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
|
||||||
|
if stage == 'test':
|
||||||
|
mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
|
||||||
|
self.mnist_test = MNIST(os.getcwd(), train=False, download=True)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
# your train transforms
|
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
|
||||||
return DataLoader(YOUR_DATASET)
|
return mnist_train
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
# your val transforms
|
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
|
||||||
return DataLoader(YOUR_DATASET)
|
return mnist_val
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
# your test transforms
|
mnist_test = DataLoader(mnist_test, batch_size=self.batch_size)
|
||||||
return DataLoader(YOUR_DATASET)
|
return mnist_test
|
||||||
|
|
||||||
And train like so:
|
And train like so:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
dm = MyDataModule()
|
dm = MNISTDataModule()
|
||||||
trainer.fit(model, dm)
|
trainer.fit(model, dm)
|
||||||
|
|
||||||
When doing distributed training, Datamodules have two optional arguments for granular control
|
When doing distributed training, Datamodules have two optional arguments for granular control
|
||||||
|
|
Loading…
Reference in New Issue