From a4026fea0501ad94f48b68930c883aa2aab1c134 Mon Sep 17 00:00:00 2001 From: semaphore-egg <56214204+semaphore-egg@users.noreply.github.com> Date: Sat, 9 Apr 2022 19:48:32 +0800 Subject: [PATCH] Fix a typo in warning message inside Trainer.reset_train_dataloader (#12645) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: kd Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/trainer.py | 2 +- tests/trainer/test_dataloaders.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d99c26a5de..088b0fb71e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1936,7 +1936,7 @@ class Trainer( if self.loggers and self.num_training_batches < self.log_every_n_steps: rank_zero_warn( - f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" + f"The number of training batches ({self.num_training_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" " you want to see logs for the training epoch.", category=PossibleUserWarning, diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 08d54e05bf..66b5be243b 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -702,11 +702,11 @@ def test_warning_with_small_dataloader_and_logging_interval(tmpdir): dataloader = DataLoader(RandomDataset(32, length=10)) model.train_dataloader = lambda: dataloader - with pytest.warns(UserWarning, match=r"The number of training samples \(10\) is smaller than the logging interval"): + with pytest.warns(UserWarning, match=r"The number of training batches \(10\) is smaller than the logging interval"): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=11) trainer.fit(model) - with pytest.warns(UserWarning, match=r"The number of training samples \(1\) is smaller than the logging interval"): + with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1) trainer.fit(model)