Add warning for few workers (#1378)
* Add warning for few workers * Fix style issue * Update CHANGELOG.md * Update test * formatting * formatting Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
fdcf9cd518
commit
b18accc64c
|
@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))
|
||||
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
|
||||
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
|
||||
- Added a warning when the number of data loader workers is small. ([#1378](https://github.com/PyTorchLightning/pytorch-lightning/pull/1378))
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union, List, Tuple, Callable
|
||||
|
||||
import torch.distributed as torch_distrib
|
||||
from torch.utils.data import SequentialSampler, DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from pytorch_lightning.core import LightningModule
|
||||
|
@ -73,6 +74,12 @@ class TrainerDataLoadingMixin(ABC):
|
|||
if not 0. <= value <= 1.:
|
||||
raise ValueError(msg)
|
||||
|
||||
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
|
||||
if isinstance(dataloader, DataLoader) and dataloader.num_workers <= 2:
|
||||
warnings.warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
|
||||
' Consider increasing the value of the `num_workers` argument`'
|
||||
' in the `DataLoader` init to improve performance.')
|
||||
|
||||
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
|
||||
|
||||
# don't do anything if it's not a dataloader
|
||||
|
@ -112,11 +119,13 @@ class TrainerDataLoadingMixin(ABC):
|
|||
model: The current `LightningModule`
|
||||
"""
|
||||
self.train_dataloader = self.request_dataloader(model.train_dataloader)
|
||||
|
||||
self.num_training_batches = 0
|
||||
|
||||
# automatically add samplers
|
||||
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
|
||||
|
||||
self._worker_check(self.train_dataloader, 'train dataloader')
|
||||
self._percent_range_check('train_percent_check')
|
||||
|
||||
if not _has_len(self.train_dataloader):
|
||||
|
@ -176,10 +185,10 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# determine number of batches
|
||||
# datasets could be none, 1 or 2+
|
||||
if len(dataloaders) != 0:
|
||||
for dataloader in dataloaders:
|
||||
for i, dataloader in enumerate(dataloaders):
|
||||
self._worker_check(dataloader, f'{mode} dataloader {i}')
|
||||
if not _has_len(dataloader):
|
||||
num_batches = float('inf')
|
||||
break
|
||||
|
||||
percent_check = getattr(self, f'{mode}_percent_check')
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from tests.base import (
|
|||
LightValStepFitMultipleDataloadersMixin,
|
||||
LightValStepFitSingleDataloaderMixin,
|
||||
LightTrainDataloader,
|
||||
LightValidationDataloader,
|
||||
LightInfTrainDataloader,
|
||||
LightInfValDataloader,
|
||||
LightInfTestDataloader,
|
||||
|
@ -485,6 +486,47 @@ def test_error_on_zero_len_dataloader(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_warning_with_few_workers(tmpdir):
|
||||
""" Test that error is raised if dataloader with only a few workers is used """
|
||||
tutils.reset_seed()
|
||||
|
||||
class CurrentTestModel(
|
||||
LightTrainDataloader,
|
||||
LightValStepFitSingleDataloaderMixin,
|
||||
LightTestFitSingleTestDataloadersMixin,
|
||||
LightEmptyTestStep,
|
||||
TestModelBase,
|
||||
):
|
||||
pass
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentTestModel(hparams)
|
||||
|
||||
# logger file to get meta
|
||||
trainer_options = dict(
|
||||
default_save_path=tmpdir,
|
||||
max_epochs=1,
|
||||
val_percent_check=0.1,
|
||||
train_percent_check=0.2
|
||||
)
|
||||
|
||||
fit_options = dict(train_dataloader=model._dataloader(train=True),
|
||||
val_dataloaders=model._dataloader(train=False),
|
||||
test_dataloaders=model._dataloader(train=False))
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
|
||||
# fit model
|
||||
with pytest.warns(UserWarning, match='train'):
|
||||
trainer.fit(model, **fit_options)
|
||||
|
||||
with pytest.warns(UserWarning, match='val'):
|
||||
trainer.fit(model, **fit_options)
|
||||
|
||||
with pytest.warns(UserWarning, match='test'):
|
||||
trainer.test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
|
||||
def test_dataloader_reinit_for_subclass():
|
||||
|
||||
|
|
Loading…
Reference in New Issue