lightning/tests/callbacks/test_timer.py

196 lines
7.1 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from datetime import timedelta
from unittest.mock import Mock, patch
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
def test_trainer_flag(caplog):
class TestModel(BoringModel):
def on_fit_start(self):
raise SystemExit()
trainer = Trainer(max_time=dict(seconds=1337))
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 1337
trainer = Trainer(max_time=dict(seconds=1337), callbacks=[Timer()])
with pytest.raises(SystemExit), caplog.at_level(level=logging.INFO):
trainer.fit(TestModel())
assert "callbacks list already contains a Timer" in caplog.text
# Make sure max_time still honored even if max_epochs == -1
trainer = Trainer(max_time=dict(seconds=1), max_epochs=-1)
with pytest.raises(SystemExit):
trainer.fit(TestModel())
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 1
assert trainer.max_epochs == -1
assert trainer.max_steps is None
@pytest.mark.parametrize(
"duration,expected",
[
(None, None),
("00:00:00:22", timedelta(seconds=22)),
("12:34:56:65", timedelta(days=12, hours=34, minutes=56, seconds=65)),
(timedelta(weeks=52, milliseconds=1), timedelta(weeks=52, milliseconds=1)),
(dict(weeks=52, days=1), timedelta(weeks=52, days=1)),
],
)
def test_timer_parse_duration(duration, expected):
timer = Timer(duration=duration)
assert (timer.time_remaining() == expected is None) or (timer.time_remaining() == expected.total_seconds())
def test_timer_interval_choice():
Timer(duration=timedelta(), interval="step")
Timer(duration=timedelta(), interval="epoch")
with pytest.raises(MisconfigurationException, match="Unsupported parameter value"):
Timer(duration=timedelta(), interval="invalid")
@patch("pytorch_lightning.callbacks.timer.time")
def test_timer_time_remaining(time_mock):
"""Test that the timer tracks the elapsed and remaining time correctly."""
start_time = time.monotonic()
duration = timedelta(seconds=10)
time_mock.monotonic.return_value = start_time
timer = Timer(duration=duration)
assert timer.time_remaining() == duration.total_seconds()
assert timer.time_elapsed() == 0
# timer not started yet
time_mock.monotonic.return_value = start_time + 60
assert timer.start_time() is None
assert timer.time_remaining() == 10
assert timer.time_elapsed() == 0
# start timer
time_mock.monotonic.return_value = start_time
timer.on_train_start(trainer=Mock(), pl_module=Mock())
assert timer.start_time() == start_time
# pretend time has elapsed
elapsed = 3
time_mock.monotonic.return_value = start_time + elapsed
assert timer.start_time() == start_time
assert round(timer.time_remaining()) == 7
assert round(timer.time_elapsed()) == 3
def test_timer_stops_training(tmpdir, caplog):
"""Test that the timer stops training before reaching max_epochs"""
model = BoringModel()
duration = timedelta(milliseconds=100)
timer = Timer(duration=duration)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1000, callbacks=[timer])
with caplog.at_level(logging.INFO):
trainer.fit(model)
assert trainer.global_step > 1
assert trainer.current_epoch < 999
assert "Time limit reached." in caplog.text
assert "Signaling Trainer to stop." in caplog.text
@pytest.mark.parametrize("interval", ["step", "epoch"])
def test_timer_zero_duration_stop(tmpdir, interval):
"""Test that the timer stops training immediately after the first check occurs."""
model = BoringModel()
duration = timedelta(0)
timer = Timer(duration=duration, interval=interval)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer])
trainer.fit(model)
if interval == "step":
# timer triggers stop on step end
assert trainer.global_step == 1
assert trainer.current_epoch == 0
else:
# timer triggers stop on epoch end
assert trainer.global_step == len(trainer.train_dataloader)
assert trainer.current_epoch == 0
@pytest.mark.parametrize("min_steps,min_epochs", [(None, 2), (3, None), (3, 2)])
def test_timer_duration_min_steps_override(tmpdir, min_steps, min_epochs):
model = BoringModel()
duration = timedelta(0)
timer = Timer(duration=duration)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer], min_steps=min_steps, min_epochs=min_epochs)
trainer.fit(model)
if min_epochs:
assert trainer.current_epoch >= min_epochs - 1
if min_steps:
assert trainer.global_step >= min_steps - 1
assert timer.time_elapsed() > duration.total_seconds()
def test_timer_resume_training(tmpdir):
"""Test that the timer can resume together with the Trainer."""
model = BoringModel()
timer = Timer(duration=timedelta(milliseconds=200))
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)
# initial training
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=100,
callbacks=[timer, checkpoint_callback],
)
trainer.fit(model)
assert not timer._offset
assert timer.time_remaining() <= 0
assert trainer.current_epoch < 99
saved_global_step = trainer.global_step
# resume training (with depleted timer
timer = Timer(duration=timedelta(milliseconds=200))
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[timer, checkpoint_callback],
resume_from_checkpoint=checkpoint_callback.best_model_path,
)
trainer.fit(model)
assert timer._offset > 0
assert trainer.global_step == saved_global_step + 1
@RunIf(skip_windows=True)
def test_timer_track_stages(tmpdir):
"""Test that the timer tracks time also for other stages (train/val/test)."""
# note: skipped on windows because time resolution of time.monotonic() is not high enough for this fast test
model = BoringModel()
timer = Timer()
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, callbacks=[timer])
trainer.fit(model)
assert timer.time_elapsed() == timer.time_elapsed("train") > 0
assert timer.time_elapsed("validate") > 0
assert timer.time_elapsed("test") == 0
trainer.test(model)
assert timer.time_elapsed("test") > 0