Refactor early stopping test (#11866)

This commit is contained in:
Carlos Mocholí 2022-02-18 00:20:39 +01:00 committed by GitHub
parent 25b505508d
commit a0ca8d076f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 45 additions and 84 deletions

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import pickle
from typing import List, Optional
from unittest import mock
@ -264,100 +265,60 @@ def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):
assert early_stopping.stopped_epoch == expected_stop_epoch
@pytest.mark.parametrize("step_freeze, min_steps, min_epochs", [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int):
"""Excepted Behaviour: IF `min_steps` was set to a higher value than the `trainer.global_step` when
`early_stopping` is being triggered, THEN the trainer should continue until reaching `trainer.global_step` ==
`min_steps`, and stop.
IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step`
when `early_stopping` is being triggered,
THEN the trainer should continue until reaching
`trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
This test validate this expected behaviour
IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step`
when `early_stopping` is being triggered,
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.
Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
This test validate those expected behaviours
"""
_logger.disabled = True
original_loss_value = 10
limit_train_batches = 3
patience = 3
class Model(BoringModel):
def __init__(self, step_freeze):
super().__init__()
self._step_freeze = step_freeze
self._loss_value = 10.0
self._eps = 1e-1
self._count_decrease = 0
self._values = []
@pytest.mark.parametrize("limit_train_batches", (3, 5))
@pytest.mark.parametrize(
["min_epochs", "min_steps"],
[
# IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being
# triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop
(0, 10),
# IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is
# being triggered, THEN the trainer should continue until reaching
# `trainer.global_step` == `min_epochs * len(train_dataloader)`
(2, 0),
# IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when
# `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and
# `min_steps` would be reached
(1, 10),
(3, 10),
],
)
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps):
if min_steps:
assert limit_train_batches < min_steps
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}
self.log("foo", batch_idx)
return super().training_step(batch, batch_idx)
def validation_step(self, batch, batch_idx):
return {"test_val_loss": self._loss_value}
def validation_epoch_end(self, outputs):
_mean = np.mean([x["test_val_loss"] for x in outputs])
if self.trainer.global_step <= self._step_freeze:
self._count_decrease += 1
self._loss_value -= self._eps
self._values.append(_mean)
self.log("test_val_loss", _mean)
model = Model(step_freeze)
model.training_step_end = None
model.test_dataloader = None
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
es_callback = EarlyStopping("foo")
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
callbacks=es_callback,
limit_val_batches=0,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
min_steps=min_steps,
min_epochs=min_epochs,
min_steps=min_steps,
logger=False,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
model = TestModel()
# Make sure loss was properly decreased
assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6
expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs)
# trigger early stopping directly after the first epoch
side_effect = [(True, "")] * expected_epochs
with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect):
trainer.fit(model)
pos_diff = (np.diff(model._values) == 0).nonzero()[0][0]
# Compute when the latest validation epoch end happened
latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches
if pos_diff % limit_train_batches == 0:
latest_validation_epoch_end += limit_train_batches
# Compute early stopping latest step
by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience
# Compute min_epochs latest step
by_min_epochs = min_epochs * limit_train_batches
# Make sure the trainer stops for the max of all minimum requirements
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), (
trainer.global_step,
max(min_steps, by_early_stopping, by_min_epochs),
step_freeze,
min_steps,
min_epochs,
)
_logger.disabled = False
assert trainer.should_stop
# epochs continue until min steps are reached
assert trainer.current_epoch == expected_epochs
# steps continue until min steps are reached AND the epoch is exhausted
# stopping mid-epoch is not supported
assert trainer.global_step == limit_train_batches * expected_epochs
def test_early_stopping_mode_options():