37 lines
1.4 KiB
Python
37 lines
1.4 KiB
Python
import pytest
|
|
import os
|
|
|
|
import torch
|
|
|
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|
from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"old_checkpoint, new_checkpoint",
|
|
[
|
|
(
|
|
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
|
|
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
|
|
),
|
|
(
|
|
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
|
|
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
|
|
),
|
|
(
|
|
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": 'path'},
|
|
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": 'path'}}},
|
|
),
|
|
(
|
|
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
|
|
{"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}},
|
|
),
|
|
],
|
|
)
|
|
def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
|
|
filepath = os.path.join(tmpdir, "model.ckpt")
|
|
torch.save(old_checkpoint, filepath)
|
|
upgrade_checkpoint(filepath)
|
|
updated_checkpoint = torch.load(filepath)
|
|
assert updated_checkpoint == new_checkpoint
|