278 lines
13 KiB
Python
278 lines
13 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from unittest.mock import ANY, MagicMock
|
|
|
|
import lightning.pytorch as pl
|
|
import pytest
|
|
import torch
|
|
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
|
from lightning.pytorch import Trainer
|
|
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
|
from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel
|
|
from lightning.pytorch.utilities.migration import migrate_checkpoint
|
|
from lightning.pytorch.utilities.migration.utils import _get_version, _set_legacy_version, _set_version
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("old_checkpoint", "new_checkpoint"),
|
|
[
|
|
(
|
|
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
|
|
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
|
|
),
|
|
(
|
|
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
|
|
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
|
|
),
|
|
(
|
|
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"},
|
|
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}},
|
|
),
|
|
(
|
|
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
|
|
{
|
|
"epoch": 1,
|
|
"global_step": 23,
|
|
"callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}},
|
|
},
|
|
),
|
|
],
|
|
)
|
|
def test_migrate_model_checkpoint_early_stopping(old_checkpoint, new_checkpoint):
|
|
_set_version(old_checkpoint, "0.9.0")
|
|
_set_legacy_version(new_checkpoint, "0.9.0")
|
|
_set_version(new_checkpoint, pl.__version__)
|
|
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.0.0")
|
|
assert updated_checkpoint == old_checkpoint == new_checkpoint
|
|
assert _get_version(updated_checkpoint) == pl.__version__
|
|
|
|
|
|
def test_migrate_loop_global_step_to_progress_tracking():
|
|
old_checkpoint = {"global_step": 15, "epoch": 2}
|
|
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
|
|
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.6.0")
|
|
# automatic optimization
|
|
assert (
|
|
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]["optimizer"][
|
|
"step"
|
|
]["total"]["completed"]
|
|
== 15
|
|
)
|
|
# for manual optimization
|
|
assert (
|
|
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]["total"][
|
|
"completed"
|
|
]
|
|
== 15
|
|
)
|
|
|
|
|
|
def test_migrate_loop_current_epoch_to_progress_tracking():
|
|
old_checkpoint = {"global_step": 15, "epoch": 2}
|
|
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
|
|
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
|
|
assert updated_checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] == 2
|
|
|
|
|
|
@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel])
|
|
def test_migrate_loop_batches_that_stepped(tmpdir, model_class):
|
|
trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir)
|
|
model = model_class()
|
|
trainer.fit(model)
|
|
ckpt_path = trainer.checkpoint_callback.best_model_path
|
|
|
|
# pretend we have a checkpoint produced in < v1.6.5; the key "_batches_that_stepped" didn't exist back then
|
|
ckpt = torch.load(ckpt_path)
|
|
del ckpt["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"]
|
|
_set_version(ckpt, "1.6.4")
|
|
torch.save(ckpt, ckpt_path)
|
|
|
|
class TestModel(model_class):
|
|
def on_train_start(self) -> None:
|
|
assert self.trainer.global_step == 1
|
|
assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1
|
|
|
|
trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir)
|
|
model = TestModel()
|
|
trainer.fit(model, ckpt_path=ckpt_path)
|
|
new_loop = trainer.fit_loop.epoch_loop
|
|
assert new_loop.global_step == new_loop._batches_that_stepped == 2
|
|
|
|
|
|
@pytest.mark.parametrize("save_on_train_epoch_end", [None, True, False])
|
|
def test_migrate_model_checkpoint_save_on_train_epoch_end_default(save_on_train_epoch_end):
|
|
"""Test that the 'save_on_train_epoch_end' part of the ModelCheckpoint state key gets removed."""
|
|
legacy_state_key = (
|
|
f"ModelCheckpoint{{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
|
f" 'train_time_interval': None, 'save_on_train_epoch_end': {save_on_train_epoch_end}}}"
|
|
)
|
|
new_state_key = (
|
|
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
|
" 'train_time_interval': None}"
|
|
)
|
|
old_checkpoint = {"callbacks": {legacy_state_key: {"dummy": 0}}, "global_step": 0, "epoch": 1}
|
|
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
|
|
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.9.0")
|
|
assert updated_checkpoint["callbacks"] == {new_state_key: {"dummy": 0}} # None -> None
|
|
|
|
|
|
def test_migrate_model_checkpoint_save_on_train_epoch_end_default_collision():
|
|
"""Test that the migration warns about collisions that would occur if the keys were modified."""
|
|
# The two keys only differ in the `save_on_train_epoch_end` value
|
|
state_key1 = (
|
|
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
|
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
|
|
)
|
|
state_key2 = (
|
|
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
|
" 'train_time_interval': None, 'save_on_train_epoch_end': False}"
|
|
)
|
|
old_checkpoint = {
|
|
"callbacks": {state_key1: {"dummy": 0}, state_key2: {"dummy": 0}},
|
|
"global_step": 0,
|
|
"epoch": 1,
|
|
}
|
|
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
|
|
with pytest.warns(PossibleUserWarning, match="callback states in this checkpoint.* colliding with each other"):
|
|
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="1.9.0")
|
|
assert updated_checkpoint["callbacks"] == old_checkpoint["callbacks"] # no migration was performed
|
|
|
|
|
|
def test_migrate_dropped_apex_amp_state(monkeypatch):
|
|
"""Test that the migration warns about collisions that would occur if the keys were modified."""
|
|
monkeypatch.setattr(pl, "__version__", "2.0.0") # pretend this version of Lightning is >= 2.0.0
|
|
old_checkpoint = {"amp_scaling_state": {"scale": 1.23}}
|
|
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
|
|
with pytest.warns(UserWarning, match="checkpoint contains apex AMP data"):
|
|
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy())
|
|
assert "amp_scaling_state" not in updated_checkpoint
|
|
|
|
|
|
def test_migrate_loop_structure_after_tbptt_removal():
|
|
"""Test the loop state migration after truncated backpropagation support was removed in 2.0.0, and with it the
|
|
training batch loop."""
|
|
# automatic- and manual optimization state are combined into a single checkpoint to simplify testing
|
|
state_automatic = MagicMock()
|
|
state_manual = MagicMock()
|
|
optim_progress_automatic = MagicMock()
|
|
optim_progress_manual = MagicMock()
|
|
old_batch_loop_state = MagicMock()
|
|
old_checkpoint = {
|
|
"loops": {
|
|
"fit_loop": {
|
|
"epoch_loop.state_dict": {"any": "state"},
|
|
"epoch_loop.batch_loop.state_dict": old_batch_loop_state,
|
|
"epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic,
|
|
"epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
|
|
"epoch_loop.batch_loop.manual_loop.state_dict": state_manual,
|
|
"epoch_loop.batch_loop.manual_loop.optim_step_progress": optim_progress_manual,
|
|
}
|
|
}
|
|
}
|
|
_set_version(old_checkpoint, "1.8.0") # pretend a checkpoint prior to 2.0.0
|
|
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
|
|
assert updated_checkpoint["loops"] == {
|
|
"fit_loop": {
|
|
"epoch_loop.state_dict": {"any": "state", "old_batch_loop_state_dict": old_batch_loop_state},
|
|
"epoch_loop.automatic_optimization.state_dict": state_automatic,
|
|
"epoch_loop.automatic_optimization.optim_progress": optim_progress_automatic,
|
|
"epoch_loop.manual_optimization.state_dict": state_manual,
|
|
"epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual,
|
|
}
|
|
}
|
|
|
|
|
|
def test_migrate_loop_structure_after_optimizer_loop_removal():
|
|
"""Test the loop state migration after multiple optimizer support in automatic optimization was removed in
|
|
2.0.0."""
|
|
state_automatic = MagicMock()
|
|
state_manual = MagicMock()
|
|
optim_progress_automatic = {
|
|
"optimizer": MagicMock(),
|
|
"optimizer_position": 33,
|
|
}
|
|
optim_progress_manual = MagicMock()
|
|
old_checkpoint = {
|
|
"loops": {
|
|
"fit_loop": {
|
|
"epoch_loop.state_dict": {"any": "state"},
|
|
"epoch_loop.batch_loop.state_dict": MagicMock(),
|
|
"epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic,
|
|
"epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
|
|
"epoch_loop.batch_loop.manual_loop.state_dict": state_manual,
|
|
"epoch_loop.batch_loop.manual_loop.optim_step_progress": optim_progress_manual,
|
|
}
|
|
}
|
|
}
|
|
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
|
|
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
|
|
assert updated_checkpoint["loops"] == {
|
|
"fit_loop": {
|
|
"epoch_loop.state_dict": ANY,
|
|
"epoch_loop.automatic_optimization.state_dict": state_automatic,
|
|
"epoch_loop.automatic_optimization.optim_progress": {"optimizer": ANY}, # optimizer_position gets dropped
|
|
"epoch_loop.manual_optimization.state_dict": state_manual,
|
|
"epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual,
|
|
}
|
|
}
|
|
|
|
|
|
def test_migrate_loop_structure_after_dataloader_loop_removal():
|
|
"""Test the loop state migration after the dataloader loops were removed in 2.0.0."""
|
|
old_dataloader_loop_state_dict = {
|
|
"state_dict": {},
|
|
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
|
|
"epoch_loop.state_dict": {},
|
|
"epoch_loop.batch_progress": {
|
|
"total": {"ready": 123, "started": 0, "processed": 0, "completed": 0},
|
|
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
|
|
"is_last_batch": False,
|
|
},
|
|
}
|
|
old_checkpoint = {
|
|
"loops": {
|
|
"predict_loop": old_dataloader_loop_state_dict,
|
|
"validate_loop": dict(old_dataloader_loop_state_dict), # copy
|
|
"test_loop": dict(old_dataloader_loop_state_dict), # copy
|
|
}
|
|
}
|
|
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
|
|
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
|
|
assert updated_checkpoint["loops"] == {
|
|
"predict_loop": {
|
|
"batch_progress": {
|
|
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
|
"is_last_batch": False,
|
|
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
|
|
},
|
|
"state_dict": {},
|
|
},
|
|
"test_loop": {
|
|
"batch_progress": {
|
|
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
|
"is_last_batch": False,
|
|
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
|
|
},
|
|
"state_dict": {},
|
|
},
|
|
"validate_loop": {
|
|
"batch_progress": {
|
|
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
|
"is_last_batch": False,
|
|
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
|
|
},
|
|
"state_dict": {},
|
|
},
|
|
}
|