verified tfx support
This commit is contained in:
parent
0a03042bf7
commit
6ffb6fb010
|
@ -373,7 +373,7 @@ class Trainer(TrainerIO):
|
|||
metrics.update(grad_norm_dic)
|
||||
|
||||
# log metrics
|
||||
scalar_metrics = self.__metrics_to_scalars(metrics)
|
||||
scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
|
||||
self.experiment.log(scalar_metrics, global_step=self.global_step)
|
||||
self.experiment.save()
|
||||
|
||||
|
@ -401,7 +401,7 @@ class Trainer(TrainerIO):
|
|||
if stop:
|
||||
return
|
||||
|
||||
def __metrics_to_scalars(self, metrics):
|
||||
def __metrics_to_scalars(self, metrics, blacklist=[]):
|
||||
new_metrics = {}
|
||||
for k, v in metrics.items():
|
||||
if type(v) is torch.Tensor:
|
||||
|
@ -410,10 +410,15 @@ class Trainer(TrainerIO):
|
|||
if type(v) is dict:
|
||||
v = self.__metrics_to_scalars(v)
|
||||
|
||||
new_metrics[k] = float(v)
|
||||
if k not in blacklist:
|
||||
new_metrics[k] = float(v)
|
||||
|
||||
return new_metrics
|
||||
|
||||
def __log_vals_blacklist(self):
|
||||
"""avoid logging some vals lightning uses to maintain state"""
|
||||
blacklist = {'batch_nb', 'v_nb', 'epoch', 'gpu'}
|
||||
return blacklist
|
||||
|
||||
def __run_tng_batch(self, data_batch, batch_nb):
|
||||
if data_batch is None:
|
||||
|
|
Loading…
Reference in New Issue