Remove on epoch guard from the should stop validation check (#7701)
* Remove on epoch guard from the should stop validation check * Formatting
This commit is contained in:
parent
e2ead9abd7
commit
a1c40f3207
|
@ -211,6 +211,4 @@ class GPUStatsMonitor(Callback):
|
|||
|
||||
@staticmethod
|
||||
def _should_log(trainer) -> bool:
|
||||
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)
|
||||
|
||||
return should_log
|
||||
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop
|
||||
|
|
|
@ -202,6 +202,4 @@ class LearningRateMonitor(Callback):
|
|||
|
||||
@staticmethod
|
||||
def _should_log(trainer) -> bool:
|
||||
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)
|
||||
|
||||
return should_log
|
||||
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop
|
||||
|
|
|
@ -529,21 +529,11 @@ class TrainLoop:
|
|||
|
||||
self.total_batch_idx += 1
|
||||
|
||||
# max steps reached, end training
|
||||
if (
|
||||
max_steps_reached = (
|
||||
self.max_steps is not None and self.max_steps <= self.global_step + 1
|
||||
and self._accumulated_batches_reached()
|
||||
):
|
||||
break
|
||||
|
||||
# end epoch early
|
||||
# stop when the flag is changed or we've gone past the amount
|
||||
# requested in the batches
|
||||
if self.trainer.should_stop:
|
||||
break
|
||||
|
||||
# stop epoch if we limited the number of training batches
|
||||
if self._num_training_batches_reached(is_last_batch):
|
||||
)
|
||||
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
|
||||
break
|
||||
|
||||
# progress global step according to grads progress
|
||||
|
@ -906,7 +896,7 @@ class TrainLoop:
|
|||
if on_epoch and is_last_batch and is_infinite_dataset:
|
||||
return True
|
||||
|
||||
if on_epoch and self.trainer.should_stop:
|
||||
if self.trainer.should_stop:
|
||||
return True
|
||||
|
||||
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
|
||||
|
|
|
@ -110,3 +110,35 @@ def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_):
|
|||
else:
|
||||
assert trainer.train_loop.batch_idx == batch_idx_
|
||||
assert trainer.global_step == batch_idx_ * max_epochs
|
||||
|
||||
|
||||
def test_should_stop_mid_epoch(tmpdir):
|
||||
"""Test that training correctly stops mid epoch and that validation is still called at the right time"""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.validation_called_at = None
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
if batch_idx == 4:
|
||||
self.trainer.should_stop = True
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def validation_step(self, *args):
|
||||
self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step)
|
||||
return super().validation_step(*args)
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=1,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.current_epoch == 0
|
||||
assert trainer.global_step == 5
|
||||
assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR
|
||||
|
|
Loading…
Reference in New Issue