Fix misuse of transforms in docs (#3546)

* 📝 docs

* 📝 docs

* 📝 docs

* 📝 docs

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Nathan Raw 2020-09-18 06:49:45 -06:00 committed by GitHub
parent a9c0ed920a
commit c46de8a3d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 34 additions and 29 deletions

View File

@ -165,6 +165,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.
---------------
LightningDataModule API
@ -203,6 +204,7 @@ There are also data operations you might want to perform on every GPU. Use setup
- count number of classes
- build vocabulary
- perform train/val/test splits
- apply transforms (defined explicitly in your datamodule or assigned in init)
- etc...
.. code-block:: python
@ -216,13 +218,23 @@ There are also data operations you might want to perform on every GPU. Use setup
# Assign Train/val split(s) for use in Dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, download=True)
mnist_full = MNIST(
self.data_dir,
train=True,
download=True,
transform=self.transform
)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.dims = self.mnist_train[0][0].shape
# Assign Test split(s) for use in Dataloaders
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, download=True)
self.mnist_test = MNIST(
self.data_dir,
train=False,
download=True,
transform=self.transform
)
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
@ -231,7 +243,7 @@ There are also data operations you might want to perform on every GPU. Use setup
train_dataloader
^^^^^^^^^^^^^^^^
Use this method to generate the train dataloader. This is also a good place to place default transformations.
Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in ``setup``.
.. code-block:: python
@ -240,25 +252,12 @@ Use this method to generate the train dataloader. This is also a good place to p
class MNISTDataModule(pl.LightningDataModule):
def train_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.train_dataset, transform=transforms, batch_size=64)
return DataLoader(self.mnist_train, batch_size=64)
However, to decouple your data from transforms you can parametrize them via `__init__`.
.. code-block:: python
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, train_transforms, val_transforms, test_transforms):
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.test_transforms = test_transforms
val_dataloader
^^^^^^^^^^^^^^
Use this method to generate the val dataloader. This is also a good place to place default transformations.
Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in ``setup``.
.. code-block:: python
@ -267,15 +266,12 @@ Use this method to generate the val dataloader. This is also a good place to pla
class MNISTDataModule(pl.LightningDataModule):
def val_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.val_dataset, transform=transforms, batch_size=64)
return DataLoader(self.mnist_val, batch_size=64)
test_dataloader
^^^^^^^^^^^^^^^
Use this method to generate the test dataloader. This is also a good place to place default transformations.
Use this method to generate the test dataloader. Usually you just wrap the dataset you defined in ``setup``.
.. code-block:: python
@ -284,11 +280,7 @@ Use this method to generate the test dataloader. This is also a good place to pl
class MNISTDataModule(pl.LightningDataModule):
def test_dataloader(self):
transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
return DataLoader(self.mnist_test, batch_size=64)
transfer_batch_to_device
^^^^^^^^^^^^^^^^^^^^^^^^
@ -306,6 +298,19 @@ Override to define how you want to move an arbitrary batch to a device
batch['x'].to(device)
return batch
.. note:: To decouple your data from transforms you can parametrize them via `__init__`.
.. code-block:: python
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, train_transforms, val_transforms, test_transforms):
super().__init__()
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.test_transforms = test_transforms
------------------
Using a DataModule