49 lines
2.1 KiB
Python
49 lines
2.1 KiB
Python
|
# Copyright The PyTorch Lightning team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import sys
|
||
|
from types import ModuleType
|
||
|
|
||
|
import pytorch_lightning.utilities.argparse
|
||
|
|
||
|
|
||
|
class pl_legacy_patch:
|
||
|
"""Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for
|
||
|
unpickling old checkpoints. The following patches apply.
|
||
|
|
||
|
1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to
|
||
|
version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
|
||
|
2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4,
|
||
|
but still needs to be available for import for legacy checkpoints.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
with pl_legacy_patch():
|
||
|
torch.load("path/to/legacy/checkpoint.ckpt")
|
||
|
"""
|
||
|
|
||
|
def __enter__(self):
|
||
|
# `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse`
|
||
|
legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils")
|
||
|
sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module
|
||
|
|
||
|
# `_gpus_arg_default` used to be imported from these locations
|
||
|
legacy_argparse_module._gpus_arg_default = lambda x: x
|
||
|
pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||
|
if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"):
|
||
|
delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default")
|
||
|
del sys.modules["pytorch_lightning.utilities.argparse_utils"]
|