217 lines
7.8 KiB
Python
217 lines
7.8 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
|
|
import torch
|
|
|
|
from pytorch_lightning import Callback, LightningModule, Trainer
|
|
from pytorch_lightning.callbacks import (
|
|
EarlyStopping,
|
|
GradientAccumulationScheduler,
|
|
LearningRateMonitor,
|
|
ModelCheckpoint,
|
|
ModelSummary,
|
|
ProgressBarBase,
|
|
TQDMProgressBar,
|
|
)
|
|
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
|
|
from tests.helpers import BoringModel
|
|
|
|
|
|
def test_checkpoint_callbacks_are_last(tmpdir):
|
|
"""Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
|
|
checkpoint1 = ModelCheckpoint(tmpdir)
|
|
checkpoint2 = ModelCheckpoint(tmpdir)
|
|
model_summary = ModelSummary()
|
|
early_stopping = EarlyStopping(monitor="foo")
|
|
lr_monitor = LearningRateMonitor()
|
|
progress_bar = TQDMProgressBar()
|
|
|
|
# no model reference
|
|
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
|
|
assert trainer.callbacks == [
|
|
progress_bar,
|
|
lr_monitor,
|
|
model_summary,
|
|
trainer.accumulation_scheduler,
|
|
checkpoint1,
|
|
checkpoint2,
|
|
]
|
|
|
|
# no model callbacks
|
|
model = LightningModule()
|
|
model.configure_callbacks = lambda: []
|
|
trainer.model = model
|
|
cb_connector = CallbackConnector(trainer)
|
|
cb_connector._attach_model_callbacks()
|
|
assert trainer.callbacks == [
|
|
progress_bar,
|
|
lr_monitor,
|
|
model_summary,
|
|
trainer.accumulation_scheduler,
|
|
checkpoint1,
|
|
checkpoint2,
|
|
]
|
|
|
|
# with model-specific callbacks that substitute ones in Trainer
|
|
model = LightningModule()
|
|
model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2]
|
|
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
|
|
trainer.model = model
|
|
cb_connector = CallbackConnector(trainer)
|
|
cb_connector._attach_model_callbacks()
|
|
assert trainer.callbacks == [
|
|
progress_bar,
|
|
lr_monitor,
|
|
trainer.accumulation_scheduler,
|
|
early_stopping,
|
|
model_summary,
|
|
checkpoint1,
|
|
checkpoint2,
|
|
]
|
|
|
|
|
|
class StatefulCallback0(Callback):
|
|
def state_dict(self):
|
|
return {"content0": 0}
|
|
|
|
|
|
class StatefulCallback1(Callback):
|
|
def __init__(self, unique=None, other=None):
|
|
self._unique = unique
|
|
self._other = other
|
|
|
|
@property
|
|
def state_key(self):
|
|
return self._generate_state_key(unique=self._unique)
|
|
|
|
def state_dict(self):
|
|
return {"content1": self._unique}
|
|
|
|
|
|
def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
|
|
"""Test that all callback states get saved even if the ModelCheckpoint is not given as last and when there are
|
|
multiple callbacks of the same type."""
|
|
|
|
callback0 = StatefulCallback0()
|
|
callback1 = StatefulCallback1(unique="one")
|
|
callback2 = StatefulCallback1(unique="two", other=2)
|
|
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states")
|
|
model = BoringModel()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_steps=1,
|
|
limit_val_batches=1,
|
|
callbacks=[
|
|
callback0,
|
|
# checkpoint callback does not have to be at the end
|
|
checkpoint_callback,
|
|
# callback2 and callback3 have the same type
|
|
callback1,
|
|
callback2,
|
|
],
|
|
)
|
|
trainer.fit(model)
|
|
|
|
ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
|
|
state0 = ckpt["callbacks"]["StatefulCallback0"]
|
|
state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"]
|
|
state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"]
|
|
assert "content0" in state0 and state0["content0"] == 0
|
|
assert "content1" in state1 and state1["content1"] == "one"
|
|
assert "content1" in state2 and state2["content1"] == "two"
|
|
assert (
|
|
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
|
" 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"]
|
|
)
|
|
|
|
|
|
def test_attach_model_callbacks():
|
|
"""Test that the callbacks defined in the model and through Trainer get merged correctly."""
|
|
|
|
def _attach_callbacks(trainer_callbacks, model_callbacks):
|
|
model = LightningModule()
|
|
model.configure_callbacks = lambda: model_callbacks
|
|
has_progress_bar = any(isinstance(cb, ProgressBarBase) for cb in trainer_callbacks + model_callbacks)
|
|
trainer = Trainer(
|
|
enable_checkpointing=False,
|
|
enable_progress_bar=has_progress_bar,
|
|
enable_model_summary=False,
|
|
callbacks=trainer_callbacks,
|
|
)
|
|
trainer.model = model
|
|
cb_connector = CallbackConnector(trainer)
|
|
cb_connector._attach_model_callbacks()
|
|
return trainer
|
|
|
|
early_stopping = EarlyStopping(monitor="foo")
|
|
progress_bar = TQDMProgressBar()
|
|
lr_monitor = LearningRateMonitor()
|
|
grad_accumulation = GradientAccumulationScheduler({1: 1})
|
|
|
|
# no callbacks
|
|
trainer = _attach_callbacks(trainer_callbacks=[], model_callbacks=[])
|
|
assert trainer.callbacks == [trainer.accumulation_scheduler]
|
|
|
|
# callbacks of different types
|
|
trainer = _attach_callbacks(trainer_callbacks=[early_stopping], model_callbacks=[progress_bar])
|
|
assert trainer.callbacks == [early_stopping, trainer.accumulation_scheduler, progress_bar]
|
|
|
|
# same callback type twice, different instance
|
|
trainer = _attach_callbacks(
|
|
trainer_callbacks=[progress_bar, EarlyStopping(monitor="foo")],
|
|
model_callbacks=[early_stopping],
|
|
)
|
|
assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping]
|
|
|
|
# multiple callbacks of the same type in trainer
|
|
trainer = _attach_callbacks(
|
|
trainer_callbacks=[
|
|
LearningRateMonitor(),
|
|
EarlyStopping(monitor="foo"),
|
|
LearningRateMonitor(),
|
|
EarlyStopping(monitor="foo"),
|
|
],
|
|
model_callbacks=[early_stopping, lr_monitor],
|
|
)
|
|
assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor]
|
|
|
|
# multiple callbacks of the same type, in both trainer and model
|
|
trainer = _attach_callbacks(
|
|
trainer_callbacks=[
|
|
LearningRateMonitor(),
|
|
progress_bar,
|
|
EarlyStopping(monitor="foo"),
|
|
LearningRateMonitor(),
|
|
EarlyStopping(monitor="foo"),
|
|
],
|
|
model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping],
|
|
)
|
|
assert trainer.callbacks == [progress_bar, early_stopping, lr_monitor, grad_accumulation, early_stopping]
|
|
|
|
|
|
def test_attach_model_callbacks_override_info(caplog):
|
|
"""Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
|
|
model = LightningModule()
|
|
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping(monitor="foo")]
|
|
trainer = Trainer(
|
|
enable_checkpointing=False, callbacks=[EarlyStopping(monitor="foo"), LearningRateMonitor(), TQDMProgressBar()]
|
|
)
|
|
trainer.model = model
|
|
cb_connector = CallbackConnector(trainer)
|
|
with caplog.at_level(logging.INFO):
|
|
cb_connector._attach_model_callbacks()
|
|
|
|
assert "existing callbacks passed to Trainer: EarlyStopping, LearningRateMonitor" in caplog.text
|