lightning/pytorch_lightning/utilities/migration.py

49 lines
2.1 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from types import ModuleType
import pytorch_lightning.utilities.argparse
class pl_legacy_patch:
"""Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for
unpickling old checkpoints. The following patches apply.
1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to
version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4,
but still needs to be available for import for legacy checkpoints.
Example:
with pl_legacy_patch():
torch.load("path/to/legacy/checkpoint.ckpt")
"""
def __enter__(self):
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils")
sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module
# `_gpus_arg_default` used to be imported from these locations
legacy_argparse_module._gpus_arg_default = lambda x: x
pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"):
delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
del sys.modules["pytorch_lightning.utilities.argparse_utils"]