Fix multithreading checkpoint loading (#17678)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
fd296e0605
commit
1307b605e8
|
@ -14,6 +14,7 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from types import ModuleType, TracebackType
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
|
@ -28,6 +29,7 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
|||
|
||||
_log = logging.getLogger(__name__)
|
||||
_CHECKPOINT = Dict[str, Any]
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def migrate_checkpoint(
|
||||
|
@ -85,6 +87,7 @@ class pl_legacy_patch:
|
|||
"""
|
||||
|
||||
def __enter__(self) -> "pl_legacy_patch":
|
||||
_lock.acquire()
|
||||
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
|
||||
legacy_argparse_module = ModuleType("lightning.pytorch.utilities.argparse_utils")
|
||||
sys.modules["lightning.pytorch.utilities.argparse_utils"] = legacy_argparse_module
|
||||
|
@ -103,6 +106,7 @@ class pl_legacy_patch:
|
|||
if hasattr(pl.utilities.argparse, "_gpus_arg_default"):
|
||||
delattr(pl.utilities.argparse, "_gpus_arg_default")
|
||||
del sys.modules["lightning.pytorch.utilities.argparse_utils"]
|
||||
_lock.release()
|
||||
|
||||
|
||||
def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT:
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
import glob
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
@ -26,6 +25,7 @@ from tests_pytorch import _PATH_LEGACY
|
|||
from tests_pytorch.helpers.datamodules import ClassifDataModule
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
from tests_pytorch.helpers.simple_models import ClassificationModel
|
||||
from tests_pytorch.helpers.threading import ThreadExceptionHandler
|
||||
|
||||
LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints")
|
||||
CHECKPOINT_EXTENSION = ".ckpt"
|
||||
|
@ -68,18 +68,22 @@ class LimitNbEpochs(Callback):
|
|||
@pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS)
|
||||
@RunIf(sklearn=True)
|
||||
def test_legacy_ckpt_threading(tmpdir, pl_version: str):
|
||||
PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
|
||||
path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}")))
|
||||
assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"'
|
||||
path_ckpt = path_ckpts[-1]
|
||||
|
||||
def load_model():
|
||||
import torch
|
||||
|
||||
from lightning.pytorch.utilities.migration import pl_legacy_patch
|
||||
|
||||
with pl_legacy_patch():
|
||||
_ = torch.load(PATH_LEGACY)
|
||||
_ = torch.load(path_ckpt)
|
||||
|
||||
PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
|
||||
with patch("sys.path", [PATH_LEGACY] + sys.path):
|
||||
t1 = threading.Thread(target=load_model)
|
||||
t2 = threading.Thread(target=load_model)
|
||||
t1 = ThreadExceptionHandler(target=load_model)
|
||||
t2 = ThreadExceptionHandler(target=load_model)
|
||||
|
||||
t1.start()
|
||||
t2.start()
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright The Lightning AI team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from threading import Thread
|
||||
|
||||
|
||||
class ThreadExceptionHandler(Thread):
|
||||
"""Adopted from https://stackoverflow.com/a/67022927."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.exception = None
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
super().run()
|
||||
except Exception as e:
|
||||
self.exception = e
|
||||
|
||||
def join(self):
|
||||
super().join()
|
||||
if self.exception:
|
||||
raise self.exception
|
Loading…
Reference in New Issue