Fix doctests
This commit is contained in:
parent
41b0ffe1f4
commit
8d0266f958
|
@ -1,9 +1,3 @@
|
|||
.. testsetup:: *
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
.. _introduction_guide:
|
||||
|
||||
#########################
|
||||
|
@ -72,7 +66,7 @@ Let's first start with the model. In this case, we'll design a 3-layer neural ne
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
from torch import nn
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
|
||||
class LitMNIST(LightningModule):
|
||||
|
@ -187,14 +181,13 @@ Data
|
|||
|
||||
Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIST.
|
||||
|
||||
.. testcode::
|
||||
:skipif: not _TORCHVISION_AVAILABLE
|
||||
.. code-block:: python
|
||||
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision.datasets import MNIST
|
||||
import os
|
||||
from torchvision import datasets, transforms
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
# transforms
|
||||
# prepare transforms standard to MNIST
|
||||
|
@ -204,19 +197,6 @@ Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIS
|
|||
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
|
||||
mnist_train = DataLoader(mnist_train, batch_size=64)
|
||||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
:skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not _TORCHVISION_AVAILABLE
|
||||
|
||||
Downloading ...
|
||||
Extracting ...
|
||||
Downloading ...
|
||||
Extracting ...
|
||||
Downloading ...
|
||||
Extracting ...
|
||||
Processing...
|
||||
Done!
|
||||
|
||||
You can use DataLoaders in three ways:
|
||||
|
||||
1. Pass DataLoaders to .fit()
|
||||
|
|
Loading…
Reference in New Issue