quick patch __code__ (#1352)

* quick patch

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix

* testing fix
This commit is contained in:
William Falcon 2020-04-03 08:40:02 -04:00 committed by GitHub
parent 1576ad9963
commit 2eca8a9ef2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 4 deletions

View File

@ -20,8 +20,17 @@ class TrainerModelHooksMixin(ABC):
# in case of calling deprecated method
return False
# when code pointers are different, it was overriden
is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__
instance_attr = getattr(model, method_name)
super_attr = getattr(super_object, method_name)
# when code pointers are different, it was implemented
if hasattr(instance_attr, 'patch_loader_code'):
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overriden = instance_attr.__code__ is not super_attr.__code__
return is_overriden
def has_arg(self, f_name, arg_name):

View File

@ -970,8 +970,10 @@ class _PatchDataLoader(object):
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader
# Assign __code__, needed for checking if method has been overriden
self.__code__ = self.__call__.__code__
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)
def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader