ref: refactor eval loop to use hooks. use test_mode for if so we can split later (#3129)
* moved eval hooks * moved eval hooks * moved eval hooks * moved eval hooks
This commit is contained in:
parent
a0997bb7a6
commit
c556ee6119
|
@ -225,27 +225,9 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
def reset_val_dataloader(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def __call_eval_loop_hook_start(self, test_mode):
|
||||
"""on_validation/test_epoch_start"""
|
||||
self.__call_eval_loop_hook_evt(test_mode, 'start')
|
||||
|
||||
def __call_eval_loop_hook_end(self, test_mode):
|
||||
"""on_validation/test_epoch_end"""
|
||||
self.__call_eval_loop_hook_evt(test_mode, 'end')
|
||||
|
||||
def __call_eval_loop_hook_evt(self, test_mode, epoch_event):
|
||||
model = self.get_model()
|
||||
|
||||
# on_[train/validation]_epoch_start hook
|
||||
hook_root_name = 'test' if test_mode else 'validation'
|
||||
hook_name = f'on_{hook_root_name}_epoch_{epoch_event}'
|
||||
with self.profiler.profile(hook_name):
|
||||
# call hook
|
||||
getattr(self, hook_name)()
|
||||
|
||||
# model hooks
|
||||
if self.is_function_implemented(hook_name):
|
||||
getattr(model, hook_name)()
|
||||
@abstractmethod
|
||||
def call_hook(self, hook_name, *args, **kwargs):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def _evaluate(
|
||||
self,
|
||||
|
@ -284,7 +266,10 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# --------------------------
|
||||
# ON_EVAL_EPOCH_START hook
|
||||
# --------------------------
|
||||
self.__call_eval_loop_hook_start(test_mode)
|
||||
if test_mode:
|
||||
self.call_hook('on_test_epoch_start')
|
||||
else:
|
||||
self.call_hook('on_validation_epoch_start')
|
||||
|
||||
# run validation
|
||||
for dataloader_idx, dataloader in enumerate(dataloaders):
|
||||
|
@ -309,17 +294,10 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
# callbacks
|
||||
if test_mode:
|
||||
self.on_test_batch_start(batch, batch_idx, dataloader_idx)
|
||||
if self.is_overridden('on_test_batch_start'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('on_test_batch_start'):
|
||||
model_ref.on_test_batch_start(batch, batch_idx, dataloader_idx)
|
||||
self.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx)
|
||||
else:
|
||||
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
||||
if self.is_overridden('on_validation_batch_start'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('on_validation_batch_start'):
|
||||
model_ref.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
||||
self.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx)
|
||||
|
||||
# -----------------
|
||||
# RUN EVALUATION STEP
|
||||
# -----------------
|
||||
|
@ -353,12 +331,11 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# EVAL STEP END
|
||||
# ------------------
|
||||
# on dp / ddp2 might still want to do something with the batch parts
|
||||
eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end'
|
||||
if self.is_overridden(eval_step_end_hook_name):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile(eval_step_end_hook_name):
|
||||
eval_step_end = getattr(model_ref, eval_step_end_hook_name)
|
||||
output = eval_step_end(output)
|
||||
if self.is_overridden('test_step_end') or self.is_overridden('validation_step_end'):
|
||||
if test_mode:
|
||||
output = self.call_hook('test_step_end', output)
|
||||
else:
|
||||
output = self.call_hook('validation_step_end', output)
|
||||
|
||||
elif is_result_obj and (self.use_dp or self.use_ddp2):
|
||||
# result auto reduce
|
||||
|
@ -366,17 +343,9 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
# callbacks (on __batch_end)
|
||||
if test_mode:
|
||||
self.on_test_batch_end(batch, batch_idx, dataloader_idx)
|
||||
if self.is_overridden('on_test_batch_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('on_test_batch_end'):
|
||||
model_ref.on_test_batch_end(batch, batch_idx, dataloader_idx)
|
||||
self.call_hook('on_test_batch_end', batch, batch_idx, dataloader_idx)
|
||||
else:
|
||||
self.on_validation_batch_end(batch, batch_idx, dataloader_idx)
|
||||
if self.is_overridden('on_validation_batch_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('on_validation_batch_end'):
|
||||
model_ref.on_validation_batch_end(batch, batch_idx, dataloader_idx)
|
||||
self.call_hook('on_validation_batch_end', batch, batch_idx, dataloader_idx)
|
||||
|
||||
# track outputs for collation
|
||||
if output is not None:
|
||||
|
@ -416,7 +385,10 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# --------------------------
|
||||
# ON_EVAL_EPOCH_END hook
|
||||
# --------------------------
|
||||
self.__call_eval_loop_hook_end(test_mode)
|
||||
if test_mode:
|
||||
self.call_hook('on_test_epoch_end')
|
||||
else:
|
||||
self.call_hook('on_validation_epoch_end')
|
||||
|
||||
return eval_results
|
||||
|
||||
|
|
|
@ -1457,10 +1457,18 @@ class Trainer(
|
|||
self._setup_amp_backend(amp_type)
|
||||
|
||||
def call_hook(self, hook_name, *args, **kwargs):
|
||||
# always profile hooks
|
||||
with self.profiler.profile(hook_name):
|
||||
|
||||
# first call trainer hook
|
||||
if hasattr(self, hook_name):
|
||||
trainer_hook = getattr(self, hook_name)
|
||||
trainer_hook(*args, **kwargs)
|
||||
|
||||
# next call hook in lightningModule
|
||||
output = None
|
||||
if self.is_overridden(hook_name):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile(hook_name):
|
||||
hook_fx = getattr(model_ref, hook_name)
|
||||
output = hook_fx(*args, **kwargs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue