verified tfx support
This commit is contained in:
parent
0a03042bf7
commit
6ffb6fb010
|
@ -373,7 +373,7 @@ class Trainer(TrainerIO):
|
||||||
metrics.update(grad_norm_dic)
|
metrics.update(grad_norm_dic)
|
||||||
|
|
||||||
# log metrics
|
# log metrics
|
||||||
scalar_metrics = self.__metrics_to_scalars(metrics)
|
scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
|
||||||
self.experiment.log(scalar_metrics, global_step=self.global_step)
|
self.experiment.log(scalar_metrics, global_step=self.global_step)
|
||||||
self.experiment.save()
|
self.experiment.save()
|
||||||
|
|
||||||
|
@ -401,7 +401,7 @@ class Trainer(TrainerIO):
|
||||||
if stop:
|
if stop:
|
||||||
return
|
return
|
||||||
|
|
||||||
def __metrics_to_scalars(self, metrics):
|
def __metrics_to_scalars(self, metrics, blacklist=[]):
|
||||||
new_metrics = {}
|
new_metrics = {}
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
if type(v) is torch.Tensor:
|
if type(v) is torch.Tensor:
|
||||||
|
@ -410,10 +410,15 @@ class Trainer(TrainerIO):
|
||||||
if type(v) is dict:
|
if type(v) is dict:
|
||||||
v = self.__metrics_to_scalars(v)
|
v = self.__metrics_to_scalars(v)
|
||||||
|
|
||||||
|
if k not in blacklist:
|
||||||
new_metrics[k] = float(v)
|
new_metrics[k] = float(v)
|
||||||
|
|
||||||
return new_metrics
|
return new_metrics
|
||||||
|
|
||||||
|
def __log_vals_blacklist(self):
|
||||||
|
"""avoid logging some vals lightning uses to maintain state"""
|
||||||
|
blacklist = {'batch_nb', 'v_nb', 'epoch', 'gpu'}
|
||||||
|
return blacklist
|
||||||
|
|
||||||
def __run_tng_batch(self, data_batch, batch_nb):
|
def __run_tng_batch(self, data_batch, batch_nb):
|
||||||
if data_batch is None:
|
if data_batch is None:
|
||||||
|
|
Loading…
Reference in New Issue