quick patch __code__ (#1352)
* quick patch * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix * testing fix
This commit is contained in:
parent
1576ad9963
commit
2eca8a9ef2
|
@ -20,8 +20,17 @@ class TrainerModelHooksMixin(ABC):
|
|||
# in case of calling deprecated method
|
||||
return False
|
||||
|
||||
# when code pointers are different, it was overriden
|
||||
is_overriden = getattr(model, method_name).__code__ is not getattr(super_object, method_name).__code__
|
||||
instance_attr = getattr(model, method_name)
|
||||
super_attr = getattr(super_object, method_name)
|
||||
|
||||
# when code pointers are different, it was implemented
|
||||
if hasattr(instance_attr, 'patch_loader_code'):
|
||||
# cannot pickle __code__ so cannot verify if PatchDataloader
|
||||
# exists which shows dataloader methods have been overwritten.
|
||||
# so, we hack it by using the string representation
|
||||
is_overriden = instance_attr.patch_loader_code != str(super_attr.__code__)
|
||||
else:
|
||||
is_overriden = instance_attr.__code__ is not super_attr.__code__
|
||||
return is_overriden
|
||||
|
||||
def has_arg(self, f_name, arg_name):
|
||||
|
|
|
@ -970,8 +970,10 @@ class _PatchDataLoader(object):
|
|||
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
|
||||
self.dataloader = dataloader
|
||||
|
||||
# Assign __code__, needed for checking if method has been overriden
|
||||
self.__code__ = self.__call__.__code__
|
||||
# cannot pickle __code__ so cannot verify if PatchDataloader
|
||||
# exists which shows dataloader methods have been overwritten.
|
||||
# so, we hack it by using the string representation
|
||||
self.patch_loader_code = str(self.__call__.__code__)
|
||||
|
||||
def __call__(self) -> Union[List[DataLoader], DataLoader]:
|
||||
return self.dataloader
|
||||
|
|
Loading…
Reference in New Issue