From 63a4af3ba771707c4d3860bc3ba35e98bbf129fa Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 20:33:31 -0400 Subject: [PATCH] added testing for metrics --- .../new_project_templates/lightning_module_template.py | 3 --- pytorch_lightning/models/trainer.py | 2 +- pytorch_lightning/root_module/root_module.py | 9 --------- pytorch_lightning/testing_models/lm_test_module.py | 4 ++-- 4 files changed, 3 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py index 0f28e18129..c6b5ccd173 100644 --- a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py +++ b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py @@ -152,9 +152,6 @@ class LightningTemplateModel(LightningModule): tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} return tqdm_dic - def update_tng_log_metrics(self, logs): - return logs - # --------------------- # MODEL SAVING # --------------------- diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index a575a3167e..7f294a7140 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -678,7 +678,7 @@ class Trainer(TrainerIO): # nb_params, nb_tensors = count_mem_items() model = self.__get_model() - metrics = model.update_tng_log_metrics(self.__tng_tqdm_dic) + metrics = self.__tng_tqdm_dic # add gpu memory if self.on_gpu: diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 945a15a61b..c0f401848d 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -71,15 +71,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): """ raise NotImplementedError - def update_tng_log_metrics(self, logs): - """ - Chance to update metrics to be logged for training step. - For example, add music, images, etc... to log - :param logs: - :return: - """ - return logs - def loss(self, *args, **kwargs): """ Expand model_out into your components diff --git a/pytorch_lightning/testing_models/lm_test_module.py b/pytorch_lightning/testing_models/lm_test_module.py index 685bb30ba2..eba8882d13 100644 --- a/pytorch_lightning/testing_models/lm_test_module.py +++ b/pytorch_lightning/testing_models/lm_test_module.py @@ -166,8 +166,8 @@ class LightningTestModel(LightningModule): tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} return tqdm_dic - def update_tng_log_metrics(self, logs): - return logs + def on_tng_metrics(self, logs): + logs['some_tensor_to_test'] = torch.rand(1) # --------------------- # MODEL SAVING