ref: added hook base method (#3127)
* added hook base method * added hook base method
This commit is contained in:
parent
20018b2668
commit
a0997bb7a6
|
@ -272,12 +272,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
"""
|
||||
rank_zero_warn('`training_step` must be implemented to be used with the Lightning Trainer')
|
||||
|
||||
def training_end(self, *args, **kwargs):
|
||||
"""
|
||||
Warnings:
|
||||
Deprecated in v0.7.0. Use :meth:`training_step_end` instead.
|
||||
"""
|
||||
|
||||
def training_step_end(self, *args, **kwargs):
|
||||
"""
|
||||
Use this when training with dp or ddp2 because :meth:`training_step`
|
||||
|
|
|
@ -1456,6 +1456,16 @@ class Trainer(
|
|||
self.amp_backend = None
|
||||
self._setup_amp_backend(amp_type)
|
||||
|
||||
def call_hook(self, hook_name, *args, **kwargs):
|
||||
output = None
|
||||
if self.is_overridden(hook_name):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile(hook_name):
|
||||
hook_fx = getattr(model_ref, hook_name)
|
||||
output = hook_fx(*args, **kwargs)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _PatchDataLoader(object):
|
||||
r"""
|
||||
|
|
|
@ -323,6 +323,10 @@ class TrainerTrainLoopMixin(ABC):
|
|||
def reset_val_dataloader(self, model):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def call_hook(self, hook_name, *args, **kwargs):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def has_arg(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
@ -1202,24 +1206,10 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# allow any mode to define training_step_end
|
||||
# do something will all the dp outputs (like softmax)
|
||||
if self.is_overridden('training_step_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('training_step_end'):
|
||||
# TODO: modify when using result obj
|
||||
output = model_ref.training_step_end(output)
|
||||
|
||||
output = self.call_hook('training_step_end', output)
|
||||
elif is_result_obj and (self.use_dp or self.use_ddp2):
|
||||
output.dp_reduce()
|
||||
|
||||
# allow any mode to define training_end
|
||||
# TODO: remove in 1.0.0
|
||||
if self.is_overridden('training_end'):
|
||||
model_ref = self.get_model()
|
||||
with self.profiler.profile('training_end'):
|
||||
output = model_ref.training_end(output)
|
||||
|
||||
rank_zero_warn('`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
|
||||
' Use training_epoch_end instead', DeprecationWarning)
|
||||
|
||||
return output
|
||||
|
||||
def update_learning_rates(self, interval: str, monitor_metrics=None):
|
||||
|
|
Loading…
Reference in New Issue