lightning/tests/tests_pytorch/callbacks/test_timer.py

186 lines
6.9 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from datetime import timedelta
from unittest.mock import Mock, patch
import pytest
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf
def test_trainer_flag(caplog):
class TestModel(BoringModel):
def on_fit_start(self):
raise SystemExit()
trainer = Trainer(max_time={"seconds": 1337})
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 1337
trainer = Trainer(max_time={"seconds": 1337}, callbacks=[Timer()])
with pytest.raises(SystemExit), caplog.at_level(level=logging.INFO):
trainer.fit(TestModel())
assert "callbacks list already contains a Timer" in caplog.text
# Make sure max_time still honored even if max_epochs == -1
trainer = Trainer(max_time={"seconds": 1}, max_epochs=-1)
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 1
assert trainer.max_epochs == -1
assert trainer.max_steps == -1
@pytest.mark.parametrize(
("duration", "expected"),
[
(None, None),
("00:00:00:22", timedelta(seconds=22)),
("12:34:56:65", timedelta(days=12, hours=34, minutes=56, seconds=65)),
(timedelta(weeks=52, milliseconds=1), timedelta(weeks=52, milliseconds=1)),
({"weeks": 52, "days": 1}, timedelta(weeks=52, days=1)),
],
)
def test_timer_parse_duration(duration, expected):
timer = Timer(duration=duration)
assert (timer.time_remaining() == expected is None) or (timer.time_remaining() == expected.total_seconds())
def test_timer_interval_choice():
Timer(duration=timedelta(), interval="step")
Timer(duration=timedelta(), interval="epoch")
with pytest.raises(MisconfigurationException, match="Unsupported parameter value"):
Timer(duration=timedelta(), interval="invalid")
@patch("lightning.pytorch.callbacks.timer.time")
def test_timer_time_remaining(time_mock):
"""Test that the timer tracks the elapsed and remaining time correctly."""
start_time = time.monotonic()
duration = timedelta(seconds=10)
time_mock.monotonic.return_value = start_time
timer = Timer(duration=duration)
assert timer.time_remaining() == duration.total_seconds()
assert timer.time_elapsed() == 0
# timer not started yet
time_mock.monotonic.return_value = start_time + 60
assert timer.start_time() is None
assert timer.time_remaining() == 10
assert timer.time_elapsed() == 0
# start timer
time_mock.monotonic.return_value = start_time
timer.on_train_start(trainer=Mock(), pl_module=Mock())
assert timer.start_time() == start_time
# pretend time has elapsed
elapsed = 3
time_mock.monotonic.return_value = start_time + elapsed
assert timer.start_time() == start_time
assert round(timer.time_remaining()) == 7
assert round(timer.time_elapsed()) == 3
def test_timer_stops_training(tmpdir, caplog):
"""Test that the timer stops training before reaching max_epochs."""
model = BoringModel()
duration = timedelta(milliseconds=100)
timer = Timer(duration=duration)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1000, callbacks=[timer])
with caplog.at_level(logging.INFO):
trainer.fit(model)
assert trainer.global_step > 1
assert trainer.current_epoch < 999
assert "Time limit reached." in caplog.text
assert "Signaling Trainer to stop." in caplog.text
@pytest.mark.parametrize("interval", ["step", "epoch"])
def test_timer_zero_duration_stop(tmpdir, interval):
"""Test that the timer stops training immediately after the first check occurs."""
model = BoringModel()
duration = timedelta(0)
timer = Timer(duration=duration, interval=interval)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer])
trainer.fit(model)
assert trainer.global_step == 0
assert trainer.current_epoch == 0
@pytest.mark.parametrize(("min_steps", "min_epochs"), [(None, 2), (3, None), (3, 2)])
def test_timer_duration_min_steps_override(tmpdir, min_steps, min_epochs):
model = BoringModel()
duration = timedelta(0)
timer = Timer(duration=duration)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer], min_steps=min_steps, min_epochs=min_epochs)
trainer.fit(model)
if min_epochs:
assert trainer.current_epoch >= min_epochs
if min_steps:
assert trainer.global_step >= min_steps - 1
assert timer.time_elapsed() > duration.total_seconds()
def test_timer_resume_training(tmpdir):
"""Test that the timer can resume together with the Trainer."""
model = BoringModel()
timer = Timer(duration=timedelta(milliseconds=200))
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)
# initial training
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=100,
callbacks=[timer, checkpoint_callback],
)
trainer.fit(model)
assert not timer._offset
assert timer.time_remaining() <= 0
assert trainer.current_epoch < 99
saved_global_step = trainer.global_step
# resume training (with depleted timer)
timer = Timer(duration=timedelta(milliseconds=200))
trainer = Trainer(default_root_dir=tmpdir, callbacks=timer)
trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path)
assert timer._offset > 0
assert trainer.global_step == saved_global_step
@RunIf(skip_windows=True)
def test_timer_track_stages(tmpdir):
"""Test that the timer tracks time also for other stages (train/val/test)."""
# note: skipped on windows because time resolution of time.monotonic() is not high enough for this fast test
model = BoringModel()
timer = Timer()
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, callbacks=[timer])
trainer.fit(model)
assert timer.time_elapsed() == timer.time_elapsed("train") > 0
assert timer.time_elapsed("validate") > 0
assert timer.time_elapsed("test") == 0
trainer.test(model)
assert timer.time_elapsed("test") > 0