Fix: Failing test in data_modules(dp) (#5924)

* Update test_datamodules.py

* fix code format issue

* fix test restore

* fix code format issue
This commit is contained in:
Kaushik B 2021-02-11 23:02:46 +05:30 committed by GitHub
parent e676ff96b1
commit 4857546c25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 35 deletions

View File

@ -15,10 +15,9 @@
import functools
import inspect
import os
from abc import abstractmethod
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional, Tuple, Union, Dict, Sequence, Mapping
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
import torch
from torch.utils.data import DataLoader, Dataset
@ -382,10 +381,11 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader()
test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader()
batch_size: Batch size to use for each dataloader. Default is 1.
num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
number of CPUs available.
num_workers: Number of subprocesses to use for data loading. 0 means that the
data will be loaded in the main process. Number of CPUs available.
"""
def dataloader(ds, shuffle=False):
return DataLoader(
ds,
@ -399,7 +399,7 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
if isinstance(train_dataset, Mapping):
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
if isinstance(train_dataset, Sequence):
return [dataloader(ds, shuffle=True) for ds in train_dataset]
return [dataloader(ds, shuffle=True) for ds in train_dataset]
return dataloader(train_dataset, shuffle=True)
def val_dataloader():

View File

@ -13,7 +13,7 @@
# limitations under the License.
import pickle
from argparse import ArgumentParser
from typing import Any, Dict, Optional
from typing import Any, Dict
from unittest.mock import MagicMock
import pytest
@ -381,8 +381,8 @@ def test_full_loop_dp(tmpdir):
def training_step(self, batch, batch_idx):
_, y = batch
out = self._step(batch, batch_idx)
out['loss'] = F.cross_entropy(out['logits'], y)
return out
loss = F.cross_entropy(out['logits'], y)
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx)
@ -419,6 +419,7 @@ def test_full_loop_dp(tmpdir):
def test_dm_transfer_batch_to_device(tmpdir):
class CustomBatch:
def __init__(self, data):
self.samples = data[0]
self.targets = data[1]
@ -451,28 +452,6 @@ def test_dm_transfer_batch_to_device(tmpdir):
assert batch_gpu.samples.device == batch_gpu.targets.device == expected
class CustomMNISTDataModule(LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self._epochs_called_for = []
def prepare_data(self):
TrialMNIST(self.data_dir, train=True, download=True)
def setup(self, stage: Optional[str] = None):
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
self.dims = self.mnist_train[0][0].shape
def train_dataloader(self):
assert self.trainer.current_epoch not in self._epochs_called_for
self._epochs_called_for.append(self.trainer.current_epoch)
return DataLoader(self.mnist_train, batch_size=4)
def test_dm_reload_dataloaders_every_epoch(tmpdir):
"""Test datamodule, where trainer argument
reload_dataloaders_every_epoch is set to True/False"""
@ -508,6 +487,7 @@ def test_dm_reload_dataloaders_every_epoch(tmpdir):
class DummyDS(torch.utils.data.Dataset):
def __getitem__(self, index):
return 1

5
tests/helpers/simple_models.py Normal file → Executable file
View File

@ -21,7 +21,8 @@ from pytorch_lightning.metrics import Accuracy, MeanSquaredError
class ClassificationModel(LightningModule):
def __init__(self):
def __init__(self, lr=0.01):
self.lr = lr
super().__init__()
for i in range(3):
setattr(self, f"layer_{i}", nn.Linear(32, 32))
@ -44,7 +45,7 @@ class ClassificationModel(LightningModule):
return logits
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return [optimizer], []
def training_step(self, batch, batch_idx):

6
tests/models/test_restore.py Normal file → Executable file
View File

@ -208,8 +208,8 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
def training_step(self, batch, batch_idx):
_, y = batch
out = self._step(batch, batch_idx)
out['loss'] = F.cross_entropy(out['logits'], y)
return out
loss = F.cross_entropy(out['logits'], y)
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx)
@ -221,7 +221,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
self.log('val_acc', self.valid_acc(outputs['logits'], outputs['y']))
dm = ClassifDataModule()
model = CustomClassificationModelDP()
model = CustomClassificationModelDP(lr=0.1)
# exp file to get meta
logger = tutils.get_default_logger(tmpdir)