142 lines
5.3 KiB
Python
142 lines
5.3 KiB
Python
import logging
|
|
from unittest.mock import Mock
|
|
|
|
import torch
|
|
|
|
from pytorch_lightning import Callback, Trainer
|
|
from pytorch_lightning.callbacks import (
|
|
EarlyStopping,
|
|
GradientAccumulationScheduler,
|
|
LearningRateMonitor,
|
|
ModelCheckpoint,
|
|
ProgressBar,
|
|
)
|
|
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
|
|
from tests.helpers import BoringModel
|
|
|
|
|
|
def test_checkpoint_callbacks_are_last(tmpdir):
|
|
""" Test that checkpoint callbacks always get moved to the end of the list, with preserved order. """
|
|
checkpoint1 = ModelCheckpoint(tmpdir)
|
|
checkpoint2 = ModelCheckpoint(tmpdir)
|
|
early_stopping = EarlyStopping()
|
|
lr_monitor = LearningRateMonitor()
|
|
progress_bar = ProgressBar()
|
|
|
|
# no model callbacks
|
|
model = Mock()
|
|
model.configure_callbacks.return_value = []
|
|
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
|
|
cb_connector = CallbackConnector(trainer)
|
|
cb_connector._attach_model_callbacks(model, trainer)
|
|
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
|
|
|
|
# with model-specific callbacks that substitute ones in Trainer
|
|
model = Mock()
|
|
model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2]
|
|
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
|
|
cb_connector = CallbackConnector(trainer)
|
|
cb_connector._attach_model_callbacks(model, trainer)
|
|
assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2]
|
|
|
|
|
|
class StatefulCallback0(Callback):
|
|
|
|
def on_save_checkpoint(self, *args):
|
|
return {"content0": 0}
|
|
|
|
|
|
class StatefulCallback1(Callback):
|
|
|
|
def on_save_checkpoint(self, *args):
|
|
return {"content1": 1}
|
|
|
|
|
|
def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
|
|
""" Test that all callback states get saved even if the ModelCheckpoint is not given as last. """
|
|
|
|
callback0 = StatefulCallback0()
|
|
callback1 = StatefulCallback1()
|
|
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states")
|
|
model = BoringModel()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_steps=1,
|
|
limit_val_batches=1,
|
|
callbacks=[callback0, checkpoint_callback, callback1]
|
|
)
|
|
trainer.fit(model)
|
|
|
|
ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
|
|
state0 = ckpt["callbacks"][type(callback0)]
|
|
state1 = ckpt["callbacks"][type(callback1)]
|
|
assert "content0" in state0 and state0["content0"] == 0
|
|
assert "content1" in state1 and state1["content1"] == 1
|
|
assert type(checkpoint_callback) in ckpt["callbacks"]
|
|
|
|
|
|
def test_attach_model_callbacks():
|
|
""" Test that the callbacks defined in the model and through Trainer get merged correctly. """
|
|
|
|
def assert_composition(trainer_callbacks, model_callbacks, expected):
|
|
model = Mock()
|
|
model.configure_callbacks.return_value = model_callbacks
|
|
trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks)
|
|
cb_connector = CallbackConnector(trainer)
|
|
cb_connector._attach_model_callbacks(model, trainer)
|
|
assert trainer.callbacks == expected
|
|
|
|
early_stopping = EarlyStopping()
|
|
progress_bar = ProgressBar()
|
|
lr_monitor = LearningRateMonitor()
|
|
grad_accumulation = GradientAccumulationScheduler({1: 1})
|
|
|
|
# no callbacks
|
|
assert_composition(trainer_callbacks=[], model_callbacks=[], expected=[])
|
|
|
|
# callbacks of different types
|
|
assert_composition(
|
|
trainer_callbacks=[early_stopping], model_callbacks=[progress_bar], expected=[early_stopping, progress_bar]
|
|
)
|
|
|
|
# same callback type twice, different instance
|
|
assert_composition(
|
|
trainer_callbacks=[progress_bar, EarlyStopping()],
|
|
model_callbacks=[early_stopping],
|
|
expected=[progress_bar, early_stopping]
|
|
)
|
|
|
|
# multiple callbacks of the same type in trainer
|
|
assert_composition(
|
|
trainer_callbacks=[LearningRateMonitor(),
|
|
EarlyStopping(),
|
|
LearningRateMonitor(),
|
|
EarlyStopping()],
|
|
model_callbacks=[early_stopping, lr_monitor],
|
|
expected=[early_stopping, lr_monitor]
|
|
)
|
|
|
|
# multiple callbacks of the same type, in both trainer and model
|
|
assert_composition(
|
|
trainer_callbacks=[
|
|
LearningRateMonitor(), progress_bar,
|
|
EarlyStopping(),
|
|
LearningRateMonitor(),
|
|
EarlyStopping()
|
|
],
|
|
model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping],
|
|
expected=[progress_bar, early_stopping, lr_monitor, grad_accumulation, early_stopping]
|
|
)
|
|
|
|
|
|
def test_attach_model_callbacks_override_info(caplog):
|
|
""" Test that the logs contain the info about overriding callbacks returned by configure_callbacks. """
|
|
model = Mock()
|
|
model.configure_callbacks.return_value = [LearningRateMonitor(), EarlyStopping()]
|
|
trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()])
|
|
cb_connector = CallbackConnector(trainer)
|
|
with caplog.at_level(logging.INFO):
|
|
cb_connector._attach_model_callbacks(model, trainer)
|
|
|
|
assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text
|