From fbd3873a0fad5627a4b4874b8a95bcfa52bc0cb1 Mon Sep 17 00:00:00 2001 From: Cinjon Resnick Date: Tue, 16 Jul 2019 12:51:48 -0400 Subject: [PATCH] add a hook for on_tng_metrics so that users get access to the grad_norm and mem_map dicts. --- pytorch_lightning/models/trainer.py | 11 ++++++----- pytorch_lightning/root_module/hooks.py | 3 +++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 1203113286..5730c618f5 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -542,9 +542,11 @@ class Trainer(TrainerIO): if self.track_grad_norm > 0: model = self.__get_model() grad_norm_dic = model.grad_norm(self.track_grad_norm) - metrics.update(grad_norm_dic) + if self.__is_function_implemented('on_tng_metrics'): + model.on_tng_metrics(metrics) + # log metrics scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist()) if self.proc_rank == 0: @@ -723,7 +725,6 @@ class Trainer(TrainerIO): self.prog_bar.set_postfix(**tqdm_metrics) # model checkpointing - if self.proc_rank == 0: - if self.checkpoint_callback: - print('save callback...') - self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic) + if self.proc_rank == 0 and self.checkpoint_callback: + print('save callback...') + self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic) diff --git a/pytorch_lightning/root_module/hooks.py b/pytorch_lightning/root_module/hooks.py index 99155ab9a1..6d5c5dcd72 100644 --- a/pytorch_lightning/root_module/hooks.py +++ b/pytorch_lightning/root_module/hooks.py @@ -19,3 +19,6 @@ class ModelHooks(torch.nn.Module): def on_post_performance_check(self): pass + def on_tng_metrics(self, metrics): + pass +