diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 998ccf49c5..e1fbeabd96 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -165,6 +165,7 @@ Here's a more realistic, complex DataModule that shows how much more reusable th .. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. + --------------- LightningDataModule API @@ -203,6 +204,7 @@ There are also data operations you might want to perform on every GPU. Use setup - count number of classes - build vocabulary - perform train/val/test splits +- apply transforms (defined explicitly in your datamodule or assigned in init) - etc... .. code-block:: python @@ -216,13 +218,23 @@ There are also data operations you might want to perform on every GPU. Use setup # Assign Train/val split(s) for use in Dataloaders if stage == 'fit' or stage is None: - mnist_full = MNIST(self.data_dir, train=True, download=True) + mnist_full = MNIST( + self.data_dir, + train=True, + download=True, + transform=self.transform + ) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) self.dims = self.mnist_train[0][0].shape # Assign Test split(s) for use in Dataloaders if stage == 'test' or stage is None: - self.mnist_test = MNIST(self.data_dir, train=False, download=True) + self.mnist_test = MNIST( + self.data_dir, + train=False, + download=True, + transform=self.transform + ) self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) @@ -231,7 +243,7 @@ There are also data operations you might want to perform on every GPU. Use setup train_dataloader ^^^^^^^^^^^^^^^^ -Use this method to generate the train dataloader. This is also a good place to place default transformations. +Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in ``setup``. .. code-block:: python @@ -240,25 +252,12 @@ Use this method to generate the train dataloader. This is also a good place to p class MNISTDataModule(pl.LightningDataModule): def train_dataloader(self): - transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ]) - return DataLoader(self.train_dataset, transform=transforms, batch_size=64) + return DataLoader(self.mnist_train, batch_size=64) -However, to decouple your data from transforms you can parametrize them via `__init__`. - -.. code-block:: python - - class MNISTDataModule(pl.LightningDataModule): - def __init__(self, train_transforms, val_transforms, test_transforms): - self.train_transforms = train_transforms - self.val_transforms = val_transforms - self.test_transforms = test_transforms val_dataloader ^^^^^^^^^^^^^^ -Use this method to generate the val dataloader. This is also a good place to place default transformations. +Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in ``setup``. .. code-block:: python @@ -267,15 +266,12 @@ Use this method to generate the val dataloader. This is also a good place to pla class MNISTDataModule(pl.LightningDataModule): def val_dataloader(self): - transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ]) - return DataLoader(self.val_dataset, transform=transforms, batch_size=64) + return DataLoader(self.mnist_val, batch_size=64) + test_dataloader ^^^^^^^^^^^^^^^ -Use this method to generate the test dataloader. This is also a good place to place default transformations. +Use this method to generate the test dataloader. Usually you just wrap the dataset you defined in ``setup``. .. code-block:: python @@ -284,11 +280,7 @@ Use this method to generate the test dataloader. This is also a good place to pl class MNISTDataModule(pl.LightningDataModule): def test_dataloader(self): - transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ]) - return DataLoader(self.test_dataset, transform=transforms, batch_size=64) + return DataLoader(self.mnist_test, batch_size=64) transfer_batch_to_device ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -306,6 +298,19 @@ Override to define how you want to move an arbitrary batch to a device batch['x'].to(device) return batch + +.. note:: To decouple your data from transforms you can parametrize them via `__init__`. + +.. code-block:: python + + class MNISTDataModule(pl.LightningDataModule): + def __init__(self, train_transforms, val_transforms, test_transforms): + super().__init__() + self.train_transforms = train_transforms + self.val_transforms = val_transforms + self.test_transforms = test_transforms + + ------------------ Using a DataModule