diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 3cf42136d4..065add2352 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -6,8 +6,9 @@ from pytorch_lightning.core.lightning import LightningModule class TrainerModelHooksMixin(ABC): - def is_function_implemented(self, f_name): - model = self.get_model() + def is_function_implemented(self, f_name, model=None): + if model is None: + model = self.get_model() f_op = getattr(model, f_name, None) return callable(f_op) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cd77f1caa8..36664b5549 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -848,7 +848,7 @@ class Trainer( # callbacks self.on_fit_start() - if self.is_function_implemented('on_fit_start'): + if self.is_function_implemented('on_fit_start', model): model.on_fit_start() # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 @@ -860,7 +860,7 @@ class Trainer( self.barrier('fit_prepare_data') self.setup('fit') - if self.is_function_implemented('setup'): + if self.is_function_implemented('setup', model): model.setup('fit') # Run auto batch size scaling @@ -1149,8 +1149,8 @@ class Trainer( trainer.test(model, test_dataloaders=test) """ self.setup('test') - if self.is_function_implemented('setup'): - model_ref = self.model if model is None else model + model_ref = self.model if model is None else model + if self.is_function_implemented('setup', model_ref): model_ref.setup('test') self.barrier('test_setup')