From 253e57c2c24ce4fbf5b100d1a84fea3a3ebe24a8 Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Thu, 11 Feb 2021 09:32:41 -0500 Subject: [PATCH] Feature: LightningDataModule.from_datasets(...) (#5133) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add class method * add tests * docstring * pep * Add type annotations Co-authored-by: Nicki Skafte * pep * fix import * remove num_workers inference * Update pytorch_lightning/core/datamodule.py Co-authored-by: Carlos MocholĂ­ * Update pytorch_lightning/core/datamodule.py Co-authored-by: Nicki Skafte * Update pytorch_lightning/core/datamodule.py Co-authored-by: Nicki Skafte * fix syntax * typing fix * list -> sequence * list -> sequence * missing import * fix test Co-authored-by: Nicki Skafte Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/core/datamodule.py | 70 +++++++++++++++++++++++++--- tests/core/test_datamodules.py | 61 ++++++++++++++++++++++-- 2 files changed, 120 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index f46c945a0d..ecf5a99e70 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -15,12 +15,13 @@ import functools import inspect +import os from abc import abstractmethod from argparse import ArgumentParser, Namespace -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union, Dict, Sequence, Mapping import torch -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.utilities import parsing, rank_zero_only @@ -35,9 +36,9 @@ class _DataModuleWrapper(type): def __call__(cls, *args, **kwargs): """A wrapper for LightningDataModule that: - 1. Runs user defined subclass's __init__ - 2. Assures prepare_data() runs on rank 0 - 3. Lets you check prepare_data and setup to see if they've been called + 1. Runs user defined subclass's __init__ + 2. Assures prepare_data() runs on rank 0 + 3. Lets you check prepare_data and setup to see if they've been called """ if not cls.__has_added_checks: cls.__has_added_checks = True @@ -266,8 +267,7 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - r"""Extends existing argparse by default `LightningDataModule` attributes. - """ + r"""Extends existing argparse by default `LightningDataModule` attributes.""" parser = ArgumentParser(parents=[parent_parser], add_help=False) added_args = [x.dest for x in parser._actions] @@ -364,3 +364,59 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp name_type_default.append((arg, arg_types, arg_default)) return name_type_default + + @classmethod + def from_datasets( + cls, + train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None, + val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, + test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None, + batch_size: int = 1, + num_workers: int = 0, + ): + r""" + Create an instance from torch.utils.data.Dataset. + + Args: + train_dataset: (optional) Dataset to be used for train_dataloader() + val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader() + test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader() + batch_size: Batch size to use for each dataloader. Default is 1. + num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. + number of CPUs available. + + """ + def dataloader(ds, shuffle=False): + return DataLoader( + ds, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=True, + ) + + def train_dataloader(): + if isinstance(train_dataset, Mapping): + return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()} + if isinstance(train_dataset, Sequence): + return [dataloader(ds, shuffle=True) for ds in train_dataset] + return dataloader(train_dataset, shuffle=True) + + def val_dataloader(): + if isinstance(val_dataset, Sequence): + return [dataloader(ds) for ds in val_dataset] + return dataloader(val_dataset) + + def test_dataloader(): + if isinstance(test_dataset, Sequence): + return [dataloader(ds) for ds in test_dataset] + return dataloader(test_dataset) + + datamodule = cls() + if train_dataset is not None: + datamodule.train_dataloader = train_dataloader + if val_dataset is not None: + datamodule.val_dataloader = val_dataloader + if test_dataset is not None: + datamodule.test_dataloader = test_dataloader + return datamodule diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 76fdca0fed..a5c7c1cab3 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -13,7 +13,7 @@ # limitations under the License. import pickle from argparse import ArgumentParser -from typing import Any, Dict +from typing import Any, Dict, Optional from unittest.mock import MagicMock import pytest @@ -419,7 +419,6 @@ def test_full_loop_dp(tmpdir): def test_dm_transfer_batch_to_device(tmpdir): class CustomBatch: - def __init__(self, data): self.samples = data[0] self.targets = data[1] @@ -452,6 +451,28 @@ def test_dm_transfer_batch_to_device(tmpdir): assert batch_gpu.samples.device == batch_gpu.targets.device == expected +class CustomMNISTDataModule(LightningDataModule): + def __init__(self, data_dir: str = "./"): + super().__init__() + self.data_dir = data_dir + self._epochs_called_for = [] + + def prepare_data(self): + TrialMNIST(self.data_dir, train=True, download=True) + + def setup(self, stage: Optional[str] = None): + + mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) + self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) + self.dims = self.mnist_train[0][0].shape + + def train_dataloader(self): + assert self.trainer.current_epoch not in self._epochs_called_for + self._epochs_called_for.append(self.trainer.current_epoch) + + return DataLoader(self.mnist_train, batch_size=4) + + def test_dm_reload_dataloaders_every_epoch(tmpdir): """Test datamodule, where trainer argument reload_dataloaders_every_epoch is set to True/False""" @@ -483,5 +504,37 @@ def test_dm_reload_dataloaders_every_epoch(tmpdir): limit_train_batches=0.01, reload_dataloaders_every_epoch=True, ) - results = trainer.fit(model, dm) - assert results + trainer.fit(model, dm) + + +class DummyDS(torch.utils.data.Dataset): + def __getitem__(self, index): + return 1 + + def __len__(self): + return 100 + + +def test_dm_init_from_datasets(tmpdir): + + train_ds = DummyDS() + valid_ds = DummyDS() + test_ds = DummyDS() + + valid_dss = [DummyDS(), DummyDS()] + test_dss = [DummyDS(), DummyDS()] + + dm = LightningDataModule.from_datasets(train_ds, batch_size=4, num_workers=0) + assert torch.all(next(iter(dm.train_dataloader())) == torch.ones(4)) + assert dm.val_dataloader() is None + assert dm.test_dataloader() is None + + dm = LightningDataModule.from_datasets(train_ds, valid_ds, test_ds, batch_size=4, num_workers=0) + assert torch.all(next(iter(dm.val_dataloader())) == torch.ones(4)) + assert torch.all(next(iter(dm.test_dataloader())) == torch.ones(4)) + + dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0) + assert torch.all(next(iter(dm.val_dataloader()[0])) == torch.ones(4)) + assert torch.all(next(iter(dm.val_dataloader()[1])) == torch.ones(4)) + assert torch.all(next(iter(dm.test_dataloader()[0])) == torch.ones(4)) + assert torch.all(next(iter(dm.test_dataloader()[1])) == torch.ones(4))