Improvements to checkpoint migration (#16233)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-01-04 11:43:52 +01:00 committed by GitHub
parent 4a73fb8eee
commit 15536bf4dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 60 deletions

View File

@ -83,8 +83,8 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _
PR: #13645, #11805
"""
global_step = checkpoint["global_step"]
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()})
checkpoint["loops"].setdefault("fit_loop", _get_fit_loop_initial_state_1_6_0())
# for automatic optimization
optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]
optim_progress["optimizer"]["step"]["total"]["completed"] = global_step
@ -103,8 +103,8 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) ->
PR: #11805
"""
epoch = checkpoint["epoch"]
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()})
checkpoint["loops"].setdefault("fit_loop", _get_fit_loop_initial_state_1_6_0())
checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch
return checkpoint
@ -121,48 +121,52 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
return checkpoint
_FIT_LOOP_INITIAL_STATE_1_6_0 = {
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
"zero_grad": {
"current": {"completed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "ready": 0, "started": 0},
},
def _get_fit_loop_initial_state_1_6_0() -> Dict:
return {
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"optimizer_position": 0,
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
"epoch_loop.val_loop.dataloader_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.val_loop.epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.val_loop.epoch_loop.state_dict": {},
"epoch_loop.val_loop.state_dict": {},
"epoch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"state_dict": {},
}
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
"zero_grad": {
"current": {"completed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "ready": 0, "started": 0},
},
},
"optimizer_position": 0,
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.scheduler_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
"epoch_loop.val_loop.dataloader_progress": {
"current": {"completed": 0, "ready": 0},
"total": {"completed": 0, "ready": 0},
},
"epoch_loop.val_loop.epoch_loop.batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"epoch_loop.val_loop.epoch_loop.state_dict": {},
"epoch_loop.val_loop.state_dict": {},
"epoch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
},
"state_dict": {},
}
def _migrate_model_checkpoint_save_on_train_epoch_end_default(checkpoint: _CHECKPOINT) -> _CHECKPOINT:

View File

@ -30,9 +30,16 @@ _log = logging.getLogger(__name__)
_CHECKPOINT = Dict[str, Any]
def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]:
def migrate_checkpoint(
checkpoint: _CHECKPOINT, target_version: Optional[str] = None
) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]:
"""Applies Lightning version migrations to a checkpoint dictionary.
Args:
checkpoint: A dictionary with the loaded state from the checkpoint file.
target_version: Run migrations only up to this version (inclusive), even if migration index contains
migration functions for newer versions than this target. Mainly useful for testing.
Note:
The migration happens in-place. We specifically avoid copying the dict to avoid memory spikes for large
checkpoints and objects that do not support being deep-copied.
@ -49,7 +56,7 @@ def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str,
index = _migration_index()
applied_migrations = {}
for migration_version, migration_functions in index.items():
if not _should_upgrade(checkpoint, migration_version):
if not _should_upgrade(checkpoint, migration_version, target_version):
continue
for migration_function in migration_functions:
checkpoint = migration_function(checkpoint)
@ -139,6 +146,7 @@ def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None:
checkpoint.setdefault("legacy_pytorch-lightning_version", version)
def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool:
def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: Optional[str] = None) -> bool:
"""Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target."""
return Version(_get_version(checkpoint)) < Version(target)
is_lte_max_version = max_version is None or Version(target) <= Version(max_version)
return Version(_get_version(checkpoint)) < Version(target) and is_lte_max_version

View File

@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import ANY
import pytest
import torch
@ -30,15 +28,15 @@ from pytorch_lightning.utilities.migration.utils import _get_version, _set_legac
[
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}, "loops": ANY},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}, "loops": ANY},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}, "loops": ANY},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}},
),
(
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
@ -46,16 +44,15 @@ from pytorch_lightning.utilities.migration.utils import _get_version, _set_legac
"epoch": 1,
"global_step": 23,
"callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}},
"loops": ANY,
},
),
],
)
def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_checkpoint):
def test_migrate_model_checkpoint_early_stopping(old_checkpoint, new_checkpoint):
_set_version(old_checkpoint, "0.9.0")
_set_legacy_version(new_checkpoint, "0.9.0")
_set_version(new_checkpoint, pl.__version__)
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.0.0")
assert updated_checkpoint == old_checkpoint == new_checkpoint
assert _get_version(updated_checkpoint) == pl.__version__
@ -63,7 +60,7 @@ def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_che
def test_migrate_loop_global_step_to_progress_tracking():
old_checkpoint = {"global_step": 15, "epoch": 2}
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.6.0")
# automatic optimization
assert (
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]["optimizer"][
@ -125,7 +122,7 @@ def test_migrate_model_checkpoint_save_on_train_epoch_end_default(save_on_train_
)
old_checkpoint = {"callbacks": {legacy_state_key: {"dummy": 0}}, "global_step": 0, "epoch": 1}
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.9.0")
assert updated_checkpoint["callbacks"] == {new_state_key: {"dummy": 0}} # None -> None
@ -147,5 +144,5 @@ def test_migrate_model_checkpoint_save_on_train_epoch_end_default_collision():
}
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
with pytest.warns(PossibleUserWarning, match="callback states in this checkpoint.* colliding with each other"):
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy())
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="1.9.0")
assert updated_checkpoint["callbacks"] == old_checkpoint["callbacks"] # no migration was performed