Fix unpickle redirection (#15382)
This commit is contained in:
parent
478ca8c3a0
commit
1c33d57b0a
|
@ -1,20 +1,46 @@
|
||||||
import pickle
|
import pickle
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
import torchmetrics
|
||||||
|
from lightning_utilities.core.imports import compare_version as _compare_version
|
||||||
|
|
||||||
|
|
||||||
class RedirectingUnpickler(pickle.Unpickler):
|
def _patch_pl_to_mirror_if_necessary(module: str) -> str:
|
||||||
def find_class(self, module: str, name: str) -> Any:
|
|
||||||
pl = "pytorch_" + "lightning" # avoids replacement during mirror package generation
|
pl = "pytorch_" + "lightning" # avoids replacement during mirror package generation
|
||||||
if module.startswith(pl):
|
if module.startswith(pl):
|
||||||
# for the standalone package this won't do anything,
|
# for the standalone package this won't do anything,
|
||||||
# for the unified mirror package it will redirect the imports
|
# for the unified mirror package it will redirect the imports
|
||||||
old_module = module
|
|
||||||
module = "pytorch_lightning" + module[len(pl) :]
|
module = "pytorch_lightning" + module[len(pl) :]
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class RedirectingUnpickler(pickle._Unpickler):
|
||||||
|
def find_class(self, module: str, name: str) -> Any:
|
||||||
|
new_module = _patch_pl_to_mirror_if_necessary(module)
|
||||||
# this warning won't trigger for standalone as these imports are identical
|
# this warning won't trigger for standalone as these imports are identical
|
||||||
if module != old_module:
|
if module != new_module:
|
||||||
warnings.warn(f"Redirecting import of {old_module}.{name} to {module}.{name}")
|
warnings.warn(f"Redirecting import of {module}.{name} to {new_module}.{name}")
|
||||||
return super().find_class(module, name)
|
return super().find_class(new_module, name)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
|
||||||
|
new_package = _patch_pl_to_mirror_if_necessary(package)
|
||||||
|
return _compare_version(new_package, op, version, use_base_version)
|
||||||
|
|
||||||
|
|
||||||
|
# patching is necessary, since up to v.0.7.3 torchmetrics has a hardcoded reference to pytorch_lightning,
|
||||||
|
# which has to be redirected to the unified package:
|
||||||
|
# https://github.com/Lightning-AI/metrics/blob/v0.7.3/torchmetrics/metric.py#L96
|
||||||
|
try:
|
||||||
|
if hasattr(torchmetrics.utilities.imports, "_compare_version"):
|
||||||
|
torchmetrics.utilities.imports._compare_version = compare_version # type: ignore
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(torchmetrics.metric, "_compare_version"):
|
||||||
|
torchmetrics.metric._compare_version = compare_version # type: ignore
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
pickle.Unpickler = RedirectingUnpickler # type: ignore
|
pickle.Unpickler = RedirectingUnpickler # type: ignore
|
||||||
|
|
Loading…
Reference in New Issue