186 lines
6.9 KiB
Python
186 lines
6.9 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import time
|
|
from datetime import timedelta
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
from lightning.pytorch import Trainer
|
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
from lightning.pytorch.callbacks.timer import Timer
|
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
|
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
|
|
|
from tests_pytorch.helpers.runif import RunIf
|
|
|
|
|
|
def test_trainer_flag(caplog):
|
|
class TestModel(BoringModel):
|
|
def on_fit_start(self):
|
|
raise SystemExit()
|
|
|
|
trainer = Trainer(max_time={"seconds": 1337})
|
|
with pytest.raises(SystemExit):
|
|
trainer.fit(TestModel())
|
|
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
|
|
assert timer._duration == 1337
|
|
|
|
trainer = Trainer(max_time={"seconds": 1337}, callbacks=[Timer()])
|
|
with pytest.raises(SystemExit), caplog.at_level(level=logging.INFO):
|
|
trainer.fit(TestModel())
|
|
assert "callbacks list already contains a Timer" in caplog.text
|
|
|
|
# Make sure max_time still honored even if max_epochs == -1
|
|
trainer = Trainer(max_time={"seconds": 1}, max_epochs=-1)
|
|
with pytest.raises(SystemExit):
|
|
trainer.fit(TestModel())
|
|
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
|
|
assert timer._duration == 1
|
|
assert trainer.max_epochs == -1
|
|
assert trainer.max_steps == -1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("duration", "expected"),
|
|
[
|
|
(None, None),
|
|
("00:00:00:22", timedelta(seconds=22)),
|
|
("12:34:56:65", timedelta(days=12, hours=34, minutes=56, seconds=65)),
|
|
(timedelta(weeks=52, milliseconds=1), timedelta(weeks=52, milliseconds=1)),
|
|
({"weeks": 52, "days": 1}, timedelta(weeks=52, days=1)),
|
|
],
|
|
)
|
|
def test_timer_parse_duration(duration, expected):
|
|
timer = Timer(duration=duration)
|
|
assert (timer.time_remaining() == expected is None) or (timer.time_remaining() == expected.total_seconds())
|
|
|
|
|
|
def test_timer_interval_choice():
|
|
Timer(duration=timedelta(), interval="step")
|
|
Timer(duration=timedelta(), interval="epoch")
|
|
with pytest.raises(MisconfigurationException, match="Unsupported parameter value"):
|
|
Timer(duration=timedelta(), interval="invalid")
|
|
|
|
|
|
@patch("lightning.pytorch.callbacks.timer.time")
|
|
def test_timer_time_remaining(time_mock):
|
|
"""Test that the timer tracks the elapsed and remaining time correctly."""
|
|
start_time = time.monotonic()
|
|
duration = timedelta(seconds=10)
|
|
time_mock.monotonic.return_value = start_time
|
|
timer = Timer(duration=duration)
|
|
assert timer.time_remaining() == duration.total_seconds()
|
|
assert timer.time_elapsed() == 0
|
|
|
|
# timer not started yet
|
|
time_mock.monotonic.return_value = start_time + 60
|
|
assert timer.start_time() is None
|
|
assert timer.time_remaining() == 10
|
|
assert timer.time_elapsed() == 0
|
|
|
|
# start timer
|
|
time_mock.monotonic.return_value = start_time
|
|
timer.on_train_start(trainer=Mock(), pl_module=Mock())
|
|
assert timer.start_time() == start_time
|
|
|
|
# pretend time has elapsed
|
|
elapsed = 3
|
|
time_mock.monotonic.return_value = start_time + elapsed
|
|
assert timer.start_time() == start_time
|
|
assert round(timer.time_remaining()) == 7
|
|
assert round(timer.time_elapsed()) == 3
|
|
|
|
|
|
def test_timer_stops_training(tmpdir, caplog):
|
|
"""Test that the timer stops training before reaching max_epochs."""
|
|
model = BoringModel()
|
|
duration = timedelta(milliseconds=100)
|
|
timer = Timer(duration=duration)
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1000, callbacks=[timer])
|
|
with caplog.at_level(logging.INFO):
|
|
trainer.fit(model)
|
|
assert trainer.global_step > 1
|
|
assert trainer.current_epoch < 999
|
|
assert "Time limit reached." in caplog.text
|
|
assert "Signaling Trainer to stop." in caplog.text
|
|
|
|
|
|
@pytest.mark.parametrize("interval", ["step", "epoch"])
|
|
def test_timer_zero_duration_stop(tmpdir, interval):
|
|
"""Test that the timer stops training immediately after the first check occurs."""
|
|
model = BoringModel()
|
|
duration = timedelta(0)
|
|
timer = Timer(duration=duration, interval=interval)
|
|
trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer])
|
|
trainer.fit(model)
|
|
assert trainer.global_step == 0
|
|
assert trainer.current_epoch == 0
|
|
|
|
|
|
@pytest.mark.parametrize(("min_steps", "min_epochs"), [(None, 2), (3, None), (3, 2)])
|
|
def test_timer_duration_min_steps_override(tmpdir, min_steps, min_epochs):
|
|
model = BoringModel()
|
|
duration = timedelta(0)
|
|
timer = Timer(duration=duration)
|
|
trainer = Trainer(default_root_dir=tmpdir, callbacks=[timer], min_steps=min_steps, min_epochs=min_epochs)
|
|
trainer.fit(model)
|
|
if min_epochs:
|
|
assert trainer.current_epoch >= min_epochs
|
|
if min_steps:
|
|
assert trainer.global_step >= min_steps - 1
|
|
assert timer.time_elapsed() > duration.total_seconds()
|
|
|
|
|
|
def test_timer_resume_training(tmpdir):
|
|
"""Test that the timer can resume together with the Trainer."""
|
|
model = BoringModel()
|
|
timer = Timer(duration=timedelta(milliseconds=200))
|
|
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)
|
|
|
|
# initial training
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=100,
|
|
callbacks=[timer, checkpoint_callback],
|
|
)
|
|
trainer.fit(model)
|
|
assert not timer._offset
|
|
assert timer.time_remaining() <= 0
|
|
assert trainer.current_epoch < 99
|
|
saved_global_step = trainer.global_step
|
|
|
|
# resume training (with depleted timer)
|
|
timer = Timer(duration=timedelta(milliseconds=200))
|
|
trainer = Trainer(default_root_dir=tmpdir, callbacks=timer)
|
|
trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path)
|
|
assert timer._offset > 0
|
|
assert trainer.global_step == saved_global_step
|
|
|
|
|
|
@RunIf(skip_windows=True)
|
|
def test_timer_track_stages(tmpdir):
|
|
"""Test that the timer tracks time also for other stages (train/val/test)."""
|
|
# note: skipped on windows because time resolution of time.monotonic() is not high enough for this fast test
|
|
model = BoringModel()
|
|
timer = Timer()
|
|
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, callbacks=[timer])
|
|
trainer.fit(model)
|
|
assert timer.time_elapsed() == timer.time_elapsed("train") > 0
|
|
assert timer.time_elapsed("validate") > 0
|
|
assert timer.time_elapsed("test") == 0
|
|
trainer.test(model)
|
|
assert timer.time_elapsed("test") > 0
|