Feature: LightningDataModule.from_datasets(...) (#5133)
* add class method * add tests * docstring * pep * Add type annotations Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * pep * fix import * remove num_workers inference * Update pytorch_lightning/core/datamodule.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/core/datamodule.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Update pytorch_lightning/core/datamodule.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * fix syntax * typing fix * list -> sequence * list -> sequence * missing import * fix test Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
31da16344c
commit
253e57c2c2
|
@ -15,12 +15,13 @@
|
|||
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict, Sequence, Mapping
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
|
||||
from pytorch_lightning.utilities import parsing, rank_zero_only
|
||||
|
@ -266,8 +267,7 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
|
|||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
||||
r"""Extends existing argparse by default `LightningDataModule` attributes.
|
||||
"""
|
||||
r"""Extends existing argparse by default `LightningDataModule` attributes."""
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
added_args = [x.dest for x in parser._actions]
|
||||
|
||||
|
@ -364,3 +364,59 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
|
|||
name_type_default.append((arg, arg_types, arg_default))
|
||||
|
||||
return name_type_default
|
||||
|
||||
@classmethod
|
||||
def from_datasets(
|
||||
cls,
|
||||
train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
|
||||
val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
|
||||
test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
|
||||
batch_size: int = 1,
|
||||
num_workers: int = 0,
|
||||
):
|
||||
r"""
|
||||
Create an instance from torch.utils.data.Dataset.
|
||||
|
||||
Args:
|
||||
train_dataset: (optional) Dataset to be used for train_dataloader()
|
||||
val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader()
|
||||
test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader()
|
||||
batch_size: Batch size to use for each dataloader. Default is 1.
|
||||
num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
|
||||
number of CPUs available.
|
||||
|
||||
"""
|
||||
def dataloader(ds, shuffle=False):
|
||||
return DataLoader(
|
||||
ds,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
def train_dataloader():
|
||||
if isinstance(train_dataset, Mapping):
|
||||
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
|
||||
if isinstance(train_dataset, Sequence):
|
||||
return [dataloader(ds, shuffle=True) for ds in train_dataset]
|
||||
return dataloader(train_dataset, shuffle=True)
|
||||
|
||||
def val_dataloader():
|
||||
if isinstance(val_dataset, Sequence):
|
||||
return [dataloader(ds) for ds in val_dataset]
|
||||
return dataloader(val_dataset)
|
||||
|
||||
def test_dataloader():
|
||||
if isinstance(test_dataset, Sequence):
|
||||
return [dataloader(ds) for ds in test_dataset]
|
||||
return dataloader(test_dataset)
|
||||
|
||||
datamodule = cls()
|
||||
if train_dataset is not None:
|
||||
datamodule.train_dataloader = train_dataloader
|
||||
if val_dataset is not None:
|
||||
datamodule.val_dataloader = val_dataloader
|
||||
if test_dataset is not None:
|
||||
datamodule.test_dataloader = test_dataloader
|
||||
return datamodule
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import pickle
|
||||
from argparse import ArgumentParser
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -419,7 +419,6 @@ def test_full_loop_dp(tmpdir):
|
|||
def test_dm_transfer_batch_to_device(tmpdir):
|
||||
|
||||
class CustomBatch:
|
||||
|
||||
def __init__(self, data):
|
||||
self.samples = data[0]
|
||||
self.targets = data[1]
|
||||
|
@ -452,6 +451,28 @@ def test_dm_transfer_batch_to_device(tmpdir):
|
|||
assert batch_gpu.samples.device == batch_gpu.targets.device == expected
|
||||
|
||||
|
||||
class CustomMNISTDataModule(LightningDataModule):
|
||||
def __init__(self, data_dir: str = "./"):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self._epochs_called_for = []
|
||||
|
||||
def prepare_data(self):
|
||||
TrialMNIST(self.data_dir, train=True, download=True)
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
|
||||
self.dims = self.mnist_train[0][0].shape
|
||||
|
||||
def train_dataloader(self):
|
||||
assert self.trainer.current_epoch not in self._epochs_called_for
|
||||
self._epochs_called_for.append(self.trainer.current_epoch)
|
||||
|
||||
return DataLoader(self.mnist_train, batch_size=4)
|
||||
|
||||
|
||||
def test_dm_reload_dataloaders_every_epoch(tmpdir):
|
||||
"""Test datamodule, where trainer argument
|
||||
reload_dataloaders_every_epoch is set to True/False"""
|
||||
|
@ -483,5 +504,37 @@ def test_dm_reload_dataloaders_every_epoch(tmpdir):
|
|||
limit_train_batches=0.01,
|
||||
reload_dataloaders_every_epoch=True,
|
||||
)
|
||||
results = trainer.fit(model, dm)
|
||||
assert results
|
||||
trainer.fit(model, dm)
|
||||
|
||||
|
||||
class DummyDS(torch.utils.data.Dataset):
|
||||
def __getitem__(self, index):
|
||||
return 1
|
||||
|
||||
def __len__(self):
|
||||
return 100
|
||||
|
||||
|
||||
def test_dm_init_from_datasets(tmpdir):
|
||||
|
||||
train_ds = DummyDS()
|
||||
valid_ds = DummyDS()
|
||||
test_ds = DummyDS()
|
||||
|
||||
valid_dss = [DummyDS(), DummyDS()]
|
||||
test_dss = [DummyDS(), DummyDS()]
|
||||
|
||||
dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0)
|
||||
assert torch.all(next(iter(dm.train_dataloader())) == torch.ones(4))
|
||||
assert dm.val_dataloader() is None
|
||||
assert dm.test_dataloader() is None
|
||||
|
||||
dm = LightningDataModule.from_datasets(train_ds, valid_ds, test_ds, batch_size=4, num_workers=0)
|
||||
assert torch.all(next(iter(dm.val_dataloader())) == torch.ones(4))
|
||||
assert torch.all(next(iter(dm.test_dataloader())) == torch.ones(4))
|
||||
|
||||
dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0)
|
||||
assert torch.all(next(iter(dm.val_dataloader()[0])) == torch.ones(4))
|
||||
assert torch.all(next(iter(dm.val_dataloader()[1])) == torch.ones(4))
|
||||
assert torch.all(next(iter(dm.test_dataloader()[0])) == torch.ones(4))
|
||||
assert torch.all(next(iter(dm.test_dataloader()[1])) == torch.ones(4))
|
||||
|
|
Loading…
Reference in New Issue