add a hook for on_tng_metrics so that users get access to the grad_norm and mem_map dicts.

This commit is contained in:
Cinjon Resnick 2019-07-16 12:51:48 -04:00
parent 28cfddbe65
commit fbd3873a0f
2 changed files with 9 additions and 5 deletions

View File

@ -542,9 +542,11 @@ class Trainer(TrainerIO):
if self.track_grad_norm > 0:
model = self.__get_model()
grad_norm_dic = model.grad_norm(self.track_grad_norm)
metrics.update(grad_norm_dic)
if self.__is_function_implemented('on_tng_metrics'):
model.on_tng_metrics(metrics)
# log metrics
scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
if self.proc_rank == 0:
@ -723,7 +725,6 @@ class Trainer(TrainerIO):
self.prog_bar.set_postfix(**tqdm_metrics)
# model checkpointing
if self.proc_rank == 0:
if self.checkpoint_callback:
print('save callback...')
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic)
if self.proc_rank == 0 and self.checkpoint_callback:
print('save callback...')
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic)

View File

@ -19,3 +19,6 @@ class ModelHooks(torch.nn.Module):
def on_post_performance_check(self):
pass
def on_tng_metrics(self, metrics):
pass