144 lines
6.0 KiB
Python
144 lines
6.0 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import sys
|
|
|
|
import pytest
|
|
|
|
import pytorch_lightning as pl
|
|
from lightning_lite.utilities.warnings import PossibleUserWarning
|
|
from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch
|
|
from pytorch_lightning.utilities.migration.utils import _pl_migrate_checkpoint
|
|
|
|
|
|
def test_patch_legacy_argparse_utils():
|
|
with pl_legacy_patch():
|
|
from pytorch_lightning.utilities import argparse_utils
|
|
|
|
assert callable(argparse_utils._gpus_arg_default)
|
|
assert "pytorch_lightning.utilities.argparse_utils" in sys.modules
|
|
|
|
assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules
|
|
|
|
|
|
def test_patch_legacy_gpus_arg_default():
|
|
with pl_legacy_patch():
|
|
from pytorch_lightning.utilities.argparse import _gpus_arg_default
|
|
|
|
assert callable(_gpus_arg_default)
|
|
assert not hasattr(pl.utilities.argparse, "_gpus_arg_default")
|
|
assert not hasattr(pl.utilities.argparse, "_gpus_arg_default")
|
|
|
|
|
|
def test_migrate_checkpoint(monkeypatch):
|
|
"""Test that the correct migration function gets executed given the current version of the checkpoint."""
|
|
|
|
# A checkpoint that is older than any migration point in the index
|
|
old_checkpoint = {"pytorch-lightning_version": "0.0.0", "content": 123}
|
|
new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint)
|
|
assert call_order == ["one", "two", "three", "four"]
|
|
assert (
|
|
new_checkpoint
|
|
== old_checkpoint
|
|
== {"legacy_pytorch-lightning_version": "0.0.0", "pytorch-lightning_version": pl.__version__, "content": 123}
|
|
)
|
|
|
|
# A checkpoint that is newer, but not the newest
|
|
old_checkpoint = {"pytorch-lightning_version": "1.0.3", "content": 123}
|
|
new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint)
|
|
assert call_order == ["four"]
|
|
assert (
|
|
new_checkpoint
|
|
== old_checkpoint
|
|
== {"legacy_pytorch-lightning_version": "1.0.3", "pytorch-lightning_version": pl.__version__, "content": 123}
|
|
)
|
|
|
|
# A checkpoint newer than any migration point in the index
|
|
old_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123}
|
|
new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint)
|
|
assert call_order == []
|
|
assert new_checkpoint == old_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123}
|
|
|
|
|
|
def _run_simple_migration(monkeypatch, old_checkpoint):
|
|
call_order = []
|
|
|
|
def dummy_upgrade(tag):
|
|
def upgrade(ckpt):
|
|
call_order.append(tag)
|
|
return ckpt
|
|
|
|
return upgrade
|
|
|
|
index = {
|
|
"0.0.1": [dummy_upgrade("one")],
|
|
"0.0.2": [dummy_upgrade("two"), dummy_upgrade("three")],
|
|
"1.2.3": [dummy_upgrade("four")],
|
|
}
|
|
monkeypatch.setattr(pl.utilities.migration.utils, "_migration_index", lambda: index)
|
|
new_checkpoint, _ = migrate_checkpoint(old_checkpoint)
|
|
return new_checkpoint, call_order
|
|
|
|
|
|
def test_migrate_checkpoint_too_new():
|
|
"""Test checkpoint migration is a no-op with a warning when attempting to migrate a checkpoint from newer
|
|
version of Lightning than installed."""
|
|
super_new_checkpoint = {"pytorch-lightning_version": "99.0.0", "content": 123}
|
|
with pytest.warns(
|
|
PossibleUserWarning, match=f"v99.0.0, which is newer than your current Lightning version: v{pl.__version__}"
|
|
):
|
|
new_checkpoint, migrations = migrate_checkpoint(super_new_checkpoint.copy())
|
|
|
|
# no version modification
|
|
assert not migrations
|
|
assert new_checkpoint == super_new_checkpoint
|
|
|
|
|
|
def test_migrate_checkpoint_for_pl(caplog):
|
|
"""Test that the automatic migration in Lightning informs the user about how to make the upgrade permanent."""
|
|
|
|
# simulate a very recent checkpoint, no migrations needed
|
|
loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123}
|
|
new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt")
|
|
assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123}
|
|
|
|
# simulate an old checkpoint that needed an upgrade
|
|
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123}
|
|
with caplog.at_level(logging.INFO, logger="pytorch_lightning.utilities.migration.utils"):
|
|
new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt")
|
|
assert new_checkpoint == {
|
|
"legacy_pytorch-lightning_version": "0.0.1",
|
|
"pytorch-lightning_version": pl.__version__,
|
|
"callbacks": {},
|
|
"content": 123,
|
|
}
|
|
assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text
|
|
|
|
|
|
def test_migrate_checkpoint_legacy_version(monkeypatch):
|
|
"""Test that the legacy version gets set and does not change if migration is applied multiple times."""
|
|
loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123}
|
|
|
|
# pretend the current pl version is 2.0
|
|
monkeypatch.setattr(pl, "__version__", "2.0.0")
|
|
new_checkpoint, _ = migrate_checkpoint(loaded_checkpoint)
|
|
assert new_checkpoint["pytorch-lightning_version"] == "2.0.0"
|
|
assert new_checkpoint["legacy_pytorch-lightning_version"] == "0.0.1"
|
|
|
|
# pretend the current pl version is even newer, we are migrating a second time
|
|
monkeypatch.setattr(pl, "__version__", "3.0.0")
|
|
new_new_checkpoint, _ = migrate_checkpoint(new_checkpoint)
|
|
assert new_new_checkpoint["pytorch-lightning_version"] == "3.0.0"
|
|
assert new_new_checkpoint["legacy_pytorch-lightning_version"] == "0.0.1" # remains the same
|