Add a migration for the dataloader loops (#17125)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4ce1b652c9
commit
0cd837f0da
|
@ -50,6 +50,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
|
|||
_drop_apex_amp_state,
|
||||
_migrate_loop_structure_after_tbptt_removal,
|
||||
_migrate_loop_structure_after_optimizer_loop_removal,
|
||||
_migrate_loop_structure_after_dataloader_loop_removal,
|
||||
],
|
||||
}
|
||||
|
||||
|
@ -236,7 +237,8 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE
|
|||
"""
|
||||
if "loops" not in checkpoint:
|
||||
return checkpoint
|
||||
|
||||
if "fit_loop" not in checkpoint["loops"]:
|
||||
return checkpoint
|
||||
fit_loop = checkpoint["loops"]["fit_loop"]
|
||||
|
||||
# remap `x.batch_loop.y` to `x.y`
|
||||
|
@ -273,8 +275,10 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT
|
|||
"""
|
||||
if "loops" not in checkpoint:
|
||||
return checkpoint
|
||||
|
||||
if "fit_loop" not in checkpoint["loops"]:
|
||||
return checkpoint
|
||||
fit_loop = checkpoint["loops"]["fit_loop"]
|
||||
|
||||
# optimizer_position is no longer used
|
||||
if "epoch_loop.optimizer_loop.optim_progress" in fit_loop:
|
||||
fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None)
|
||||
|
@ -291,3 +295,25 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT
|
|||
"epoch_loop.manual_loop.optim_step_progress"
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def _migrate_loop_structure_after_dataloader_loop_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
|
||||
"""The dataloader loops (``_DataLoaderLoop``, ``_PredictionLoop`, and ``_EvaluationLoop``) were flattened into
|
||||
the ``_EvaluationEpochLoop`` (now ``_EvaluationLoop``) and ``_PredictionEpochLoop`` (now ``_PredictionLoop``).
|
||||
|
||||
Version: 2.0.0
|
||||
Commit: ec4f592ecfe238edd83185f6c6905fb1e2406d61
|
||||
PR: #16726
|
||||
"""
|
||||
if "loops" not in checkpoint:
|
||||
return checkpoint
|
||||
loops = checkpoint["loops"]
|
||||
for loop_key in ("predict_loop", "validate_loop", "test_loop"):
|
||||
if loop_key not in loops:
|
||||
continue
|
||||
loop = loops[loop_key]
|
||||
loop.pop("dataloader_progress", None) # no longer used
|
||||
epoch_loop_key = "epoch_loop."
|
||||
epoch_loop_dict = {k[len(epoch_loop_key) :]: loop.pop(k) for k in list(loop) if k.startswith(epoch_loop_key)}
|
||||
loop.update(epoch_loop_dict)
|
||||
return checkpoint
|
||||
|
|
|
@ -227,3 +227,52 @@ def test_migrate_loop_structure_after_optimizer_loop_removal():
|
|||
"epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_migrate_loop_structure_after_dataloader_loop_removal():
|
||||
"""Test the loop state migration after the dataloader loops were removed in 2.0.0."""
|
||||
old_dataloader_loop_state_dict = {
|
||||
"state_dict": {},
|
||||
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
|
||||
"epoch_loop.state_dict": {},
|
||||
"epoch_loop.batch_progress": {
|
||||
"total": {"ready": 123, "started": 0, "processed": 0, "completed": 0},
|
||||
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
|
||||
"is_last_batch": False,
|
||||
},
|
||||
}
|
||||
old_checkpoint = {
|
||||
"loops": {
|
||||
"predict_loop": old_dataloader_loop_state_dict,
|
||||
"validate_loop": dict(old_dataloader_loop_state_dict), # copy
|
||||
"test_loop": dict(old_dataloader_loop_state_dict), # copy
|
||||
}
|
||||
}
|
||||
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
|
||||
assert updated_checkpoint["loops"] == {
|
||||
"predict_loop": {
|
||||
"batch_progress": {
|
||||
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
"is_last_batch": False,
|
||||
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
|
||||
},
|
||||
"state_dict": {},
|
||||
},
|
||||
"test_loop": {
|
||||
"batch_progress": {
|
||||
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
"is_last_batch": False,
|
||||
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
|
||||
},
|
||||
"state_dict": {},
|
||||
},
|
||||
"validate_loop": {
|
||||
"batch_progress": {
|
||||
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
|
||||
"is_last_batch": False,
|
||||
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
|
||||
},
|
||||
"state_dict": {},
|
||||
},
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue