add a hook for on_tng_metrics so that users get access to the grad_norm and mem_map dicts.
This commit is contained in:
parent
28cfddbe65
commit
fbd3873a0f
|
@ -542,9 +542,11 @@ class Trainer(TrainerIO):
|
|||
if self.track_grad_norm > 0:
|
||||
model = self.__get_model()
|
||||
grad_norm_dic = model.grad_norm(self.track_grad_norm)
|
||||
|
||||
metrics.update(grad_norm_dic)
|
||||
|
||||
if self.__is_function_implemented('on_tng_metrics'):
|
||||
model.on_tng_metrics(metrics)
|
||||
|
||||
# log metrics
|
||||
scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
|
||||
if self.proc_rank == 0:
|
||||
|
@ -723,7 +725,6 @@ class Trainer(TrainerIO):
|
|||
self.prog_bar.set_postfix(**tqdm_metrics)
|
||||
|
||||
# model checkpointing
|
||||
if self.proc_rank == 0:
|
||||
if self.checkpoint_callback:
|
||||
print('save callback...')
|
||||
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic)
|
||||
if self.proc_rank == 0 and self.checkpoint_callback:
|
||||
print('save callback...')
|
||||
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic)
|
||||
|
|
|
@ -19,3 +19,6 @@ class ModelHooks(torch.nn.Module):
|
|||
def on_post_performance_check(self):
|
||||
pass
|
||||
|
||||
def on_tng_metrics(self, metrics):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue