Feature: LightningDataModule.from_datasets(...) (#5133)
* add class method * add tests * docstring * pep * Add type annotations Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * pep * fix import * remove num_workers inference * Update pytorch_lightning/core/datamodule.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/core/datamodule.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Update pytorch_lightning/core/datamodule.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * fix syntax * typing fix * list -> sequence * list -> sequence * missing import * fix test Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
31da16344c
commit
253e57c2c2
|
@ -15,12 +15,13 @@
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union, Dict, Sequence, Mapping
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||||
from pytorch_lightning.utilities import parsing, rank_zero_only
|
from pytorch_lightning.utilities import parsing, rank_zero_only
|
||||||
|
@ -266,8 +267,7 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
||||||
r"""Extends existing argparse by default `LightningDataModule` attributes.
|
r"""Extends existing argparse by default `LightningDataModule` attributes."""
|
||||||
"""
|
|
||||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||||
added_args = [x.dest for x in parser._actions]
|
added_args = [x.dest for x in parser._actions]
|
||||||
|
|
||||||
|
@ -364,3 +364,59 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
|
||||||
name_type_default.append((arg, arg_types, arg_default))
|
name_type_default.append((arg, arg_types, arg_default))
|
||||||
|
|
||||||
return name_type_default
|
return name_type_default
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_datasets(
|
||||||
|
cls,
|
||||||
|
train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
|
||||||
|
val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
|
||||||
|
test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
|
||||||
|
batch_size: int = 1,
|
||||||
|
num_workers: int = 0,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Create an instance from torch.utils.data.Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_dataset: (optional) Dataset to be used for train_dataloader()
|
||||||
|
val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader()
|
||||||
|
test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader()
|
||||||
|
batch_size: Batch size to use for each dataloader. Default is 1.
|
||||||
|
num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
|
||||||
|
number of CPUs available.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def dataloader(ds, shuffle=False):
|
||||||
|
return DataLoader(
|
||||||
|
ds,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloader():
|
||||||
|
if isinstance(train_dataset, Mapping):
|
||||||
|
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
|
||||||
|
if isinstance(train_dataset, Sequence):
|
||||||
|
return [dataloader(ds, shuffle=True) for ds in train_dataset]
|
||||||
|
return dataloader(train_dataset, shuffle=True)
|
||||||
|
|
||||||
|
def val_dataloader():
|
||||||
|
if isinstance(val_dataset, Sequence):
|
||||||
|
return [dataloader(ds) for ds in val_dataset]
|
||||||
|
return dataloader(val_dataset)
|
||||||
|
|
||||||
|
def test_dataloader():
|
||||||
|
if isinstance(test_dataset, Sequence):
|
||||||
|
return [dataloader(ds) for ds in test_dataset]
|
||||||
|
return dataloader(test_dataset)
|
||||||
|
|
||||||
|
datamodule = cls()
|
||||||
|
if train_dataset is not None:
|
||||||
|
datamodule.train_dataloader = train_dataloader
|
||||||
|
if val_dataset is not None:
|
||||||
|
datamodule.val_dataloader = val_dataloader
|
||||||
|
if test_dataset is not None:
|
||||||
|
datamodule.test_dataloader = test_dataloader
|
||||||
|
return datamodule
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import pickle
|
import pickle
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -419,7 +419,6 @@ def test_full_loop_dp(tmpdir):
|
||||||
def test_dm_transfer_batch_to_device(tmpdir):
|
def test_dm_transfer_batch_to_device(tmpdir):
|
||||||
|
|
||||||
class CustomBatch:
|
class CustomBatch:
|
||||||
|
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
self.samples = data[0]
|
self.samples = data[0]
|
||||||
self.targets = data[1]
|
self.targets = data[1]
|
||||||
|
@ -452,6 +451,28 @@ def test_dm_transfer_batch_to_device(tmpdir):
|
||||||
assert batch_gpu.samples.device == batch_gpu.targets.device == expected
|
assert batch_gpu.samples.device == batch_gpu.targets.device == expected
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMNISTDataModule(LightningDataModule):
|
||||||
|
def __init__(self, data_dir: str = "./"):
|
||||||
|
super().__init__()
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self._epochs_called_for = []
|
||||||
|
|
||||||
|
def prepare_data(self):
|
||||||
|
TrialMNIST(self.data_dir, train=True, download=True)
|
||||||
|
|
||||||
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
|
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
|
||||||
|
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
|
||||||
|
self.dims = self.mnist_train[0][0].shape
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
assert self.trainer.current_epoch not in self._epochs_called_for
|
||||||
|
self._epochs_called_for.append(self.trainer.current_epoch)
|
||||||
|
|
||||||
|
return DataLoader(self.mnist_train, batch_size=4)
|
||||||
|
|
||||||
|
|
||||||
def test_dm_reload_dataloaders_every_epoch(tmpdir):
|
def test_dm_reload_dataloaders_every_epoch(tmpdir):
|
||||||
"""Test datamodule, where trainer argument
|
"""Test datamodule, where trainer argument
|
||||||
reload_dataloaders_every_epoch is set to True/False"""
|
reload_dataloaders_every_epoch is set to True/False"""
|
||||||
|
@ -483,5 +504,37 @@ def test_dm_reload_dataloaders_every_epoch(tmpdir):
|
||||||
limit_train_batches=0.01,
|
limit_train_batches=0.01,
|
||||||
reload_dataloaders_every_epoch=True,
|
reload_dataloaders_every_epoch=True,
|
||||||
)
|
)
|
||||||
results = trainer.fit(model, dm)
|
trainer.fit(model, dm)
|
||||||
assert results
|
|
||||||
|
|
||||||
|
class DummyDS(torch.utils.data.Dataset):
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 100
|
||||||
|
|
||||||
|
|
||||||
|
def test_dm_init_from_datasets(tmpdir):
|
||||||
|
|
||||||
|
train_ds = DummyDS()
|
||||||
|
valid_ds = DummyDS()
|
||||||
|
test_ds = DummyDS()
|
||||||
|
|
||||||
|
valid_dss = [DummyDS(), DummyDS()]
|
||||||
|
test_dss = [DummyDS(), DummyDS()]
|
||||||
|
|
||||||
|
dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0)
|
||||||
|
assert torch.all(next(iter(dm.train_dataloader())) == torch.ones(4))
|
||||||
|
assert dm.val_dataloader() is None
|
||||||
|
assert dm.test_dataloader() is None
|
||||||
|
|
||||||
|
dm = LightningDataModule.from_datasets(train_ds, valid_ds, test_ds, batch_size=4, num_workers=0)
|
||||||
|
assert torch.all(next(iter(dm.val_dataloader())) == torch.ones(4))
|
||||||
|
assert torch.all(next(iter(dm.test_dataloader())) == torch.ones(4))
|
||||||
|
|
||||||
|
dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0)
|
||||||
|
assert torch.all(next(iter(dm.val_dataloader()[0])) == torch.ones(4))
|
||||||
|
assert torch.all(next(iter(dm.val_dataloader()[1])) == torch.ones(4))
|
||||||
|
assert torch.all(next(iter(dm.test_dataloader()[0])) == torch.ones(4))
|
||||||
|
assert torch.all(next(iter(dm.test_dataloader()[1])) == torch.ones(4))
|
||||||
|
|
Loading…
Reference in New Issue