From 2eca8a9ef27c55e246818ba7fcf00e77a62a4f0c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 3 Apr 2020 08:40:02 -0400 Subject: [PATCH] quick patch __code__ (#1352) * quick patch * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix --- pytorch_lightning/trainer/model_hooks.py | 13 +++++++++++-- pytorch_lightning/trainer/trainer.py | 6 ++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 177aff1faa..d4871ff215 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -20,8 +20,17 @@ class TrainerModelHooksMixin(ABC): # in case of calling deprecated method return False - # when code pointers are different, it was overriden - is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__ + instance_attr = getattr(model, method_name) + super_attr = getattr(super_object, method_name) + + # when code pointers are different, it was implemented + if hasattr(instance_attr, 'patch_loader_code'): + # cannot pickle __code__ so cannot verify if PatchDataloader + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__) + else: + is_overriden = instance_attr.__code__ is not super_attr.__code__ return is_overriden def has_arg(self, f_name, arg_name): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4854864413..d40b1fadfd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -970,8 +970,10 @@ class _PatchDataLoader(object): def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader - # Assign __code__, needed for checking if method has been overriden - self.__code__ = self.__call__.__code__ + # cannot pickle __code__ so cannot verify if PatchDataloader + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + self.patch_loader_code = str(self.__call__.__code__) def __call__(self) -> Union[List[DataLoader], DataLoader]: return self.dataloader