enable any logged metric to be accessible in callbacks (#3598)

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* enable any logged or written metric to be accessible in callbacks

* clarify forward

* clarify forward

* clarify forward

* clarify forward
This commit is contained in:
William Falcon 2020-09-22 18:00:23 -04:00 committed by GitHub
parent d2a3d6aa8e
commit c591013708
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 43 additions and 16 deletions

View File

@ -22,6 +22,7 @@ Automatically save model checkpoints during training.
import os
import re
from copy import deepcopy
from typing import Any, Dict, Optional
import numpy as np
@ -219,7 +220,8 @@ class ModelCheckpoint(Callback):
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
return monitor_op(current, self.best_k_models[self.kth_best_model_path])
val = monitor_op(current, self.best_k_models[self.kth_best_model_path])
return val
@classmethod
def _format_checkpoint_name(
@ -351,7 +353,11 @@ class ModelCheckpoint(Callback):
self.epoch_last_check = epoch
ckpt_name_metrics = trainer.logger_connector.logged_metrics
# anything logged or in callbacks can be in the name
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
version_cnt = 0
while self._fs.exists(filepath):
@ -366,18 +372,19 @@ class ModelCheckpoint(Callback):
if not isinstance(current, torch.Tensor):
rank_zero_warn(
f"The metric you returned {current} must be a `torch.Tensor` instance, checkpoint not saved"
f" HINT: what is the value of {self.monitor} in validation_epoch_end()?",
f"The metric you returned {self.monitor}={current} must be a `torch.Tensor` "
f"instance, checkpoint not saved HINT: what is the value of {self.monitor}?",
RuntimeWarning,
)
if current is not None:
current = torch.tensor(current)
current = torch.tensor(current).to(pl_module.device)
if current is None:
rank_zero_warn(
f"Can save best model only with {self.monitor} available, skipping.",
RuntimeWarning,
)
m = f"Can save best model only with {self.monitor} available, skipping."
if self.monitor == 'checkpoint_on':
m = f'No checkpoint_on found. Hint: Did you set it in EvalResult(checkpoint_on=tensor) or ' \
f'TrainResult(checkpoint_on=tensor)?'
rank_zero_warn(m, RuntimeWarning)
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose:

View File

@ -172,8 +172,10 @@ class LoggerConnector:
# log metrics
self.trainer.logger_connector.log_metrics(log_metrics, {})
# track metrics for callbacks
# track metrics for callbacks (all prog bar, logged and callback metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
self.trainer.logger_connector.callback_metrics.update(log_metrics)
self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics)
if len(dataloader_result_metrics) > 0:
eval_loop_results.append(dataloader_result_metrics)
@ -263,16 +265,18 @@ class LoggerConnector:
# --------------------------
# track results
# --------------------------
# add the metrics to the loggers
# add the metrics to the loggers and callbacks
if epoch_log_metrics and len(epoch_log_metrics) > 0:
self.log_metrics(epoch_log_metrics, {})
self.callback_metrics.update(epoch_log_metrics)
# add metrics to callbacks
self.callback_metrics.update(epoch_callback_metrics)
# add metrics to progress_bar
# add metrics to progress_bar and callbacks
if len(epoch_progress_bar_metrics) > 0:
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
self.callback_metrics.update(epoch_progress_bar_metrics)
def __auto_reduce_results_on_epoch_end(self, epoch_output):
epoch_log_metrics = {}
@ -326,3 +330,4 @@ class LoggerConnector:
grad_norm_dic = batch_output.grad_norm_dic
if len(metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(metrics, grad_norm_dic)
self.callback_metrics.update(metrics)

View File

@ -146,6 +146,8 @@ class TrainerLoggingMixin(ABC):
# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
callback_metrics = recursive_detach(callback_metrics)
progress_bar_metrics = recursive_detach(progress_bar_metrics)
log_metrics = recursive_detach(log_metrics)
# replace loss with checkpoint_on
if 'loss' in callback_metrics:

View File

@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir):
assert k in eval_results[1]
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) == 8
assert len(trainer.logger_connector.callback_metrics) in [8, 9]
# make sure correct steps were called
assert model.validation_step_called
@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir):
assert k in eval_results[1]
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) == 9
assert len(trainer.logger_connector.callback_metrics) in [9, 10]
# make sure correct steps were called
assert model.validation_step_called
@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir):
assert k in eval_results
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) == 9
assert len(trainer.logger_connector.callback_metrics) in [9, 10]
# make sure correct steps were called
assert model.validation_step_called
@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir):
assert k in eval_results
# ensure all the keys ended up as candidates for callbacks
assert len(trainer.logger_connector.callback_metrics) == 10
assert len(trainer.logger_connector.callback_metrics) in [10, 11]
# make sure correct steps were called
assert model.validation_step_called

View File

@ -58,6 +58,7 @@ def test_training_step_result_log_step_only(tmpdir):
assert len(logged_metrics) == 4
# make sure we are using the correct metrics for callbacks
assert len(trainer.logger_connector.callback_metrics) == 8
assert trainer.logger_connector.callback_metrics['checkpoint_on'] == 171
# make sure pbar metrics are correct ang log metrics did not leak
@ -123,6 +124,8 @@ def test_training_step_result_log_epoch_only(tmpdir):
assert not model.training_step_end_called
assert not model.training_epoch_end_called
assert len(trainer.logger_connector.callback_metrics) == 12
# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == epochs
epoch_metrics = trainer.dev_debugger.logged_metrics
@ -198,6 +201,8 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert not model.training_step_end_called
assert not model.training_epoch_end_called
assert len(trainer.logger_connector.callback_metrics) == 8
# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs
epoch_metrics = trainer.dev_debugger.logged_metrics
@ -323,6 +328,8 @@ def test_training_step_epoch_end_result(tmpdir):
)
trainer.fit(model)
assert len(trainer.logger_connector.callback_metrics) == 11
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
@ -403,6 +410,8 @@ def test_no_auto_callbacks_with_train_loop_only(tmpdir):
)
trainer.fit(model)
assert len(trainer.logger_connector.callback_metrics) == 2
all_losses = trainer.dev_debugger.saved_train_losses
assert len(all_losses) == batches * epochs

View File

@ -278,6 +278,8 @@ def test_val_step_epoch_step_metrics(tmpdir):
)
trainer.fit(model)
assert len(trainer.logger_connector.callback_metrics) == 7
# make sure correct steps were called
assert model.validation_step_called
assert not model.validation_step_end_called
@ -352,6 +354,8 @@ def test_val_step_epoch_end_result(tmpdir):
)
trainer.fit(model)
assert len(trainer.logger_connector.callback_metrics) == 6
# make sure correct steps were called
assert model.validation_step_called
assert not model.validation_step_end_called