Refactor early stopping test (#11866)
This commit is contained in:
parent
25b505508d
commit
a0ca8d076f
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import pickle
|
||||
from typing import List, Optional
|
||||
from unittest import mock
|
||||
|
@ -264,100 +265,60 @@ def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):
|
|||
assert early_stopping.stopped_epoch == expected_stop_epoch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("step_freeze, min_steps, min_epochs", [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
|
||||
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int):
|
||||
"""Excepted Behaviour: IF `min_steps` was set to a higher value than the `trainer.global_step` when
|
||||
`early_stopping` is being triggered, THEN the trainer should continue until reaching `trainer.global_step` ==
|
||||
`min_steps`, and stop.
|
||||
|
||||
IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step`
|
||||
when `early_stopping` is being triggered,
|
||||
THEN the trainer should continue until reaching
|
||||
`trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
|
||||
This test validate this expected behaviour
|
||||
|
||||
IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step`
|
||||
when `early_stopping` is being triggered,
|
||||
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.
|
||||
|
||||
Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
|
||||
|
||||
This test validate those expected behaviours
|
||||
"""
|
||||
|
||||
_logger.disabled = True
|
||||
|
||||
original_loss_value = 10
|
||||
limit_train_batches = 3
|
||||
patience = 3
|
||||
|
||||
class Model(BoringModel):
|
||||
def __init__(self, step_freeze):
|
||||
super().__init__()
|
||||
|
||||
self._step_freeze = step_freeze
|
||||
|
||||
self._loss_value = 10.0
|
||||
self._eps = 1e-1
|
||||
self._count_decrease = 0
|
||||
self._values = []
|
||||
@pytest.mark.parametrize("limit_train_batches", (3, 5))
|
||||
@pytest.mark.parametrize(
|
||||
["min_epochs", "min_steps"],
|
||||
[
|
||||
# IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being
|
||||
# triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop
|
||||
(0, 10),
|
||||
# IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is
|
||||
# being triggered, THEN the trainer should continue until reaching
|
||||
# `trainer.global_step` == `min_epochs * len(train_dataloader)`
|
||||
(2, 0),
|
||||
# IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when
|
||||
# `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and
|
||||
# `min_steps` would be reached
|
||||
(1, 10),
|
||||
(3, 10),
|
||||
],
|
||||
)
|
||||
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps):
|
||||
if min_steps:
|
||||
assert limit_train_batches < min_steps
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"loss": loss}
|
||||
self.log("foo", batch_idx)
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return {"test_val_loss": self._loss_value}
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
_mean = np.mean([x["test_val_loss"] for x in outputs])
|
||||
if self.trainer.global_step <= self._step_freeze:
|
||||
self._count_decrease += 1
|
||||
self._loss_value -= self._eps
|
||||
self._values.append(_mean)
|
||||
self.log("test_val_loss", _mean)
|
||||
|
||||
model = Model(step_freeze)
|
||||
model.training_step_end = None
|
||||
model.test_dataloader = None
|
||||
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
|
||||
es_callback = EarlyStopping("foo")
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[early_stop_callback],
|
||||
callbacks=es_callback,
|
||||
limit_val_batches=0,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=2,
|
||||
min_steps=min_steps,
|
||||
min_epochs=min_epochs,
|
||||
min_steps=min_steps,
|
||||
logger=False,
|
||||
enable_checkpointing=False,
|
||||
enable_progress_bar=False,
|
||||
enable_model_summary=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
model = TestModel()
|
||||
|
||||
# Make sure loss was properly decreased
|
||||
assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6
|
||||
expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs)
|
||||
# trigger early stopping directly after the first epoch
|
||||
side_effect = [(True, "")] * expected_epochs
|
||||
with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect):
|
||||
trainer.fit(model)
|
||||
|
||||
pos_diff = (np.diff(model._values) == 0).nonzero()[0][0]
|
||||
|
||||
# Compute when the latest validation epoch end happened
|
||||
latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches
|
||||
if pos_diff % limit_train_batches == 0:
|
||||
latest_validation_epoch_end += limit_train_batches
|
||||
|
||||
# Compute early stopping latest step
|
||||
by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience
|
||||
|
||||
# Compute min_epochs latest step
|
||||
by_min_epochs = min_epochs * limit_train_batches
|
||||
|
||||
# Make sure the trainer stops for the max of all minimum requirements
|
||||
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), (
|
||||
trainer.global_step,
|
||||
max(min_steps, by_early_stopping, by_min_epochs),
|
||||
step_freeze,
|
||||
min_steps,
|
||||
min_epochs,
|
||||
)
|
||||
|
||||
_logger.disabled = False
|
||||
assert trainer.should_stop
|
||||
# epochs continue until min steps are reached
|
||||
assert trainer.current_epoch == expected_epochs
|
||||
# steps continue until min steps are reached AND the epoch is exhausted
|
||||
# stopping mid-epoch is not supported
|
||||
assert trainer.global_step == limit_train_batches * expected_epochs
|
||||
|
||||
|
||||
def test_early_stopping_mode_options():
|
||||
|
|
Loading…
Reference in New Issue