added testing for metrics

This commit is contained in:
William Falcon 2019-07-24 20:33:31 -04:00
parent 23e7521300
commit 63a4af3ba7
4 changed files with 3 additions and 15 deletions

View File

@ -152,9 +152,6 @@ class LightningTemplateModel(LightningModule):
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
return tqdm_dic
def update_tng_log_metrics(self, logs):
return logs
# ---------------------
# MODEL SAVING
# ---------------------

View File

@ -678,7 +678,7 @@ class Trainer(TrainerIO):
# nb_params, nb_tensors = count_mem_items()
model = self.__get_model()
metrics = model.update_tng_log_metrics(self.__tng_tqdm_dic)
metrics = self.__tng_tqdm_dic
# add gpu memory
if self.on_gpu:

View File

@ -71,15 +71,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
"""
raise NotImplementedError
def update_tng_log_metrics(self, logs):
"""
Chance to update metrics to be logged for training step.
For example, add music, images, etc... to log
:param logs:
:return:
"""
return logs
def loss(self, *args, **kwargs):
"""
Expand model_out into your components

View File

@ -166,8 +166,8 @@ class LightningTestModel(LightningModule):
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
return tqdm_dic
def update_tng_log_metrics(self, logs):
return logs
def on_tng_metrics(self, logs):
logs['some_tensor_to_test'] = torch.rand(1)
# ---------------------
# MODEL SAVING