Fix a typo in warning message inside Trainer.reset_train_dataloader (#12645)

Co-authored-by: kd <li_jide_ok@126.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
semaphore-egg 2022-04-09 19:48:32 +08:00 committed by lexierule
parent 3576dad964
commit a4026fea05
2 changed files with 3 additions and 3 deletions

View File

@ -1936,7 +1936,7 @@ class Trainer(
if self.loggers and self.num_training_batches < self.log_every_n_steps:
rank_zero_warn(
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
f"The number of training batches ({self.num_training_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
" you want to see logs for the training epoch.",
category=PossibleUserWarning,

View File

@ -702,11 +702,11 @@ def test_warning_with_small_dataloader_and_logging_interval(tmpdir):
dataloader = DataLoader(RandomDataset(32, length=10))
model.train_dataloader = lambda: dataloader
with pytest.warns(UserWarning, match=r"The number of training samples \(10\) is smaller than the logging interval"):
with pytest.warns(UserWarning, match=r"The number of training batches \(10\) is smaller than the logging interval"):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=11)
trainer.fit(model)
with pytest.warns(UserWarning, match=r"The number of training samples \(1\) is smaller than the logging interval"):
with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1)
trainer.fit(model)