Fix: Failing test in data_modules(dp) (#5924)
* Update test_datamodules.py * fix code format issue * fix test restore * fix code format issue
This commit is contained in:
parent
e676ff96b1
commit
4857546c25
|
@ -15,10 +15,9 @@
|
|||
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Any, List, Optional, Tuple, Union, Dict, Sequence, Mapping
|
||||
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
@ -382,10 +381,11 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
|
|||
val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader()
|
||||
test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader()
|
||||
batch_size: Batch size to use for each dataloader. Default is 1.
|
||||
num_workers: Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
|
||||
number of CPUs available.
|
||||
num_workers: Number of subprocesses to use for data loading. 0 means that the
|
||||
data will be loaded in the main process. Number of CPUs available.
|
||||
|
||||
"""
|
||||
|
||||
def dataloader(ds, shuffle=False):
|
||||
return DataLoader(
|
||||
ds,
|
||||
|
@ -399,7 +399,7 @@ class LightningDataModule(DataHooks, CheckpointHooks, metaclass=_DataModuleWrapp
|
|||
if isinstance(train_dataset, Mapping):
|
||||
return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
|
||||
if isinstance(train_dataset, Sequence):
|
||||
return [dataloader(ds, shuffle=True) for ds in train_dataset]
|
||||
return [dataloader(ds, shuffle=True) for ds in train_dataset]
|
||||
return dataloader(train_dataset, shuffle=True)
|
||||
|
||||
def val_dataloader():
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import pickle
|
||||
from argparse import ArgumentParser
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -381,8 +381,8 @@ def test_full_loop_dp(tmpdir):
|
|||
def training_step(self, batch, batch_idx):
|
||||
_, y = batch
|
||||
out = self._step(batch, batch_idx)
|
||||
out['loss'] = F.cross_entropy(out['logits'], y)
|
||||
return out
|
||||
loss = F.cross_entropy(out['logits'], y)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return self._step(batch, batch_idx)
|
||||
|
@ -419,6 +419,7 @@ def test_full_loop_dp(tmpdir):
|
|||
def test_dm_transfer_batch_to_device(tmpdir):
|
||||
|
||||
class CustomBatch:
|
||||
|
||||
def __init__(self, data):
|
||||
self.samples = data[0]
|
||||
self.targets = data[1]
|
||||
|
@ -451,28 +452,6 @@ def test_dm_transfer_batch_to_device(tmpdir):
|
|||
assert batch_gpu.samples.device == batch_gpu.targets.device == expected
|
||||
|
||||
|
||||
class CustomMNISTDataModule(LightningDataModule):
|
||||
def __init__(self, data_dir: str = "./"):
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
self._epochs_called_for = []
|
||||
|
||||
def prepare_data(self):
|
||||
TrialMNIST(self.data_dir, train=True, download=True)
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
mnist_full = TrialMNIST(root=self.data_dir, train=True, num_samples=64, download=True)
|
||||
self.mnist_train, self.mnist_val = random_split(mnist_full, [128, 64])
|
||||
self.dims = self.mnist_train[0][0].shape
|
||||
|
||||
def train_dataloader(self):
|
||||
assert self.trainer.current_epoch not in self._epochs_called_for
|
||||
self._epochs_called_for.append(self.trainer.current_epoch)
|
||||
|
||||
return DataLoader(self.mnist_train, batch_size=4)
|
||||
|
||||
|
||||
def test_dm_reload_dataloaders_every_epoch(tmpdir):
|
||||
"""Test datamodule, where trainer argument
|
||||
reload_dataloaders_every_epoch is set to True/False"""
|
||||
|
@ -508,6 +487,7 @@ def test_dm_reload_dataloaders_every_epoch(tmpdir):
|
|||
|
||||
|
||||
class DummyDS(torch.utils.data.Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
return 1
|
||||
|
||||
|
|
|
@ -21,7 +21,8 @@ from pytorch_lightning.metrics import Accuracy, MeanSquaredError
|
|||
|
||||
class ClassificationModel(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, lr=0.01):
|
||||
self.lr = lr
|
||||
super().__init__()
|
||||
for i in range(3):
|
||||
setattr(self, f"layer_{i}", nn.Linear(32, 32))
|
||||
|
@ -44,7 +45,7 @@ class ClassificationModel(LightningModule):
|
|||
return logits
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
return [optimizer], []
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
|
|
|
@ -208,8 +208,8 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
|
|||
def training_step(self, batch, batch_idx):
|
||||
_, y = batch
|
||||
out = self._step(batch, batch_idx)
|
||||
out['loss'] = F.cross_entropy(out['logits'], y)
|
||||
return out
|
||||
loss = F.cross_entropy(out['logits'], y)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return self._step(batch, batch_idx)
|
||||
|
@ -221,7 +221,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
|
|||
self.log('val_acc', self.valid_acc(outputs['logits'], outputs['y']))
|
||||
|
||||
dm = ClassifDataModule()
|
||||
model = CustomClassificationModelDP()
|
||||
model = CustomClassificationModelDP(lr=0.1)
|
||||
|
||||
# exp file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
|
|
Loading…
Reference in New Issue