Add warning for few workers (#1378)

* Add warning for few workers

* Fix style issue

* Update CHANGELOG.md

* Update test

* formatting

* formatting

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Ethan Harris 2020-04-05 16:07:16 +01:00 committed by GitHub
parent fdcf9cd518
commit b18accc64c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 3 deletions

View File

@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
- Added a warning when the number of data loader workers is small. ([#1378](https://github.com/PyTorchLightning/pytorch-lightning/pull/1378))
### Changed

View File

@ -1,8 +1,9 @@
import warnings
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable
import torch.distributed as torch_distrib
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.core import LightningModule
@ -73,6 +74,12 @@ class TrainerDataLoadingMixin(ABC):
if not 0. <= value <= 1.:
raise ValueError(msg)
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
if isinstance(dataloader, DataLoader) and dataloader.num_workers <= 2:
warnings.warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
' Consider increasing the value of the `num_workers` argument`'
' in the `DataLoader` init to improve performance.')
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
@ -112,11 +119,13 @@ class TrainerDataLoadingMixin(ABC):
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)
self.num_training_batches = 0
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
self._worker_check(self.train_dataloader, 'train dataloader')
self._percent_range_check('train_percent_check')
if not _has_len(self.train_dataloader):
@ -176,10 +185,10 @@ class TrainerDataLoadingMixin(ABC):
# determine number of batches
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for dataloader in dataloaders:
for i, dataloader in enumerate(dataloaders):
self._worker_check(dataloader, f'{mode} dataloader {i}')
if not _has_len(dataloader):
num_batches = float('inf')
break
percent_check = getattr(self, f'{mode}_percent_check')

View File

@ -15,6 +15,7 @@ from tests.base import (
LightValStepFitMultipleDataloadersMixin,
LightValStepFitSingleDataloaderMixin,
LightTrainDataloader,
LightValidationDataloader,
LightInfTrainDataloader,
LightInfValDataloader,
LightInfTestDataloader,
@ -485,6 +486,47 @@ def test_error_on_zero_len_dataloader(tmpdir):
trainer.fit(model)
def test_warning_with_few_workers(tmpdir):
""" Test that error is raised if dataloader with only a few workers is used """
tutils.reset_seed()
class CurrentTestModel(
LightTrainDataloader,
LightValStepFitSingleDataloaderMixin,
LightTestFitSingleTestDataloadersMixin,
LightEmptyTestStep,
TestModelBase,
):
pass
hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)
fit_options = dict(train_dataloader=model._dataloader(train=True),
val_dataloaders=model._dataloader(train=False),
test_dataloaders=model._dataloader(train=False))
trainer = Trainer(**trainer_options)
# fit model
with pytest.warns(UserWarning, match='train'):
trainer.fit(model, **fit_options)
with pytest.warns(UserWarning, match='val'):
trainer.fit(model, **fit_options)
with pytest.warns(UserWarning, match='test'):
trainer.test()
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_reinit_for_subclass():