[TEST] Min steps override early stopping (#4283)

* test to make sure behaviour is enforced

* test_min_steps_override_early_stopping_functionality

* make sure Excepted Behaviour is reproduced

* remove pollution from extra logging

* update docstring

* reduce test time

* resolve pep8
This commit is contained in:
chaton 2020-12-04 16:10:14 +00:00 committed by GitHub
parent 342a2b6f25
commit 62903717a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 93 additions and 2 deletions

View File

@ -13,15 +13,17 @@
# limitations under the License.
import os
import pickle
from unittest import mock
import cloudpickle
import numpy as np
import pytest
import torch
from unittest import mock
from pytorch_lightning import _logger
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tests.base import EvalModelTemplate
from tests.base import EvalModelTemplate, BoringModel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -204,3 +206,92 @@ def test_early_stopping_functionality_arbitrary_key(tmpdir):
)
trainer.fit(model)
assert trainer.current_epoch >= 5, 'early_stopping failed'
@pytest.mark.parametrize('step_freeze, min_steps, min_epochs',[(5, 1, 1), (5, 1, 3), (3, 15, 1)])
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, min_steps, min_epochs):
"""Excepted Behaviour:
IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered,
THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop.
IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is being triggered,
THEN the trainer should continue until reaching `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
This test validate this expected behaviour
IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when `early_stopping` is being triggered,
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.
Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
This test validate those expected behaviours
"""
_logger.disabled = True
original_loss_value = 10
limit_train_batches = 3
patience = 3
class Model(BoringModel):
def __init__(self, step_freeze):
super(Model, self).__init__()
self._step_freeze = step_freeze
self._loss_value = 10.0
self._eps = 1e-1
self._count_decrease = 0
self._values = []
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
return {"test_val_loss": self._loss_value}
def validation_epoch_end(self, outputs):
_mean = np.mean([x['test_val_loss'] for x in outputs])
if self.trainer.global_step <= self._step_freeze:
self._count_decrease += 1
self._loss_value -= self._eps
self._values.append(_mean)
return {"test_val_loss": _mean}
model = Model(step_freeze)
model.training_step_end = None
model.test_dataloader = None
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
limit_train_batches=limit_train_batches,
limit_val_batches=2,
min_steps=min_steps,
min_epochs=min_epochs
)
trainer.fit(model)
# Make sure loss was properly decreased
assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6
pos_diff = (np.diff(model._values) == 0).nonzero()[0][0]
# Compute when the latest validation epoch end happened
latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches
if pos_diff % limit_train_batches == 0:
latest_validation_epoch_end += limit_train_batches
# Compute early stopping latest step
by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience
# Compute min_epochs latest step
by_min_epochs = min_epochs * limit_train_batches
# Make sure the trainer stops for the max of all minimun requirements
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \
(trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs)
_logger.disabled = False