fix setup and on fit calls (#2252)
This commit is contained in:
parent
b7fc092bf4
commit
b5a2f1ec44
|
@ -6,7 +6,8 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
|
||||
class TrainerModelHooksMixin(ABC):
|
||||
|
||||
def is_function_implemented(self, f_name):
|
||||
def is_function_implemented(self, f_name, model=None):
|
||||
if model is None:
|
||||
model = self.get_model()
|
||||
f_op = getattr(model, f_name, None)
|
||||
return callable(f_op)
|
||||
|
|
|
@ -848,7 +848,7 @@ class Trainer(
|
|||
|
||||
# callbacks
|
||||
self.on_fit_start()
|
||||
if self.is_function_implemented('on_fit_start'):
|
||||
if self.is_function_implemented('on_fit_start', model):
|
||||
model.on_fit_start()
|
||||
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
|
@ -860,7 +860,7 @@ class Trainer(
|
|||
self.barrier('fit_prepare_data')
|
||||
|
||||
self.setup('fit')
|
||||
if self.is_function_implemented('setup'):
|
||||
if self.is_function_implemented('setup', model):
|
||||
model.setup('fit')
|
||||
|
||||
# Run auto batch size scaling
|
||||
|
@ -1149,8 +1149,8 @@ class Trainer(
|
|||
trainer.test(model, test_dataloaders=test)
|
||||
"""
|
||||
self.setup('test')
|
||||
if self.is_function_implemented('setup'):
|
||||
model_ref = self.model if model is None else model
|
||||
if self.is_function_implemented('setup', model_ref):
|
||||
model_ref.setup('test')
|
||||
|
||||
self.barrier('test_setup')
|
||||
|
|
Loading…
Reference in New Issue