quick patch __code__ (#1352)
* quick patch * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix
This commit is contained in:
parent
1576ad9963
commit
2eca8a9ef2
|
@ -20,8 +20,17 @@ class TrainerModelHooksMixin(ABC):
|
||||||
# in case of calling deprecated method
|
# in case of calling deprecated method
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# when code pointers are different, it was overriden
|
instance_attr = getattr(model, method_name)
|
||||||
is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__
|
super_attr = getattr(super_object, method_name)
|
||||||
|
|
||||||
|
# when code pointers are different, it was implemented
|
||||||
|
if hasattr(instance_attr, 'patch_loader_code'):
|
||||||
|
# cannot pickle __code__ so cannot verify if PatchDataloader
|
||||||
|
# exists which shows dataloader methods have been overwritten.
|
||||||
|
# so, we hack it by using the string representation
|
||||||
|
is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__)
|
||||||
|
else:
|
||||||
|
is_overriden = instance_attr.__code__ is not super_attr.__code__
|
||||||
return is_overriden
|
return is_overriden
|
||||||
|
|
||||||
def has_arg(self, f_name, arg_name):
|
def has_arg(self, f_name, arg_name):
|
||||||
|
|
|
@ -970,8 +970,10 @@ class _PatchDataLoader(object):
|
||||||
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
|
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
|
||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
|
|
||||||
# Assign __code__, needed for checking if method has been overriden
|
# cannot pickle __code__ so cannot verify if PatchDataloader
|
||||||
self.__code__ = self.__call__.__code__
|
# exists which shows dataloader methods have been overwritten.
|
||||||
|
# so, we hack it by using the string representation
|
||||||
|
self.patch_loader_code = str(self.__call__.__code__)
|
||||||
|
|
||||||
def __call__(self) -> Union[List[DataLoader], DataLoader]:
|
def __call__(self) -> Union[List[DataLoader], DataLoader]:
|
||||||
return self.dataloader
|
return self.dataloader
|
||||||
|
|
Loading…
Reference in New Issue