Improvements to checkpoint migration (#16233)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
4a73fb8eee
commit
15536bf4dd
|
@ -83,8 +83,8 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _
|
|||
PR: #13645, #11805
|
||||
"""
|
||||
global_step = checkpoint["global_step"]
|
||||
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
|
||||
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
|
||||
checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()})
|
||||
checkpoint["loops"].setdefault("fit_loop", _get_fit_loop_initial_state_1_6_0())
|
||||
# for automatic optimization
|
||||
optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]
|
||||
optim_progress["optimizer"]["step"]["total"]["completed"] = global_step
|
||||
|
@ -103,8 +103,8 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) ->
|
|||
PR: #11805
|
||||
"""
|
||||
epoch = checkpoint["epoch"]
|
||||
checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0})
|
||||
checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0)
|
||||
checkpoint.setdefault("loops", {"fit_loop": _get_fit_loop_initial_state_1_6_0()})
|
||||
checkpoint["loops"].setdefault("fit_loop", _get_fit_loop_initial_state_1_6_0())
|
||||
checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch
|
||||
return checkpoint
|
||||
|
||||
|
@ -121,48 +121,52 @@ def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
|
|||
return checkpoint
|
||||
|
||||
|
||||
_FIT_LOOP_INITIAL_STATE_1_6_0 = {
|
||||
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
|
||||
"current": {"completed": 0, "ready": 0},
|
||||
"total": {"completed": 0, "ready": 0},
|
||||
},
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer": {
|
||||
"step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
|
||||
"zero_grad": {
|
||||
"current": {"completed": 0, "ready": 0, "started": 0},
|
||||
"total": {"completed": 0, "ready": 0, "started": 0},
|
||||
},
|
||||
def _get_fit_loop_initial_state_1_6_0() -> Dict:
|
||||
return {
|
||||
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
|
||||
"current": {"completed": 0, "ready": 0},
|
||||
"total": {"completed": 0, "ready": 0},
|
||||
},
|
||||
"optimizer_position": 0,
|
||||
},
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.state_dict": {},
|
||||
"epoch_loop.batch_progress": {
|
||||
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
"is_last_batch": False,
|
||||
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
},
|
||||
"epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
|
||||
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
|
||||
"epoch_loop.val_loop.dataloader_progress": {
|
||||
"current": {"completed": 0, "ready": 0},
|
||||
"total": {"completed": 0, "ready": 0},
|
||||
},
|
||||
"epoch_loop.val_loop.epoch_loop.batch_progress": {
|
||||
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
"is_last_batch": False,
|
||||
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
},
|
||||
"epoch_loop.val_loop.epoch_loop.state_dict": {},
|
||||
"epoch_loop.val_loop.state_dict": {},
|
||||
"epoch_progress": {
|
||||
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
},
|
||||
"state_dict": {},
|
||||
}
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer": {
|
||||
"step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}},
|
||||
"zero_grad": {
|
||||
"current": {"completed": 0, "ready": 0, "started": 0},
|
||||
"total": {"completed": 0, "ready": 0, "started": 0},
|
||||
},
|
||||
},
|
||||
"optimizer_position": 0,
|
||||
},
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.state_dict": {},
|
||||
"epoch_loop.batch_progress": {
|
||||
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
"is_last_batch": False,
|
||||
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
},
|
||||
"epoch_loop.scheduler_progress": {
|
||||
"current": {"completed": 0, "ready": 0},
|
||||
"total": {"completed": 0, "ready": 0},
|
||||
},
|
||||
"epoch_loop.state_dict": {"_batches_that_stepped": 0},
|
||||
"epoch_loop.val_loop.dataloader_progress": {
|
||||
"current": {"completed": 0, "ready": 0},
|
||||
"total": {"completed": 0, "ready": 0},
|
||||
},
|
||||
"epoch_loop.val_loop.epoch_loop.batch_progress": {
|
||||
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
"is_last_batch": False,
|
||||
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
},
|
||||
"epoch_loop.val_loop.epoch_loop.state_dict": {},
|
||||
"epoch_loop.val_loop.state_dict": {},
|
||||
"epoch_progress": {
|
||||
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
"total": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
},
|
||||
"state_dict": {},
|
||||
}
|
||||
|
||||
|
||||
def _migrate_model_checkpoint_save_on_train_epoch_end_default(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
|
||||
|
|
|
@ -30,9 +30,16 @@ _log = logging.getLogger(__name__)
|
|||
_CHECKPOINT = Dict[str, Any]
|
||||
|
||||
|
||||
def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]:
|
||||
def migrate_checkpoint(
|
||||
checkpoint: _CHECKPOINT, target_version: Optional[str] = None
|
||||
) -> Tuple[_CHECKPOINT, Dict[str, List[str]]]:
|
||||
"""Applies Lightning version migrations to a checkpoint dictionary.
|
||||
|
||||
Args:
|
||||
checkpoint: A dictionary with the loaded state from the checkpoint file.
|
||||
target_version: Run migrations only up to this version (inclusive), even if migration index contains
|
||||
migration functions for newer versions than this target. Mainly useful for testing.
|
||||
|
||||
Note:
|
||||
The migration happens in-place. We specifically avoid copying the dict to avoid memory spikes for large
|
||||
checkpoints and objects that do not support being deep-copied.
|
||||
|
@ -49,7 +56,7 @@ def migrate_checkpoint(checkpoint: _CHECKPOINT) -> Tuple[_CHECKPOINT, Dict[str,
|
|||
index = _migration_index()
|
||||
applied_migrations = {}
|
||||
for migration_version, migration_functions in index.items():
|
||||
if not _should_upgrade(checkpoint, migration_version):
|
||||
if not _should_upgrade(checkpoint, migration_version, target_version):
|
||||
continue
|
||||
for migration_function in migration_functions:
|
||||
checkpoint = migration_function(checkpoint)
|
||||
|
@ -139,6 +146,7 @@ def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None:
|
|||
checkpoint.setdefault("legacy_pytorch-lightning_version", version)
|
||||
|
||||
|
||||
def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool:
|
||||
def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: Optional[str] = None) -> bool:
|
||||
"""Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target."""
|
||||
return Version(_get_version(checkpoint)) < Version(target)
|
||||
is_lte_max_version = max_version is None or Version(target) <= Version(max_version)
|
||||
return Version(_get_version(checkpoint)) < Version(target) and is_lte_max_version
|
||||
|
|
|
@ -11,8 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -30,15 +28,15 @@ from pytorch_lightning.utilities.migration.utils import _get_version, _set_legac
|
|||
[
|
||||
(
|
||||
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}, "loops": ANY},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
|
||||
),
|
||||
(
|
||||
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}, "loops": ANY},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
|
||||
),
|
||||
(
|
||||
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}, "loops": ANY},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}},
|
||||
),
|
||||
(
|
||||
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
|
||||
|
@ -46,16 +44,15 @@ from pytorch_lightning.utilities.migration.utils import _get_version, _set_legac
|
|||
"epoch": 1,
|
||||
"global_step": 23,
|
||||
"callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}},
|
||||
"loops": ANY,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_checkpoint):
|
||||
def test_migrate_model_checkpoint_early_stopping(old_checkpoint, new_checkpoint):
|
||||
_set_version(old_checkpoint, "0.9.0")
|
||||
_set_legacy_version(new_checkpoint, "0.9.0")
|
||||
_set_version(new_checkpoint, pl.__version__)
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.0.0")
|
||||
assert updated_checkpoint == old_checkpoint == new_checkpoint
|
||||
assert _get_version(updated_checkpoint) == pl.__version__
|
||||
|
||||
|
@ -63,7 +60,7 @@ def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_che
|
|||
def test_migrate_loop_global_step_to_progress_tracking():
|
||||
old_checkpoint = {"global_step": 15, "epoch": 2}
|
||||
_set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.6.0")
|
||||
# automatic optimization
|
||||
assert (
|
||||
updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]["optimizer"][
|
||||
|
@ -125,7 +122,7 @@ def test_migrate_model_checkpoint_save_on_train_epoch_end_default(save_on_train_
|
|||
)
|
||||
old_checkpoint = {"callbacks": {legacy_state_key: {"dummy": 0}}, "global_step": 0, "epoch": 1}
|
||||
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint)
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint, target_version="1.9.0")
|
||||
assert updated_checkpoint["callbacks"] == {new_state_key: {"dummy": 0}} # None -> None
|
||||
|
||||
|
||||
|
@ -147,5 +144,5 @@ def test_migrate_model_checkpoint_save_on_train_epoch_end_default_collision():
|
|||
}
|
||||
_set_version(old_checkpoint, "1.8.9") # pretend a checkpoint prior to 1.9.0
|
||||
with pytest.warns(PossibleUserWarning, match="callback states in this checkpoint.* colliding with each other"):
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy())
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="1.9.0")
|
||||
assert updated_checkpoint["callbacks"] == old_checkpoint["callbacks"] # no migration was performed
|
||||
|
|
Loading…
Reference in New Issue