diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py index f937f86fcd..7bd42d7a5d 100644 --- a/src/pytorch_lightning/utilities/migration/migration.py +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -83,8 +83,8 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _ PR: #13645, #11805 """ global_step = checkpoint["global_step"] - checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) - checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()}) + checkpoint["loops"].setdefault("fit_loop", _get_fit_loop_initial_state_1_6_0()) # for automatic optimization optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"] optim_progress["optimizer"]["step"]["total"]["completed"] = global_step @@ -103,8 +103,8 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> PR: #11805 """ epoch = checkpoint["epoch"] - checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) - checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()}) + checkpoint["loops"].setdefault("fit_loop", _get_fit_loop_initial_state_1_6_0()) checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch return checkpoint @@ -121,48 +121,52 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT: return checkpoint -_FIT_LOOP_INITIAL_STATE_1_6_0 = { - "epoch_loop.batch_loop.manual_loop.optim_step_progress": { - "current": {"completed": 0, "ready": 0}, - "total": {"completed": 0, "ready": 0}, - }, - "epoch_loop.batch_loop.manual_loop.state_dict": {}, - "epoch_loop.batch_loop.optimizer_loop.optim_progress": { - "optimizer": { - "step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, - "zero_grad": { - "current": {"completed": 0, "ready": 0, "started": 0}, - "total": {"completed": 0, "ready": 0, "started": 0}, - }, +def _get_fit_loop_initial_state_1_6_0() -> Dict: + return { + "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, }, - "optimizer_position": 0, - }, - "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, - "epoch_loop.batch_loop.state_dict": {}, - "epoch_loop.batch_progress": { - "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - "is_last_batch": False, - "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - }, - "epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, - "epoch_loop.state_dict": {"_batches_that_stepped": 0}, - "epoch_loop.val_loop.dataloader_progress": { - "current": {"completed": 0, "ready": 0}, - "total": {"completed": 0, "ready": 0}, - }, - "epoch_loop.val_loop.epoch_loop.batch_progress": { - "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - "is_last_batch": False, - "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - }, - "epoch_loop.val_loop.epoch_loop.state_dict": {}, - "epoch_loop.val_loop.state_dict": {}, - "epoch_progress": { - "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - }, - "state_dict": {}, -} + "epoch_loop.batch_loop.manual_loop.state_dict": {}, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "optimizer": { + "step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "zero_grad": { + "current": {"completed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "ready": 0, "started": 0}, + }, + }, + "optimizer_position": 0, + }, + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.scheduler_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.state_dict": {"_batches_that_stepped": 0}, + "epoch_loop.val_loop.dataloader_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.val_loop.epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.state_dict": {}, + "epoch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "state_dict": {}, + } def _migrate_model_checkpoint_save_on_train_epoch_end_default(checkpoint: _CHECKPOINT) -> _CHECKPOINT: diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index f231445fc8..5c5fda7c90 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -30,9 +30,16 @@ _log = logging.getLogger(__name__) _CHECKPOINT = Dict[str, Any] -def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]: +def migrate_checkpoint( + checkpoint: _CHECKPOINT, target_version: Optional[str] = None +) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]: """Applies Lightning version migrations to a checkpoint dictionary. + Args: + checkpoint: A dictionary with the loaded state from the checkpoint file. + target_version: Run migrations only up to this version (inclusive), even if migration index contains + migration functions for newer versions than this target. Mainly useful for testing. + Note: The migration happens in-place. We specifically avoid copying the dict to avoid memory spikes for large checkpoints and objects that do not support being deep-copied. @@ -49,7 +56,7 @@ def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, index = _migration_index() applied_migrations = {} for migration_version, migration_functions in index.items(): - if not _should_upgrade(checkpoint, migration_version): + if not _should_upgrade(checkpoint, migration_version, target_version): continue for migration_function in migration_functions: checkpoint = migration_function(checkpoint) @@ -139,6 +146,7 @@ def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None: checkpoint.setdefault("legacy_pytorch-lightning_version", version) -def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: +def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: Optional[str] = None) -> bool: """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" - return Version(_get_version(checkpoint)) < Version(target) + is_lte_max_version = max_version is None or Version(target) <= Version(max_version) + return Version(_get_version(checkpoint)) < Version(target) and is_lte_max_version diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index d73a6f0d83..5c127b4903 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import ANY - import pytest import torch @@ -30,15 +28,15 @@ from pytorch_lightning.utilities.migration.utils import _get_version, _set_legac [ ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}, "loops": ANY}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}, "loops": ANY}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}, "loops": ANY}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}}, ), ( {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, @@ -46,16 +44,15 @@ from pytorch_lightning.utilities.migration.utils import _get_version, _set_legac "epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}, - "loops": ANY, }, ), ], ) -def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_checkpoint): +def test_migrate_model_checkpoint_early_stopping(old_checkpoint, new_checkpoint): _set_version(old_checkpoint, "0.9.0") _set_legacy_version(new_checkpoint, "0.9.0") _set_version(new_checkpoint, pl.__version__) - updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.0.0") assert updated_checkpoint == old_checkpoint == new_checkpoint assert _get_version(updated_checkpoint) == pl.__version__ @@ -63,7 +60,7 @@ def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_che def test_migrate_loop_global_step_to_progress_tracking(): old_checkpoint = {"global_step": 15, "epoch": 2} _set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0 - updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.6.0") # automatic optimization assert ( updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]["optimizer"][ @@ -125,7 +122,7 @@ def test_migrate_model_checkpoint_save_on_train_epoch_end_default(save_on_train_ ) old_checkpoint = {"callbacks": {legacy_state_key: {"dummy": 0}}, "global_step": 0, "epoch": 1} _set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0 - updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.9.0") assert updated_checkpoint["callbacks"] == {new_state_key: {"dummy": 0}} # None -> None @@ -147,5 +144,5 @@ def test_migrate_model_checkpoint_save_on_train_epoch_end_default_collision(): } _set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0 with pytest.warns(PossibleUserWarning, match="callback states in this checkpoint.* colliding with each other"): - updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy()) + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="1.9.0") assert updated_checkpoint["callbacks"] == old_checkpoint["callbacks"] # no migration was performed