Remove on epoch guard from the should stop validation check (#7701)

* Remove on epoch guard from the should stop validation check

* Formatting
This commit is contained in:
Carlos Mocholí 2021-05-25 16:59:42 +02:00 committed by GitHub
parent e2ead9abd7
commit a1c40f3207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 20 deletions

View File

@ -211,6 +211,4 @@ class GPUStatsMonitor(Callback):
@staticmethod
def _should_log(trainer) -> bool:
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)
return should_log
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop

View File

@ -202,6 +202,4 @@ class LearningRateMonitor(Callback):
@staticmethod
def _should_log(trainer) -> bool:
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)
return should_log
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop

View File

@ -529,21 +529,11 @@ class TrainLoop:
self.total_batch_idx += 1
# max steps reached, end training
if (
max_steps_reached = (
self.max_steps is not None and self.max_steps <= self.global_step + 1
and self._accumulated_batches_reached()
):
break
# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
if self.trainer.should_stop:
break
# stop epoch if we limited the number of training batches
if self._num_training_batches_reached(is_last_batch):
)
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
break
# progress global step according to grads progress
@ -906,7 +896,7 @@ class TrainLoop:
if on_epoch and is_last_batch and is_infinite_dataset:
return True
if on_epoch and self.trainer.should_stop:
if self.trainer.should_stop:
return True
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch

View File

@ -110,3 +110,35 @@ def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_):
else:
assert trainer.train_loop.batch_idx == batch_idx_
assert trainer.global_step == batch_idx_ * max_epochs
def test_should_stop_mid_epoch(tmpdir):
"""Test that training correctly stops mid epoch and that validation is still called at the right time"""
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.validation_called_at = None
def training_step(self, batch, batch_idx):
if batch_idx == 4:
self.trainer.should_stop = True
return super().training_step(batch, batch_idx)
def validation_step(self, *args):
self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step)
return super().validation_step(*args)
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=10,
limit_val_batches=1,
)
trainer.fit(model)
assert trainer.current_epoch == 0
assert trainer.global_step == 5
assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR