Fix a typo in warning message inside Trainer.reset_train_dataloader (#12645)
Co-authored-by: kd <li_jide_ok@126.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
3576dad964
commit
a4026fea05
|
@ -1936,7 +1936,7 @@ class Trainer(
|
|||
|
||||
if self.loggers and self.num_training_batches < self.log_every_n_steps:
|
||||
rank_zero_warn(
|
||||
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
|
||||
f"The number of training batches ({self.num_training_batches}) is smaller than the logging interval"
|
||||
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
|
||||
" you want to see logs for the training epoch.",
|
||||
category=PossibleUserWarning,
|
||||
|
|
|
@ -702,11 +702,11 @@ def test_warning_with_small_dataloader_and_logging_interval(tmpdir):
|
|||
dataloader = DataLoader(RandomDataset(32, length=10))
|
||||
model.train_dataloader = lambda: dataloader
|
||||
|
||||
with pytest.warns(UserWarning, match=r"The number of training samples \(10\) is smaller than the logging interval"):
|
||||
with pytest.warns(UserWarning, match=r"The number of training batches \(10\) is smaller than the logging interval"):
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=11)
|
||||
trainer.fit(model)
|
||||
|
||||
with pytest.warns(UserWarning, match=r"The number of training samples \(1\) is smaller than the logging interval"):
|
||||
with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"):
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1)
|
||||
trainer.fit(model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue