lightning/tests/trainer/connectors/test_callback_connector.py

56 lines
1.9 KiB
Python
Raw Normal View History

from unittest.mock import Mock
import torch
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ProgressBar
from tests.base import BoringModel
def test_checkpoint_callbacks_are_last(tmpdir):
""" Test that checkpoint callbacks always get moved to the end of the list, with preserved order. """
checkpoint1 = ModelCheckpoint(tmpdir)
checkpoint2 = ModelCheckpoint(tmpdir)
lr_monitor = LearningRateMonitor()
progress_bar = ProgressBar()
model = Mock()
model.configure_callbacks.return_value = []
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
class StatefulCallback0(Callback):
def on_save_checkpoint(self, trainer, pl_module):
return {"content0": 0}
class StatefulCallback1(Callback):
def on_save_checkpoint(self, trainer, pl_module):
return {"content1": 1}
def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
""" Test that all callback states get saved even if the ModelCheckpoint is not given as last. """
callback0 = StatefulCallback0()
callback1 = StatefulCallback1()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states")
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
limit_val_batches=1,
callbacks=[callback0, checkpoint_callback, callback1]
)
trainer.fit(model)
ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
state0 = ckpt["callbacks"][type(callback0)]
state1 = ckpt["callbacks"][type(callback1)]
assert "content0" in state0 and state0["content0"] == 0
assert "content1" in state1 and state1["content1"] == 1
assert type(checkpoint_callback) in ckpt["callbacks"]