55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
|
# Copyright The PyTorch Lightning team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import glob
|
||
|
import os
|
||
|
import sys
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from pytorch_lightning import Trainer
|
||
|
from tests import LEGACY_PATH
|
||
|
|
||
|
LEGACY_CHECKPOINTS_PATH = os.path.join(LEGACY_PATH, 'checkpoints')
|
||
|
CHECKPOINT_EXTENSION = ".ckpt"
|
||
|
|
||
|
|
||
|
# todo: add more legacy checkpoints :]
|
||
|
@pytest.mark.parametrize("pl_version", [
|
||
|
"0.10.0", "1.0.0", "1.0.1", "1.0.2", "1.0.3", "1.0.4", "1.0.5", "1.0.6", "1.0.7", "1.0.8"
|
||
|
])
|
||
|
def test_resume_legacy_checkpoints(tmpdir, pl_version):
|
||
|
path_dir = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
|
||
|
|
||
|
# todo: make this as mock, so it is cleaner...
|
||
|
orig_sys_paths = list(sys.path)
|
||
|
sys.path.insert(0, path_dir)
|
||
|
from zero_training import DummyModel
|
||
|
|
||
|
path_ckpts = sorted(glob.glob(os.path.join(path_dir, f'*{CHECKPOINT_EXTENSION}')))
|
||
|
assert path_ckpts, 'No checkpoints found in folder "%s"' % path_dir
|
||
|
path_ckpt = path_ckpts[-1]
|
||
|
|
||
|
model = DummyModel.load_from_checkpoint(path_ckpt)
|
||
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=6)
|
||
|
result = trainer.fit(model)
|
||
|
assert result
|
||
|
|
||
|
# todo
|
||
|
# model = DummyModel()
|
||
|
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, resume_from_checkpoint=path_ckpt)
|
||
|
# result = trainer.fit(model)
|
||
|
# assert result
|
||
|
|
||
|
sys.path = orig_sys_paths
|