Remove on epoch guard from the should stop validation check (#7701)
* Remove on epoch guard from the should stop validation check * Formatting
This commit is contained in:
parent
e2ead9abd7
commit
a1c40f3207
|
@ -211,6 +211,4 @@ class GPUStatsMonitor(Callback):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _should_log(trainer) -> bool:
|
def _should_log(trainer) -> bool:
|
||||||
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)
|
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop
|
||||||
|
|
||||||
return should_log
|
|
||||||
|
|
|
@ -202,6 +202,4 @@ class LearningRateMonitor(Callback):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _should_log(trainer) -> bool:
|
def _should_log(trainer) -> bool:
|
||||||
should_log = ((trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop)
|
return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop
|
||||||
|
|
||||||
return should_log
|
|
||||||
|
|
|
@ -529,21 +529,11 @@ class TrainLoop:
|
||||||
|
|
||||||
self.total_batch_idx += 1
|
self.total_batch_idx += 1
|
||||||
|
|
||||||
# max steps reached, end training
|
max_steps_reached = (
|
||||||
if (
|
|
||||||
self.max_steps is not None and self.max_steps <= self.global_step + 1
|
self.max_steps is not None and self.max_steps <= self.global_step + 1
|
||||||
and self._accumulated_batches_reached()
|
and self._accumulated_batches_reached()
|
||||||
):
|
)
|
||||||
break
|
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
|
||||||
|
|
||||||
# end epoch early
|
|
||||||
# stop when the flag is changed or we've gone past the amount
|
|
||||||
# requested in the batches
|
|
||||||
if self.trainer.should_stop:
|
|
||||||
break
|
|
||||||
|
|
||||||
# stop epoch if we limited the number of training batches
|
|
||||||
if self._num_training_batches_reached(is_last_batch):
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# progress global step according to grads progress
|
# progress global step according to grads progress
|
||||||
|
@ -906,7 +896,7 @@ class TrainLoop:
|
||||||
if on_epoch and is_last_batch and is_infinite_dataset:
|
if on_epoch and is_last_batch and is_infinite_dataset:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if on_epoch and self.trainer.should_stop:
|
if self.trainer.should_stop:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
|
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
|
||||||
|
|
|
@ -110,3 +110,35 @@ def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_):
|
||||||
else:
|
else:
|
||||||
assert trainer.train_loop.batch_idx == batch_idx_
|
assert trainer.train_loop.batch_idx == batch_idx_
|
||||||
assert trainer.global_step == batch_idx_ * max_epochs
|
assert trainer.global_step == batch_idx_ * max_epochs
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_stop_mid_epoch(tmpdir):
|
||||||
|
"""Test that training correctly stops mid epoch and that validation is still called at the right time"""
|
||||||
|
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.validation_called_at = None
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
if batch_idx == 4:
|
||||||
|
self.trainer.should_stop = True
|
||||||
|
return super().training_step(batch, batch_idx)
|
||||||
|
|
||||||
|
def validation_step(self, *args):
|
||||||
|
self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step)
|
||||||
|
return super().validation_step(*args)
|
||||||
|
|
||||||
|
model = TestModel()
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
max_epochs=1,
|
||||||
|
limit_train_batches=10,
|
||||||
|
limit_val_batches=1,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
assert trainer.current_epoch == 0
|
||||||
|
assert trainer.global_step == 5
|
||||||
|
assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR
|
||||||
|
|
Loading…
Reference in New Issue