From 6ffb6fb01034a65f0e88b0fa2a2b591e0d14f296 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 29 Jun 2019 17:45:26 -0400 Subject: [PATCH] verified tfx support --- pytorch_lightning/models/trainer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 796533e963..cf2667a04d 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -373,7 +373,7 @@ class Trainer(TrainerIO): metrics.update(grad_norm_dic) # log metrics - scalar_metrics = self.__metrics_to_scalars(metrics) + scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist()) self.experiment.log(scalar_metrics, global_step=self.global_step) self.experiment.save() @@ -401,7 +401,7 @@ class Trainer(TrainerIO): if stop: return - def __metrics_to_scalars(self, metrics): + def __metrics_to_scalars(self, metrics, blacklist=[]): new_metrics = {} for k, v in metrics.items(): if type(v) is torch.Tensor: @@ -410,10 +410,15 @@ class Trainer(TrainerIO): if type(v) is dict: v = self.__metrics_to_scalars(v) - new_metrics[k] = float(v) + if k not in blacklist: + new_metrics[k] = float(v) return new_metrics + def __log_vals_blacklist(self): + """avoid logging some vals lightning uses to maintain state""" + blacklist = {'batch_nb', 'v_nb', 'epoch', 'gpu'} + return blacklist def __run_tng_batch(self, data_batch, batch_nb): if data_batch is None: