From a0997bb7a6b5c6d39cfade6c8ae510f9a49c24b8 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Aug 2020 13:46:46 -0400 Subject: [PATCH] ref: added hook base method (#3127) * added hook base method * added hook base method --- pytorch_lightning/core/lightning.py | 6 ------ pytorch_lightning/trainer/trainer.py | 10 ++++++++++ pytorch_lightning/trainer/training_loop.py | 20 +++++--------------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 895f65f409..a726cf31fe 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -272,12 +272,6 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod """ rank_zero_warn('`training_step` must be implemented to be used with the Lightning Trainer') - def training_end(self, *args, **kwargs): - """ - Warnings: - Deprecated in v0.7.0. Use :meth:`training_step_end` instead. - """ - def training_step_end(self, *args, **kwargs): """ Use this when training with dp or ddp2 because :meth:`training_step` diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3a47b8f374..62e03e52fa 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1456,6 +1456,16 @@ class Trainer( self.amp_backend = None self._setup_amp_backend(amp_type) + def call_hook(self, hook_name, *args, **kwargs): + output = None + if self.is_overridden(hook_name): + model_ref = self.get_model() + with self.profiler.profile(hook_name): + hook_fx = getattr(model_ref, hook_name) + output = hook_fx(*args, **kwargs) + + return output + class _PatchDataLoader(object): r""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7399539213..5c238040a8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -323,6 +323,10 @@ class TrainerTrainLoopMixin(ABC): def reset_val_dataloader(self, model): """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod + def call_hook(self, hook_name, *args, **kwargs): + """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod def has_arg(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -1202,24 +1206,10 @@ class TrainerTrainLoopMixin(ABC): # allow any mode to define training_step_end # do something will all the dp outputs (like softmax) if self.is_overridden('training_step_end'): - model_ref = self.get_model() - with self.profiler.profile('training_step_end'): - # TODO: modify when using result obj - output = model_ref.training_step_end(output) - + output = self.call_hook('training_step_end', output) elif is_result_obj and (self.use_dp or self.use_ddp2): output.dp_reduce() - # allow any mode to define training_end - # TODO: remove in 1.0.0 - if self.is_overridden('training_end'): - model_ref = self.get_model() - with self.profiler.profile('training_end'): - output = model_ref.training_end(output) - - rank_zero_warn('`training_end` was deprecated in 0.7.0 and will be removed 1.0.0.' - ' Use training_epoch_end instead', DeprecationWarning) - return output def update_learning_rates(self, interval: str, monitor_metrics=None):