fix setup and on fit calls (#2252)
This commit is contained in:
parent
b7fc092bf4
commit
b5a2f1ec44
|
@ -6,8 +6,9 @@ from pytorch_lightning.core.lightning import LightningModule
|
||||||
|
|
||||||
class TrainerModelHooksMixin(ABC):
|
class TrainerModelHooksMixin(ABC):
|
||||||
|
|
||||||
def is_function_implemented(self, f_name):
|
def is_function_implemented(self, f_name, model=None):
|
||||||
model = self.get_model()
|
if model is None:
|
||||||
|
model = self.get_model()
|
||||||
f_op = getattr(model, f_name, None)
|
f_op = getattr(model, f_name, None)
|
||||||
return callable(f_op)
|
return callable(f_op)
|
||||||
|
|
||||||
|
|
|
@ -848,7 +848,7 @@ class Trainer(
|
||||||
|
|
||||||
# callbacks
|
# callbacks
|
||||||
self.on_fit_start()
|
self.on_fit_start()
|
||||||
if self.is_function_implemented('on_fit_start'):
|
if self.is_function_implemented('on_fit_start', model):
|
||||||
model.on_fit_start()
|
model.on_fit_start()
|
||||||
|
|
||||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||||
|
@ -860,7 +860,7 @@ class Trainer(
|
||||||
self.barrier('fit_prepare_data')
|
self.barrier('fit_prepare_data')
|
||||||
|
|
||||||
self.setup('fit')
|
self.setup('fit')
|
||||||
if self.is_function_implemented('setup'):
|
if self.is_function_implemented('setup', model):
|
||||||
model.setup('fit')
|
model.setup('fit')
|
||||||
|
|
||||||
# Run auto batch size scaling
|
# Run auto batch size scaling
|
||||||
|
@ -1149,8 +1149,8 @@ class Trainer(
|
||||||
trainer.test(model, test_dataloaders=test)
|
trainer.test(model, test_dataloaders=test)
|
||||||
"""
|
"""
|
||||||
self.setup('test')
|
self.setup('test')
|
||||||
if self.is_function_implemented('setup'):
|
model_ref = self.model if model is None else model
|
||||||
model_ref = self.model if model is None else model
|
if self.is_function_implemented('setup', model_ref):
|
||||||
model_ref.setup('test')
|
model_ref.setup('test')
|
||||||
|
|
||||||
self.barrier('test_setup')
|
self.barrier('test_setup')
|
||||||
|
|
Loading…
Reference in New Issue