ref: added hook base method (#3127)

* added hook base method

* added hook base method
This commit is contained in:
William Falcon 2020-08-24 13:46:46 -04:00 committed by GitHub
parent 20018b2668
commit a0997bb7a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 21 deletions

View File

@ -272,12 +272,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
"""
rank_zero_warn('`training_step` must be implemented to be used with the Lightning Trainer')
def training_end(self, *args, **kwargs):
"""
Warnings:
Deprecated in v0.7.0. Use :meth:`training_step_end` instead.
"""
def training_step_end(self, *args, **kwargs):
"""
Use this when training with dp or ddp2 because :meth:`training_step`

View File

@ -1456,6 +1456,16 @@ class Trainer(
self.amp_backend = None
self._setup_amp_backend(amp_type)
def call_hook(self, hook_name, *args, **kwargs):
output = None
if self.is_overridden(hook_name):
model_ref = self.get_model()
with self.profiler.profile(hook_name):
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)
return output
class _PatchDataLoader(object):
r"""

View File

@ -323,6 +323,10 @@ class TrainerTrainLoopMixin(ABC):
def reset_val_dataloader(self, model):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def call_hook(self, hook_name, *args, **kwargs):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def has_arg(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@ -1202,24 +1206,10 @@ class TrainerTrainLoopMixin(ABC):
# allow any mode to define training_step_end
# do something will all the dp outputs (like softmax)
if self.is_overridden('training_step_end'):
model_ref = self.get_model()
with self.profiler.profile('training_step_end'):
# TODO: modify when using result obj
output = model_ref.training_step_end(output)
output = self.call_hook('training_step_end', output)
elif is_result_obj and (self.use_dp or self.use_ddp2):
output.dp_reduce()
# allow any mode to define training_end
# TODO: remove in 1.0.0
if self.is_overridden('training_end'):
model_ref = self.get_model()
with self.profiler.profile('training_end'):
output = model_ref.training_end(output)
rank_zero_warn('`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use training_epoch_end instead', DeprecationWarning)
return output
def update_learning_rates(self, interval: str, monitor_metrics=None):