[TEST] Min steps override early stopping (#4283)
* test to make sure behaviour is enforced * test_min_steps_override_early_stopping_functionality * make sure Excepted Behaviour is reproduced * remove pollution from extra logging * update docstring * reduce test time * resolve pep8
This commit is contained in:
parent
342a2b6f25
commit
62903717a4
|
@ -13,15 +13,17 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import pickle
|
||||
from unittest import mock
|
||||
|
||||
import cloudpickle
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from unittest import mock
|
||||
|
||||
from pytorch_lightning import _logger
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base import EvalModelTemplate, BoringModel
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -204,3 +206,92 @@ def test_early_stopping_functionality_arbitrary_key(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.current_epoch >= 5, 'early_stopping failed'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('step_freeze, min_steps, min_epochs',[(5, 1, 1), (5, 1, 3), (3, 15, 1)])
|
||||
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, min_steps, min_epochs):
|
||||
"""Excepted Behaviour:
|
||||
IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered,
|
||||
THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop.
|
||||
|
||||
IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is being triggered,
|
||||
THEN the trainer should continue until reaching `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
|
||||
This test validate this expected behaviour
|
||||
|
||||
IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when `early_stopping` is being triggered,
|
||||
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.
|
||||
|
||||
Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
|
||||
|
||||
This test validate those expected behaviours
|
||||
"""
|
||||
|
||||
_logger.disabled = True
|
||||
|
||||
original_loss_value = 10
|
||||
limit_train_batches = 3
|
||||
patience = 3
|
||||
|
||||
class Model(BoringModel):
|
||||
|
||||
def __init__(self, step_freeze):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self._step_freeze = step_freeze
|
||||
|
||||
self._loss_value = 10.0
|
||||
self._eps = 1e-1
|
||||
self._count_decrease = 0
|
||||
self._values = []
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"loss": loss}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return {"test_val_loss": self._loss_value}
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
_mean = np.mean([x['test_val_loss'] for x in outputs])
|
||||
if self.trainer.global_step <= self._step_freeze:
|
||||
self._count_decrease += 1
|
||||
self._loss_value -= self._eps
|
||||
self._values.append(_mean)
|
||||
return {"test_val_loss": _mean}
|
||||
|
||||
model = Model(step_freeze)
|
||||
model.training_step_end = None
|
||||
model.test_dataloader = None
|
||||
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[early_stop_callback],
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=2,
|
||||
min_steps=min_steps,
|
||||
min_epochs=min_epochs
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# Make sure loss was properly decreased
|
||||
assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6
|
||||
|
||||
pos_diff = (np.diff(model._values) == 0).nonzero()[0][0]
|
||||
|
||||
# Compute when the latest validation epoch end happened
|
||||
latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches
|
||||
if pos_diff % limit_train_batches == 0:
|
||||
latest_validation_epoch_end += limit_train_batches
|
||||
|
||||
# Compute early stopping latest step
|
||||
by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience
|
||||
|
||||
# Compute min_epochs latest step
|
||||
by_min_epochs = min_epochs * limit_train_batches
|
||||
|
||||
# Make sure the trainer stops for the max of all minimun requirements
|
||||
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \
|
||||
(trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs)
|
||||
|
||||
_logger.disabled = False
|
||||
|
|
Loading…
Reference in New Issue