lightning/tests/tests_pytorch/utilities/migration/test_migration.py

278 lines
13 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import ANY, MagicMock
import lightning.pytorch as pl
import pytest
import torch
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel
from lightning.pytorch.utilities.migration import migrate_checkpoint
from lightning.pytorch.utilities.migration.utils import _get_version, _set_legacy_version, _set_version
@pytest.mark.parametrize(
("old_checkpoint", "new_checkpoint"),
[
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}},
),
(
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
{
"epoch": 1,
"global_step": 23,
"callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}},
},
),
],
)
def test_migrate_model_checkpoint_early_stopping(old_checkpoint, new_checkpoint):
_set_version(old_checkpoint, "0.9.0")
_set_legacy_version(new_checkpoint, "0.9.0")
_set_version(new_checkpoint, pl.__version__)
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.0.0")
assert updated_checkpoint == old_checkpoint == new_checkpoint
assert _get_version(updated_checkpoint) == pl.__version__
def test_migrate_loop_global_step_to_progress_tracking():
old_checkpoint = {"global_step": 15, "epoch": 2}
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.6.0")
# automatic optimization
assert (
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]["optimizer"][
"step"
]["total"]["completed"]
== 15
)
# for manual optimization
assert (
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]["total"][
"completed"
]
== 15
)
def test_migrate_loop_current_epoch_to_progress_tracking():
old_checkpoint = {"global_step": 15, "epoch": 2}
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
assert updated_checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] == 2
@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
def test_migrate_loop_batches_that_stepped(tmpdir, model_class):
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
model = model_class()
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
# pretend we have a checkpoint produced in < v1.6.5; the key "_batches_that_stepped" didn't exist back then
ckpt = torch.load(ckpt_path)
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
_set_version(ckpt, "1.6.4")
torch.save(ckpt, ckpt_path)
class TestModel(model_class):
def on_train_start(self) -> None:
assert self.trainer.global_step == 1
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1
trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
model = TestModel()
trainer.fit(model, ckpt_path=ckpt_path)
new_loop = trainer.fit_loop.epoch_loop
assert new_loop.global_step == new_loop._batches_that_stepped == 2
@pytest.mark.parametrize("save_on_train_epoch_end", [None, True, False])
def test_migrate_model_checkpoint_save_on_train_epoch_end_default(save_on_train_epoch_end):
"""Test that the 'save_on_train_epoch_end' part of the ModelCheckpoint state key gets removed."""
legacy_state_key = (
f"ModelCheckpoint{{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
f" 'train_time_interval': None, 'save_on_train_epoch_end': {save_on_train_epoch_end}}}"
)
new_state_key = (
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None}"
)
old_checkpoint = {"callbacks": {legacy_state_key: {"dummy": 0}}, "global_step": 0, "epoch": 1}
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.9.0")
assert updated_checkpoint["callbacks"] == {new_state_key: {"dummy": 0}} # None -> None
def test_migrate_model_checkpoint_save_on_train_epoch_end_default_collision():
"""Test that the migration warns about collisions that would occur if the keys were modified."""
# The two keys only differ in the `save_on_train_epoch_end` value
state_key1 = (
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
)
state_key2 = (
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': False}"
)
old_checkpoint = {
"callbacks": {state_key1: {"dummy": 0}, state_key2: {"dummy": 0}},
"global_step": 0,
"epoch": 1,
}
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
with pytest.warns(PossibleUserWarning, match="callback states in this checkpoint.* colliding with each other"):
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="1.9.0")
assert updated_checkpoint["callbacks"] == old_checkpoint["callbacks"] # no migration was performed
def test_migrate_dropped_apex_amp_state(monkeypatch):
"""Test that the migration warns about collisions that would occur if the keys were modified."""
monkeypatch.setattr(pl, "__version__", "2.0.0") # pretend this version of Lightning is >= 2.0.0
old_checkpoint = {"amp_scaling_state": {"scale": 1.23}}
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
with pytest.warns(UserWarning, match="checkpoint contains apex AMP data"):
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy())
assert "amp_scaling_state" not in updated_checkpoint
def test_migrate_loop_structure_after_tbptt_removal():
"""Test the loop state migration after truncated backpropagation support was removed in 2.0.0, and with it the
training batch loop."""
# automatic- and manual optimization state are combined into a single checkpoint to simplify testing
state_automatic = MagicMock()
state_manual = MagicMock()
optim_progress_automatic = MagicMock()
optim_progress_manual = MagicMock()
old_batch_loop_state = MagicMock()
old_checkpoint = {
"loops": {
"fit_loop": {
"epoch_loop.state_dict": {"any": "state"},
"epoch_loop.batch_loop.state_dict": old_batch_loop_state,
"epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic,
"epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
"epoch_loop.batch_loop.manual_loop.state_dict": state_manual,
"epoch_loop.batch_loop.manual_loop.optim_step_progress": optim_progress_manual,
}
}
}
_set_version(old_checkpoint, "1.8.0") # pretend a checkpoint prior to 2.0.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
assert updated_checkpoint["loops"] == {
"fit_loop": {
"epoch_loop.state_dict": {"any": "state", "old_batch_loop_state_dict": old_batch_loop_state},
"epoch_loop.automatic_optimization.state_dict": state_automatic,
"epoch_loop.automatic_optimization.optim_progress": optim_progress_automatic,
"epoch_loop.manual_optimization.state_dict": state_manual,
"epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual,
}
}
def test_migrate_loop_structure_after_optimizer_loop_removal():
"""Test the loop state migration after multiple optimizer support in automatic optimization was removed in
2.0.0."""
state_automatic = MagicMock()
state_manual = MagicMock()
optim_progress_automatic = {
"optimizer": MagicMock(),
"optimizer_position": 33,
}
optim_progress_manual = MagicMock()
old_checkpoint = {
"loops": {
"fit_loop": {
"epoch_loop.state_dict": {"any": "state"},
"epoch_loop.batch_loop.state_dict": MagicMock(),
"epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic,
"epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
"epoch_loop.batch_loop.manual_loop.state_dict": state_manual,
"epoch_loop.batch_loop.manual_loop.optim_step_progress": optim_progress_manual,
}
}
}
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
assert updated_checkpoint["loops"] == {
"fit_loop": {
"epoch_loop.state_dict": ANY,
"epoch_loop.automatic_optimization.state_dict": state_automatic,
"epoch_loop.automatic_optimization.optim_progress": {"optimizer": ANY}, # optimizer_position gets dropped
"epoch_loop.manual_optimization.state_dict": state_manual,
"epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual,
}
}
def test_migrate_loop_structure_after_dataloader_loop_removal():
"""Test the loop state migration after the dataloader loops were removed in 2.0.0."""
old_dataloader_loop_state_dict = {
"state_dict": {},
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 123, "started": 0, "processed": 0, "completed": 0},
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
"is_last_batch": False,
},
}
old_checkpoint = {
"loops": {
"predict_loop": old_dataloader_loop_state_dict,
"validate_loop": dict(old_dataloader_loop_state_dict), # copy
"test_loop": dict(old_dataloader_loop_state_dict), # copy
}
}
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
assert updated_checkpoint["loops"] == {
"predict_loop": {
"batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
},
"state_dict": {},
},
"test_loop": {
"batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
},
"state_dict": {},
},
"validate_loop": {
"batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
},
"state_dict": {},
},
}