37 lines
1.4 KiB
Python
37 lines
1.4 KiB
Python
|
import pytest
|
||
|
import os
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||
|
from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize(
|
||
|
"old_checkpoint, new_checkpoint",
|
||
|
[
|
||
|
(
|
||
|
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
|
||
|
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
|
||
|
),
|
||
|
(
|
||
|
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
|
||
|
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
|
||
|
),
|
||
|
(
|
||
|
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": 'path'},
|
||
|
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": 'path'}}},
|
||
|
),
|
||
|
(
|
||
|
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
|
||
|
{"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}},
|
||
|
),
|
||
|
],
|
||
|
)
|
||
|
def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
|
||
|
filepath = os.path.join(tmpdir, "model.ckpt")
|
||
|
torch.save(old_checkpoint, filepath)
|
||
|
upgrade_checkpoint(filepath)
|
||
|
updated_checkpoint = torch.load(filepath)
|
||
|
assert updated_checkpoint == new_checkpoint
|