added testing for metrics
This commit is contained in:
parent
23e7521300
commit
63a4af3ba7
|
@ -152,9 +152,6 @@ class LightningTemplateModel(LightningModule):
|
|||
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
|
||||
def update_tng_log_metrics(self, logs):
|
||||
return logs
|
||||
|
||||
# ---------------------
|
||||
# MODEL SAVING
|
||||
# ---------------------
|
||||
|
|
|
@ -678,7 +678,7 @@ class Trainer(TrainerIO):
|
|||
# nb_params, nb_tensors = count_mem_items()
|
||||
|
||||
model = self.__get_model()
|
||||
metrics = model.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
metrics = self.__tng_tqdm_dic
|
||||
|
||||
# add gpu memory
|
||||
if self.on_gpu:
|
||||
|
|
|
@ -71,15 +71,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update_tng_log_metrics(self, logs):
|
||||
"""
|
||||
Chance to update metrics to be logged for training step.
|
||||
For example, add music, images, etc... to log
|
||||
:param logs:
|
||||
:return:
|
||||
"""
|
||||
return logs
|
||||
|
||||
def loss(self, *args, **kwargs):
|
||||
"""
|
||||
Expand model_out into your components
|
||||
|
|
|
@ -166,8 +166,8 @@ class LightningTestModel(LightningModule):
|
|||
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
return tqdm_dic
|
||||
|
||||
def update_tng_log_metrics(self, logs):
|
||||
return logs
|
||||
def on_tng_metrics(self, logs):
|
||||
logs['some_tensor_to_test'] = torch.rand(1)
|
||||
|
||||
# ---------------------
|
||||
# MODEL SAVING
|
||||
|
|
Loading…
Reference in New Issue