fix setup and on fit calls (#2252)

This commit is contained in:
William Falcon 2020-06-18 21:45:09 -04:00 committed by GitHub
parent b7fc092bf4
commit b5a2f1ec44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 6 deletions

View File

@ -6,8 +6,9 @@ from pytorch_lightning.core.lightning import LightningModule
class TrainerModelHooksMixin(ABC):
def is_function_implemented(self, f_name):
model = self.get_model()
def is_function_implemented(self, f_name, model=None):
if model is None:
model = self.get_model()
f_op = getattr(model, f_name, None)
return callable(f_op)

View File

@ -848,7 +848,7 @@ class Trainer(
# callbacks
self.on_fit_start()
if self.is_function_implemented('on_fit_start'):
if self.is_function_implemented('on_fit_start', model):
model.on_fit_start()
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
@ -860,7 +860,7 @@ class Trainer(
self.barrier('fit_prepare_data')
self.setup('fit')
if self.is_function_implemented('setup'):
if self.is_function_implemented('setup', model):
model.setup('fit')
# Run auto batch size scaling
@ -1149,8 +1149,8 @@ class Trainer(
trainer.test(model, test_dataloaders=test)
"""
self.setup('test')
if self.is_function_implemented('setup'):
model_ref = self.model if model is None else model
model_ref = self.model if model is None else model
if self.is_function_implemented('setup', model_ref):
model_ref.setup('test')
self.barrier('test_setup')