diff --git a/src/lightning/pytorch/utilities/migration/utils.py b/src/lightning/pytorch/utilities/migration/utils.py index f24ad3d035..56f018a9fe 100644 --- a/src/lightning/pytorch/utilities/migration/utils.py +++ b/src/lightning/pytorch/utilities/migration/utils.py @@ -14,6 +14,7 @@ import logging import os import sys +import threading from types import ModuleType, TracebackType from typing import Any, Dict, List, Optional, Tuple, Type @@ -28,6 +29,7 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_warn _log = logging.getLogger(__name__) _CHECKPOINT = Dict[str, Any] +_lock = threading.Lock() def migrate_checkpoint( @@ -85,6 +87,7 @@ class pl_legacy_patch: """ def __enter__(self) -> "pl_legacy_patch": + _lock.acquire() # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("lightning.pytorch.utilities.argparse_utils") sys.modules["lightning.pytorch.utilities.argparse_utils"] = legacy_argparse_module @@ -103,6 +106,7 @@ class pl_legacy_patch: if hasattr(pl.utilities.argparse, "_gpus_arg_default"): delattr(pl.utilities.argparse, "_gpus_arg_default") del sys.modules["lightning.pytorch.utilities.argparse_utils"] + _lock.release() def _pl_migrate_checkpoint(checkpoint: _CHECKPOINT, checkpoint_path: Optional[_PATH] = None) -> _CHECKPOINT: diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index a171e92e8b..3e69659461 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -14,7 +14,6 @@ import glob import os import sys -import threading from unittest.mock import patch import pytest @@ -26,6 +25,7 @@ from tests_pytorch import _PATH_LEGACY from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel +from tests_pytorch.helpers.threading import ThreadExceptionHandler LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints") CHECKPOINT_EXTENSION = ".ckpt" @@ -68,18 +68,22 @@ class LimitNbEpochs(Callback): @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @RunIf(sklearn=True) def test_legacy_ckpt_threading(tmpdir, pl_version: str): + PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) + path_ckpts = sorted(glob.glob(os.path.join(PATH_LEGACY, f"*{CHECKPOINT_EXTENSION}"))) + assert path_ckpts, f'No checkpoints found in folder "{PATH_LEGACY}"' + path_ckpt = path_ckpts[-1] + def load_model(): import torch from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): - _ = torch.load(PATH_LEGACY) + _ = torch.load(path_ckpt) - PATH_LEGACY = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) with patch("sys.path", [PATH_LEGACY] + sys.path): - t1 = threading.Thread(target=load_model) - t2 = threading.Thread(target=load_model) + t1 = ThreadExceptionHandler(target=load_model) + t2 = ThreadExceptionHandler(target=load_model) t1.start() t2.start() diff --git a/tests/tests_pytorch/helpers/threading.py b/tests/tests_pytorch/helpers/threading.py new file mode 100644 index 0000000000..6447bec303 --- /dev/null +++ b/tests/tests_pytorch/helpers/threading.py @@ -0,0 +1,33 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from threading import Thread + + +class ThreadExceptionHandler(Thread): + """Adopted from https://stackoverflow.com/a/67022927.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.exception = None + + def run(self): + try: + super().run() + except Exception as e: + self.exception = e + + def join(self): + super().join() + if self.exception: + raise self.exception