107 lines
3.3 KiB
ReStructuredText
107 lines
3.3 KiB
ReStructuredText
.. testsetup:: *
|
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
|
|
.. _multiple_loaders:
|
|
|
|
Multiple Datasets
|
|
=================
|
|
Lightning supports multiple dataloaders in a few ways.
|
|
|
|
1. Create a dataloader that iterates multiple datasets under the hood.
|
|
2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning
|
|
will automatically combine the batches from different loaders.
|
|
3. In the validation and test loop you also have the option to return multiple dataloaders
|
|
which lightning will call sequentially.
|
|
|
|
----------
|
|
|
|
Multiple training dataloaders
|
|
-----------------------------
|
|
For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class
|
|
which wraps your multiple dataloaders (this of course also works for testing and validation
|
|
dataloaders).
|
|
|
|
(`reference <https://discuss.pytorch.org/t/train-simultaneously-on-two-datasets/649/2>`_)
|
|
|
|
.. testcode::
|
|
|
|
class ConcatDataset(torch.utils.data.Dataset):
|
|
def __init__(self, *datasets):
|
|
self.datasets = datasets
|
|
|
|
def __getitem__(self, i):
|
|
return tuple(d[i] for d in self.datasets)
|
|
|
|
def __len__(self):
|
|
return min(len(d) for d in self.datasets)
|
|
|
|
class LitModel(LightningModule):
|
|
|
|
def train_dataloader(self):
|
|
concat_dataset = ConcatDataset(
|
|
datasets.ImageFolder(traindir_A),
|
|
datasets.ImageFolder(traindir_B)
|
|
)
|
|
|
|
loader = torch.utils.data.DataLoader(
|
|
concat_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=args.workers,
|
|
pin_memory=True
|
|
)
|
|
return loader
|
|
|
|
def val_dataloader(self):
|
|
# SAME
|
|
...
|
|
|
|
def test_dataloader(self):
|
|
# SAME
|
|
...
|
|
|
|
However, with lightning you can also return multiple loaders and lightning will take care of batch combination.
|
|
|
|
For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer.Trainer.multiple_trainloader_mode`
|
|
|
|
.. testcode::
|
|
|
|
class LitModel(LightningModule):
|
|
|
|
def train_dataloader(self):
|
|
|
|
loader_a = torch.utils.data.DataLoader(range(6), batch_size=4)
|
|
loader_b = torch.utils.data.DataLoader(range(15), batch_size=5)
|
|
|
|
# pass loaders as a dict. This will create batches like this:
|
|
# {'a': batch from loader_a, 'b': batch from loader_b}
|
|
loaders = {'a': loader_a,
|
|
'b': loader_b}
|
|
|
|
# OR:
|
|
# pass loaders as sequence. This will create batches like this:
|
|
# [batch from loader_a, batch from loader_b]
|
|
loaders = [loader_a, loader_b]
|
|
|
|
return loaders
|
|
|
|
----------
|
|
|
|
Test/Val dataloaders
|
|
--------------------
|
|
For validation and test dataloaders, lightning also gives you the additional
|
|
option of passing multiple dataloaders back from each call.
|
|
|
|
See the following for more details:
|
|
|
|
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.val_dataloader`
|
|
- :meth:`~pytorch_lightning.core.datamodule.LightningDataModule.test_dataloader`
|
|
|
|
.. testcode::
|
|
|
|
def val_dataloader(self):
|
|
loader_1 = Dataloader()
|
|
loader_2 = Dataloader()
|
|
return [loader_1, loader_2]
|