diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index ecf5a99e70..d0e1725b2c 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -15,10 +15,9 @@ import functools import inspect -import os from abc import abstractmethod from argparse import ArgumentParser, Namespace -from typing import Any, List, Optional, Tuple, Union, Dict, Sequence, Mapping +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch.utils.data import DataLoader, Dataset @@ -382,10 +381,11 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader() test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader() batch_size: Batch size to use for each dataloader. Default is 1. - num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. - number of CPUs available. + num_workers: Number of subprocesses to use for data loading. 0 means that the + data will be loaded in the main process. Number of CPUs available. """ + def dataloader(ds, shuffle=False): return DataLoader( ds, @@ -399,7 +399,7 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp if isinstance(train_dataset, Mapping): return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()} if isinstance(train_dataset, Sequence): - return [dataloader(ds, shuffle=True) for ds in train_dataset] + return [dataloader(ds, shuffle=True) for ds in train_dataset] return dataloader(train_dataset, shuffle=True) def val_dataloader(): diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index a5c7c1cab3..a83a6a41c9 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -13,7 +13,7 @@ # limitations under the License. import pickle from argparse import ArgumentParser -from typing import Any, Dict, Optional +from typing import Any, Dict from unittest.mock import MagicMock import pytest @@ -381,8 +381,8 @@ def test_full_loop_dp(tmpdir): def training_step(self, batch, batch_idx): _, y = batch out = self._step(batch, batch_idx) - out['loss'] = F.cross_entropy(out['logits'], y) - return out + loss = F.cross_entropy(out['logits'], y) + return loss def validation_step(self, batch, batch_idx): return self._step(batch, batch_idx) @@ -419,6 +419,7 @@ def test_full_loop_dp(tmpdir): def test_dm_transfer_batch_to_device(tmpdir): class CustomBatch: + def __init__(self, data): self.samples = data[0] self.targets = data[1] @@ -451,28 +452,6 @@ def test_dm_transfer_batch_to_device(tmpdir): assert batch_gpu.samples.device == batch_gpu.targets.device == expected -class CustomMNISTDataModule(LightningDataModule): - def __init__(self, data_dir: str = "./"): - super().__init__() - self.data_dir = data_dir - self._epochs_called_for = [] - - def prepare_data(self): - TrialMNIST(self.data_dir, train=True, download=True) - - def setup(self, stage: Optional[str] = None): - - mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True) - self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64]) - self.dims = self.mnist_train[0][0].shape - - def train_dataloader(self): - assert self.trainer.current_epoch not in self._epochs_called_for - self._epochs_called_for.append(self.trainer.current_epoch) - - return DataLoader(self.mnist_train, batch_size=4) - - def test_dm_reload_dataloaders_every_epoch(tmpdir): """Test datamodule, where trainer argument reload_dataloaders_every_epoch is set to True/False""" @@ -508,6 +487,7 @@ def test_dm_reload_dataloaders_every_epoch(tmpdir): class DummyDS(torch.utils.data.Dataset): + def __getitem__(self, index): return 1 diff --git a/tests/helpers/simple_models.py b/tests/helpers/simple_models.py old mode 100644 new mode 100755 index 9288a3c802..c33c470d04 --- a/tests/helpers/simple_models.py +++ b/tests/helpers/simple_models.py @@ -21,7 +21,8 @@ from pytorch_lightning.metrics import Accuracy, MeanSquaredError class ClassificationModel(LightningModule): - def __init__(self): + def __init__(self, lr=0.01): + self.lr = lr super().__init__() for i in range(3): setattr(self, f"layer_{i}", nn.Linear(32, 32)) @@ -44,7 +45,7 @@ class ClassificationModel(LightningModule): return logits def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=0.01) + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return [optimizer], [] def training_step(self, batch, batch_idx): diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py old mode 100644 new mode 100755 index 114ebf3368..28e3e65a87 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -208,8 +208,8 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): def training_step(self, batch, batch_idx): _, y = batch out = self._step(batch, batch_idx) - out['loss'] = F.cross_entropy(out['logits'], y) - return out + loss = F.cross_entropy(out['logits'], y) + return loss def validation_step(self, batch, batch_idx): return self._step(batch, batch_idx) @@ -221,7 +221,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir): self.log('val_acc', self.valid_acc(outputs['logits'], outputs['y'])) dm = ClassifDataModule() - model = CustomClassificationModelDP() + model = CustomClassificationModelDP(lr=0.1) # exp file to get meta logger = tutils.get_default_logger(tmpdir)