Fix unpickle redirection (#15382)

This commit is contained in:
Justus Schock 2022-10-28 14:35:35 +02:00 committed by GitHub
parent 478ca8c3a0
commit 1c33d57b0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 38 additions and 12 deletions

View File

@ -1,20 +1,46 @@
import pickle import pickle
import warnings import warnings
from typing import Any from typing import Any, Callable
import torchmetrics
from lightning_utilities.core.imports import compare_version as _compare_version
class RedirectingUnpickler(pickle.Unpickler): def _patch_pl_to_mirror_if_necessary(module: str) -> str:
def find_class(self, module: str, name: str) -> Any:
pl = "pytorch_" + "lightning" # avoids replacement during mirror package generation pl = "pytorch_" + "lightning" # avoids replacement during mirror package generation
if module.startswith(pl): if module.startswith(pl):
# for the standalone package this won't do anything, # for the standalone package this won't do anything,
# for the unified mirror package it will redirect the imports # for the unified mirror package it will redirect the imports
old_module = module
module = "pytorch_lightning" + module[len(pl) :] module = "pytorch_lightning" + module[len(pl) :]
return module
class RedirectingUnpickler(pickle._Unpickler):
def find_class(self, module: str, name: str) -> Any:
new_module = _patch_pl_to_mirror_if_necessary(module)
# this warning won't trigger for standalone as these imports are identical # this warning won't trigger for standalone as these imports are identical
if module != old_module: if module != new_module:
warnings.warn(f"Redirecting import of {old_module}.{name} to {module}.{name}") warnings.warn(f"Redirecting import of {module}.{name} to {new_module}.{name}")
return super().find_class(module, name) return super().find_class(new_module, name)
def compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
new_package = _patch_pl_to_mirror_if_necessary(package)
return _compare_version(new_package, op, version, use_base_version)
# patching is necessary, since up to v.0.7.3 torchmetrics has a hardcoded reference to pytorch_lightning,
# which has to be redirected to the unified package:
# https://github.com/Lightning-AI/metrics/blob/v0.7.3/torchmetrics/metric.py#L96
try:
if hasattr(torchmetrics.utilities.imports, "_compare_version"):
torchmetrics.utilities.imports._compare_version = compare_version # type: ignore
except AttributeError:
pass
try:
if hasattr(torchmetrics.metric, "_compare_version"):
torchmetrics.metric._compare_version = compare_version # type: ignore
except AttributeError:
pass
pickle.Unpickler = RedirectingUnpickler # type: ignore pickle.Unpickler = RedirectingUnpickler # type: ignore