Feature: LightningDataModule.from_datasets(...) (#5133)

* add class method

* add tests

* docstring

* pep

* Add type annotations

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* pep

* fix import

* remove num_workers inference

* Update pytorch_lightning/core/datamodule.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/core/datamodule.py

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update pytorch_lightning/core/datamodule.py

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* fix syntax

* typing fix

* list -> sequence

* list -> sequence

* missing import

* fix test

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Teddy Koker 2021-02-11 09:32:41 -05:00 committed by GitHub
parent 31da16344c
commit 253e57c2c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 120 additions and 11 deletions

View File

@ -15,12 +15,13 @@
import functools import functools
import inspect import inspect
import os
from abc import abstractmethod from abc import abstractmethod
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union, Dict, Sequence, Mapping
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities import parsing, rank_zero_only from pytorch_lightning.utilities import parsing, rank_zero_only
@ -266,8 +267,7 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
@classmethod @classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `LightningDataModule` attributes. r"""Extends existing argparse by default `LightningDataModule` attributes."""
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False) parser = ArgumentParser(parents=[parent_parser], add_help=False)
added_args = [x.dest for x in parser._actions] added_args = [x.dest for x in parser._actions]
@ -364,3 +364,59 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
name_type_default.append((arg, arg_types, arg_default)) name_type_default.append((arg, arg_types, arg_default))
return name_type_default return name_type_default
@classmethod
def from_datasets(
cls,
train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
batch_size: int = 1,
num_workers: int = 0,
):
r"""
Create an instance from torch.utils.data.Dataset.
Args:
train_dataset: (optional) Dataset to be used for train_dataloader()
val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader()
test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader()
batch_size: Batch size to use for each dataloader. Default is 1.
num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
number of CPUs available.
"""
def dataloader(ds, shuffle=False):
return DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True,
)
def train_dataloader():
if isinstance(train_dataset, Mapping):
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
if isinstance(train_dataset, Sequence):
return [dataloader(ds, shuffle=True) for ds in train_dataset]
return dataloader(train_dataset, shuffle=True)
def val_dataloader():
if isinstance(val_dataset, Sequence):
return [dataloader(ds) for ds in val_dataset]
return dataloader(val_dataset)
def test_dataloader():
if isinstance(test_dataset, Sequence):
return [dataloader(ds) for ds in test_dataset]
return dataloader(test_dataset)
datamodule = cls()
if train_dataset is not None:
datamodule.train_dataloader = train_dataloader
if val_dataset is not None:
datamodule.val_dataloader = val_dataloader
if test_dataset is not None:
datamodule.test_dataloader = test_dataloader
return datamodule

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import pickle import pickle
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import Any, Dict from typing import Any, Dict, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -419,7 +419,6 @@ def test_full_loop_dp(tmpdir):
def test_dm_transfer_batch_to_device(tmpdir): def test_dm_transfer_batch_to_device(tmpdir):
class CustomBatch: class CustomBatch:
def __init__(self, data): def __init__(self, data):
self.samples = data[0] self.samples = data[0]
self.targets = data[1] self.targets = data[1]
@ -452,6 +451,28 @@ def test_dm_transfer_batch_to_device(tmpdir):
assert batch_gpu.samples.device == batch_gpu.targets.device == expected assert batch_gpu.samples.device == batch_gpu.targets.device == expected
class CustomMNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self._epochs_called_for = []
def prepare_data(self):
TrialMNIST(self.data_dir, train=True, download=True)
def setup(self, stage: Optional[str] = None):
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
self.dims = self.mnist_train[0][0].shape
def train_dataloader(self):
assert self.trainer.current_epoch not in self._epochs_called_for
self._epochs_called_for.append(self.trainer.current_epoch)
return DataLoader(self.mnist_train, batch_size=4)
def test_dm_reload_dataloaders_every_epoch(tmpdir): def test_dm_reload_dataloaders_every_epoch(tmpdir):
"""Test datamodule, where trainer argument """Test datamodule, where trainer argument
reload_dataloaders_every_epoch is set to True/False""" reload_dataloaders_every_epoch is set to True/False"""
@ -483,5 +504,37 @@ def test_dm_reload_dataloaders_every_epoch(tmpdir):
limit_train_batches=0.01, limit_train_batches=0.01,
reload_dataloaders_every_epoch=True, reload_dataloaders_every_epoch=True,
) )
results = trainer.fit(model, dm) trainer.fit(model, dm)
assert results
class DummyDS(torch.utils.data.Dataset):
def __getitem__(self, index):
return 1
def __len__(self):
return 100
def test_dm_init_from_datasets(tmpdir):
train_ds = DummyDS()
valid_ds = DummyDS()
test_ds = DummyDS()
valid_dss = [DummyDS(), DummyDS()]
test_dss = [DummyDS(), DummyDS()]
dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0)
assert torch.all(next(iter(dm.train_dataloader())) == torch.ones(4))
assert dm.val_dataloader() is None
assert dm.test_dataloader() is None
dm = LightningDataModule.from_datasets(train_ds, valid_ds, test_ds, batch_size=4, num_workers=0)
assert torch.all(next(iter(dm.val_dataloader())) == torch.ones(4))
assert torch.all(next(iter(dm.test_dataloader())) == torch.ones(4))
dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0)
assert torch.all(next(iter(dm.val_dataloader()[0])) == torch.ones(4))
assert torch.all(next(iter(dm.val_dataloader()[1])) == torch.ones(4))
assert torch.all(next(iter(dm.test_dataloader()[0])) == torch.ones(4))
assert torch.all(next(iter(dm.test_dataloader()[1])) == torch.ones(4))