Fix misuse of transforms in docs (#3546)
* 📝 docs * 📝 docs * 📝 docs * 📝 docs * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
a9c0ed920a
commit
c46de8a3d4
|
@ -165,6 +165,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
|
|||
|
||||
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.
|
||||
|
||||
|
||||
---------------
|
||||
|
||||
LightningDataModule API
|
||||
|
@ -203,6 +204,7 @@ There are also data operations you might want to perform on every GPU. Use setup
|
|||
- count number of classes
|
||||
- build vocabulary
|
||||
- perform train/val/test splits
|
||||
- apply transforms (defined explicitly in your datamodule or assigned in init)
|
||||
- etc...
|
||||
|
||||
.. code-block:: python
|
||||
|
@ -216,13 +218,23 @@ There are also data operations you might want to perform on every GPU. Use setup
|
|||
|
||||
# Assign Train/val split(s) for use in Dataloaders
|
||||
if stage == 'fit' or stage is None:
|
||||
mnist_full = MNIST(self.data_dir, train=True, download=True)
|
||||
mnist_full = MNIST(
|
||||
self.data_dir,
|
||||
train=True,
|
||||
download=True,
|
||||
transform=self.transform
|
||||
)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
||||
self.dims = self.mnist_train[0][0].shape
|
||||
|
||||
# Assign Test split(s) for use in Dataloaders
|
||||
if stage == 'test' or stage is None:
|
||||
self.mnist_test = MNIST(self.data_dir, train=False, download=True)
|
||||
self.mnist_test = MNIST(
|
||||
self.data_dir,
|
||||
train=False,
|
||||
download=True,
|
||||
transform=self.transform
|
||||
)
|
||||
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
|
||||
|
||||
|
||||
|
@ -231,7 +243,7 @@ There are also data operations you might want to perform on every GPU. Use setup
|
|||
|
||||
train_dataloader
|
||||
^^^^^^^^^^^^^^^^
|
||||
Use this method to generate the train dataloader. This is also a good place to place default transformations.
|
||||
Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in ``setup``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -240,25 +252,12 @@ Use this method to generate the train dataloader. This is also a good place to p
|
|||
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def train_dataloader(self):
|
||||
transforms = transform_lib.Compose([
|
||||
transform_lib.ToTensor(),
|
||||
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
|
||||
])
|
||||
return DataLoader(self.train_dataset, transform=transforms, batch_size=64)
|
||||
return DataLoader(self.mnist_train, batch_size=64)
|
||||
|
||||
However, to decouple your data from transforms you can parametrize them via `__init__`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def __init__(self, train_transforms, val_transforms, test_transforms):
|
||||
self.train_transforms = train_transforms
|
||||
self.val_transforms = val_transforms
|
||||
self.test_transforms = test_transforms
|
||||
|
||||
val_dataloader
|
||||
^^^^^^^^^^^^^^
|
||||
Use this method to generate the val dataloader. This is also a good place to place default transformations.
|
||||
Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in ``setup``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -267,15 +266,12 @@ Use this method to generate the val dataloader. This is also a good place to pla
|
|||
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def val_dataloader(self):
|
||||
transforms = transform_lib.Compose([
|
||||
transform_lib.ToTensor(),
|
||||
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
|
||||
])
|
||||
return DataLoader(self.val_dataset, transform=transforms, batch_size=64)
|
||||
return DataLoader(self.mnist_val, batch_size=64)
|
||||
|
||||
|
||||
test_dataloader
|
||||
^^^^^^^^^^^^^^^
|
||||
Use this method to generate the test dataloader. This is also a good place to place default transformations.
|
||||
Use this method to generate the test dataloader. Usually you just wrap the dataset you defined in ``setup``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -284,11 +280,7 @@ Use this method to generate the test dataloader. This is also a good place to pl
|
|||
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def test_dataloader(self):
|
||||
transforms = transform_lib.Compose([
|
||||
transform_lib.ToTensor(),
|
||||
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
|
||||
])
|
||||
return DataLoader(self.test_dataset, transform=transforms, batch_size=64)
|
||||
return DataLoader(self.mnist_test, batch_size=64)
|
||||
|
||||
transfer_batch_to_device
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -306,6 +298,19 @@ Override to define how you want to move an arbitrary batch to a device
|
|||
batch['x'].to(device)
|
||||
return batch
|
||||
|
||||
|
||||
.. note:: To decouple your data from transforms you can parametrize them via `__init__`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MNISTDataModule(pl.LightningDataModule):
|
||||
def __init__(self, train_transforms, val_transforms, test_transforms):
|
||||
super().__init__()
|
||||
self.train_transforms = train_transforms
|
||||
self.val_transforms = val_transforms
|
||||
self.test_transforms = test_transforms
|
||||
|
||||
|
||||
------------------
|
||||
|
||||
Using a DataModule
|
||||
|
|
Loading…
Reference in New Issue