Fix a typo in warning message inside Trainer.reset_train_dataloader (#12645)
Co-authored-by: kd <li_jide_ok@126.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
3576dad964
commit
a4026fea05
|
@ -1936,7 +1936,7 @@ class Trainer(
|
||||||
|
|
||||||
if self.loggers and self.num_training_batches < self.log_every_n_steps:
|
if self.loggers and self.num_training_batches < self.log_every_n_steps:
|
||||||
rank_zero_warn(
|
rank_zero_warn(
|
||||||
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
|
f"The number of training batches ({self.num_training_batches}) is smaller than the logging interval"
|
||||||
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
|
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
|
||||||
" you want to see logs for the training epoch.",
|
" you want to see logs for the training epoch.",
|
||||||
category=PossibleUserWarning,
|
category=PossibleUserWarning,
|
||||||
|
|
|
@ -702,11 +702,11 @@ def test_warning_with_small_dataloader_and_logging_interval(tmpdir):
|
||||||
dataloader = DataLoader(RandomDataset(32, length=10))
|
dataloader = DataLoader(RandomDataset(32, length=10))
|
||||||
model.train_dataloader = lambda: dataloader
|
model.train_dataloader = lambda: dataloader
|
||||||
|
|
||||||
with pytest.warns(UserWarning, match=r"The number of training samples \(10\) is smaller than the logging interval"):
|
with pytest.warns(UserWarning, match=r"The number of training batches \(10\) is smaller than the logging interval"):
|
||||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=11)
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=11)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
with pytest.warns(UserWarning, match=r"The number of training samples \(1\) is smaller than the logging interval"):
|
with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"):
|
||||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1)
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, log_every_n_steps=2, limit_train_batches=1)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue