Fix doctests
This commit is contained in:
parent
41b0ffe1f4
commit
8d0266f958
|
@ -1,9 +1,3 @@
|
||||||
.. testsetup:: *
|
|
||||||
|
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
|
||||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
|
||||||
from pytorch_lightning.trainer.trainer import Trainer
|
|
||||||
|
|
||||||
.. _introduction_guide:
|
.. _introduction_guide:
|
||||||
|
|
||||||
#########################
|
#########################
|
||||||
|
@ -72,7 +66,7 @@ Let's first start with the model. In this case, we'll design a 3-layer neural ne
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
from pytorch_lightning import LightningModule
|
||||||
|
|
||||||
|
|
||||||
class LitMNIST(LightningModule):
|
class LitMNIST(LightningModule):
|
||||||
|
@ -187,14 +181,13 @@ Data
|
||||||
|
|
||||||
Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIST.
|
Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIST.
|
||||||
|
|
||||||
.. testcode::
|
.. code-block:: python
|
||||||
:skipif: not _TORCHVISION_AVAILABLE
|
|
||||||
|
|
||||||
from torch.utils.data import DataLoader, random_split
|
|
||||||
from torchvision.datasets import MNIST
|
|
||||||
import os
|
import os
|
||||||
from torchvision import datasets, transforms
|
|
||||||
from pytorch_lightning import Trainer
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
# transforms
|
# transforms
|
||||||
# prepare transforms standard to MNIST
|
# prepare transforms standard to MNIST
|
||||||
|
@ -204,19 +197,6 @@ Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIS
|
||||||
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
|
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
|
||||||
mnist_train = DataLoader(mnist_train, batch_size=64)
|
mnist_train = DataLoader(mnist_train, batch_size=64)
|
||||||
|
|
||||||
.. testoutput::
|
|
||||||
:hide:
|
|
||||||
:skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not _TORCHVISION_AVAILABLE
|
|
||||||
|
|
||||||
Downloading ...
|
|
||||||
Extracting ...
|
|
||||||
Downloading ...
|
|
||||||
Extracting ...
|
|
||||||
Downloading ...
|
|
||||||
Extracting ...
|
|
||||||
Processing...
|
|
||||||
Done!
|
|
||||||
|
|
||||||
You can use DataLoaders in three ways:
|
You can use DataLoaders in three ways:
|
||||||
|
|
||||||
1. Pass DataLoaders to .fit()
|
1. Pass DataLoaders to .fit()
|
||||||
|
|
Loading…
Reference in New Issue