From 20f37b85b68f9903df8d61e79fcebdbadacf6422 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 7 Jun 2021 14:40:43 +0200 Subject: [PATCH] add warning when Trainer(log_every_n_steps) not well chosen (#7734) * add warning * update changelog * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * logger check * add docstring for test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/data_loading.py | 8 ++++++++ tests/trainer/test_dataloaders.py | 19 +++++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bc8ffcf1d..7b2d79ba0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864)) +- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734)) + + ### Changed - Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index a16ac0c7f5..77835a1976 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -51,6 +51,7 @@ class TrainerDataLoadingMixin(ABC): test_dataloaders: Optional[List[DataLoader]] num_test_batches: List[Union[int, float]] limit_train_batches: Union[int, float] + log_every_n_steps: int overfit_batches: Union[int, float] distributed_sampler_kwargs: dict accelerator: Accelerator @@ -302,6 +303,13 @@ class TrainerDataLoadingMixin(ABC): self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) + if self.logger and self.num_training_batches < self.log_every_n_steps: + rank_zero_warn( + f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" + f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" + f" you want to see logs for the training epoch." + ) + def _reset_eval_dataloader( self, model: LightningModule, diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index c2e5e1c24a..eceba60ae5 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -895,6 +895,25 @@ def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): trainer.fit(model, train_dataloader=dataloader) +def test_warning_with_small_dataloader_and_logging_interval(tmpdir): + """ Test that a warning message is shown if the dataloader length is too short for the chosen logging interval. """ + model = BoringModel() + dataloader = DataLoader(RandomDataset(32, length=10)) + model.train_dataloader = lambda: dataloader + + with pytest.warns(UserWarning, match=r"The number of training samples \(10\) is smaller than the logging interval"): + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + log_every_n_steps=11, + ) + trainer.fit(model) + + with pytest.warns(UserWarning, match=r"The number of training samples \(1\) is smaller than the logging interval"): + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1) + trainer.fit(model) + + def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = BoringModel()