verified tfx support

This commit is contained in:
William Falcon 2019-06-29 17:45:26 -04:00
parent 0a03042bf7
commit 6ffb6fb010
1 changed files with 8 additions and 3 deletions

View File

@ -373,7 +373,7 @@ class Trainer(TrainerIO):
metrics.update(grad_norm_dic) metrics.update(grad_norm_dic)
# log metrics # log metrics
scalar_metrics = self.__metrics_to_scalars(metrics) scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
self.experiment.log(scalar_metrics, global_step=self.global_step) self.experiment.log(scalar_metrics, global_step=self.global_step)
self.experiment.save() self.experiment.save()
@ -401,7 +401,7 @@ class Trainer(TrainerIO):
if stop: if stop:
return return
def __metrics_to_scalars(self, metrics): def __metrics_to_scalars(self, metrics, blacklist=[]):
new_metrics = {} new_metrics = {}
for k, v in metrics.items(): for k, v in metrics.items():
if type(v) is torch.Tensor: if type(v) is torch.Tensor:
@ -410,10 +410,15 @@ class Trainer(TrainerIO):
if type(v) is dict: if type(v) is dict:
v = self.__metrics_to_scalars(v) v = self.__metrics_to_scalars(v)
new_metrics[k] = float(v) if k not in blacklist:
new_metrics[k] = float(v)
return new_metrics return new_metrics
def __log_vals_blacklist(self):
"""avoid logging some vals lightning uses to maintain state"""
blacklist = {'batch_nb', 'v_nb', 'epoch', 'gpu'}
return blacklist
def __run_tng_batch(self, data_batch, batch_nb): def __run_tng_batch(self, data_batch, batch_nb):
if data_batch is None: if data_batch is None: