diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index abe21b9d28..b7711e8aae 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -13,15 +13,17 @@ # limitations under the License. import os import pickle -from unittest import mock import cloudpickle +import numpy as np import pytest import torch +from unittest import mock +from pytorch_lightning import _logger from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from tests.base import EvalModelTemplate +from tests.base import EvalModelTemplate, BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -204,3 +206,92 @@ def test_early_stopping_functionality_arbitrary_key(tmpdir): ) trainer.fit(model) assert trainer.current_epoch >= 5, 'early_stopping failed' + + +@pytest.mark.parametrize('step_freeze, min_steps, min_epochs',[(5, 1, 1), (5, 1, 3), (3, 15, 1)]) +def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, min_steps, min_epochs): + """Excepted Behaviour: + IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered, + THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop. + + IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is being triggered, + THEN the trainer should continue until reaching `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop. + This test validate this expected behaviour + + IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when `early_stopping` is being triggered, + THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached. + + Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader) + + This test validate those expected behaviours + """ + + _logger.disabled = True + + original_loss_value = 10 + limit_train_batches = 3 + patience = 3 + + class Model(BoringModel): + + def __init__(self, step_freeze): + super(Model, self).__init__() + + self._step_freeze = step_freeze + + self._loss_value = 10.0 + self._eps = 1e-1 + self._count_decrease = 0 + self._values = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + return {"test_val_loss": self._loss_value} + + def validation_epoch_end(self, outputs): + _mean = np.mean([x['test_val_loss'] for x in outputs]) + if self.trainer.global_step <= self._step_freeze: + self._count_decrease += 1 + self._loss_value -= self._eps + self._values.append(_mean) + return {"test_val_loss": _mean} + + model = Model(step_freeze) + model.training_step_end = None + model.test_dataloader = None + early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stop_callback], + limit_train_batches=limit_train_batches, + limit_val_batches=2, + min_steps=min_steps, + min_epochs=min_epochs + ) + trainer.fit(model) + + # Make sure loss was properly decreased + assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6 + + pos_diff = (np.diff(model._values) == 0).nonzero()[0][0] + + # Compute when the latest validation epoch end happened + latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches + if pos_diff % limit_train_batches == 0: + latest_validation_epoch_end += limit_train_batches + + # Compute early stopping latest step + by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience + + # Compute min_epochs latest step + by_min_epochs = min_epochs * limit_train_batches + + # Make sure the trainer stops for the max of all minimun requirements + assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \ + (trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs) + + _logger.disabled = False