enable any logged metric to be accessible in callbacks (#3598)
* enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * enable any logged or written metric to be accessible in callbacks * clarify forward * clarify forward * clarify forward * clarify forward
This commit is contained in:
parent
d2a3d6aa8e
commit
c591013708
|
@ -22,6 +22,7 @@ Automatically save model checkpoints during training.
|
|||
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -219,7 +220,8 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode]
|
||||
|
||||
return monitor_op(current, self.best_k_models[self.kth_best_model_path])
|
||||
val = monitor_op(current, self.best_k_models[self.kth_best_model_path])
|
||||
return val
|
||||
|
||||
@classmethod
|
||||
def _format_checkpoint_name(
|
||||
|
@ -351,7 +353,11 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
self.epoch_last_check = epoch
|
||||
|
||||
ckpt_name_metrics = trainer.logger_connector.logged_metrics
|
||||
# anything logged or in callbacks can be in the name
|
||||
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
|
||||
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
|
||||
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
|
||||
|
||||
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
|
||||
version_cnt = 0
|
||||
while self._fs.exists(filepath):
|
||||
|
@ -366,18 +372,19 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
if not isinstance(current, torch.Tensor):
|
||||
rank_zero_warn(
|
||||
f"The metric you returned {current} must be a `torch.Tensor` instance, checkpoint not saved"
|
||||
f" HINT: what is the value of {self.monitor} in validation_epoch_end()?",
|
||||
f"The metric you returned {self.monitor}={current} must be a `torch.Tensor` "
|
||||
f"instance, checkpoint not saved HINT: what is the value of {self.monitor}?",
|
||||
RuntimeWarning,
|
||||
)
|
||||
if current is not None:
|
||||
current = torch.tensor(current)
|
||||
current = torch.tensor(current).to(pl_module.device)
|
||||
|
||||
if current is None:
|
||||
rank_zero_warn(
|
||||
f"Can save best model only with {self.monitor} available, skipping.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
m = f"Can save best model only with {self.monitor} available, skipping."
|
||||
if self.monitor == 'checkpoint_on':
|
||||
m = f'No checkpoint_on found. Hint: Did you set it in EvalResult(checkpoint_on=tensor) or ' \
|
||||
f'TrainResult(checkpoint_on=tensor)?'
|
||||
rank_zero_warn(m, RuntimeWarning)
|
||||
elif self.check_monitor_top_k(current):
|
||||
self._do_check_save(filepath, current, epoch, trainer, pl_module)
|
||||
elif self.verbose:
|
||||
|
|
|
@ -172,8 +172,10 @@ class LoggerConnector:
|
|||
# log metrics
|
||||
self.trainer.logger_connector.log_metrics(log_metrics, {})
|
||||
|
||||
# track metrics for callbacks
|
||||
# track metrics for callbacks (all prog bar, logged and callback metrics)
|
||||
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
|
||||
self.trainer.logger_connector.callback_metrics.update(log_metrics)
|
||||
self.trainer.logger_connector.callback_metrics.update(prog_bar_metrics)
|
||||
|
||||
if len(dataloader_result_metrics) > 0:
|
||||
eval_loop_results.append(dataloader_result_metrics)
|
||||
|
@ -263,16 +265,18 @@ class LoggerConnector:
|
|||
# --------------------------
|
||||
# track results
|
||||
# --------------------------
|
||||
# add the metrics to the loggers
|
||||
# add the metrics to the loggers and callbacks
|
||||
if epoch_log_metrics and len(epoch_log_metrics) > 0:
|
||||
self.log_metrics(epoch_log_metrics, {})
|
||||
self.callback_metrics.update(epoch_log_metrics)
|
||||
|
||||
# add metrics to callbacks
|
||||
self.callback_metrics.update(epoch_callback_metrics)
|
||||
|
||||
# add metrics to progress_bar
|
||||
# add metrics to progress_bar and callbacks
|
||||
if len(epoch_progress_bar_metrics) > 0:
|
||||
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
|
||||
self.callback_metrics.update(epoch_progress_bar_metrics)
|
||||
|
||||
def __auto_reduce_results_on_epoch_end(self, epoch_output):
|
||||
epoch_log_metrics = {}
|
||||
|
@ -326,3 +330,4 @@ class LoggerConnector:
|
|||
grad_norm_dic = batch_output.grad_norm_dic
|
||||
if len(metrics) > 0 or len(grad_norm_dic) > 0:
|
||||
self.log_metrics(metrics, grad_norm_dic)
|
||||
self.callback_metrics.update(metrics)
|
||||
|
|
|
@ -146,6 +146,8 @@ class TrainerLoggingMixin(ABC):
|
|||
# detach all metrics for callbacks to prevent memory leaks
|
||||
# no .item() because it will slow things down
|
||||
callback_metrics = recursive_detach(callback_metrics)
|
||||
progress_bar_metrics = recursive_detach(progress_bar_metrics)
|
||||
log_metrics = recursive_detach(log_metrics)
|
||||
|
||||
# replace loss with checkpoint_on
|
||||
if 'loss' in callback_metrics:
|
||||
|
|
|
@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir):
|
|||
assert k in eval_results[1]
|
||||
|
||||
# ensure all the keys ended up as candidates for callbacks
|
||||
assert len(trainer.logger_connector.callback_metrics) == 8
|
||||
assert len(trainer.logger_connector.callback_metrics) in [8, 9]
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
|
@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir):
|
|||
assert k in eval_results[1]
|
||||
|
||||
# ensure all the keys ended up as candidates for callbacks
|
||||
assert len(trainer.logger_connector.callback_metrics) == 9
|
||||
assert len(trainer.logger_connector.callback_metrics) in [9, 10]
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
|
@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir):
|
|||
assert k in eval_results
|
||||
|
||||
# ensure all the keys ended up as candidates for callbacks
|
||||
assert len(trainer.logger_connector.callback_metrics) == 9
|
||||
assert len(trainer.logger_connector.callback_metrics) in [9, 10]
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
|
@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir):
|
|||
assert k in eval_results
|
||||
|
||||
# ensure all the keys ended up as candidates for callbacks
|
||||
assert len(trainer.logger_connector.callback_metrics) == 10
|
||||
assert len(trainer.logger_connector.callback_metrics) in [10, 11]
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
|
|
|
@ -58,6 +58,7 @@ def test_training_step_result_log_step_only(tmpdir):
|
|||
assert len(logged_metrics) == 4
|
||||
|
||||
# make sure we are using the correct metrics for callbacks
|
||||
assert len(trainer.logger_connector.callback_metrics) == 8
|
||||
assert trainer.logger_connector.callback_metrics['checkpoint_on'] == 171
|
||||
|
||||
# make sure pbar metrics are correct ang log metrics did not leak
|
||||
|
@ -123,6 +124,8 @@ def test_training_step_result_log_epoch_only(tmpdir):
|
|||
assert not model.training_step_end_called
|
||||
assert not model.training_epoch_end_called
|
||||
|
||||
assert len(trainer.logger_connector.callback_metrics) == 12
|
||||
|
||||
# make sure correct metrics are logged (one per batch step as requested)
|
||||
assert len(trainer.dev_debugger.logged_metrics) == epochs
|
||||
epoch_metrics = trainer.dev_debugger.logged_metrics
|
||||
|
@ -198,6 +201,8 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
|
|||
assert not model.training_step_end_called
|
||||
assert not model.training_epoch_end_called
|
||||
|
||||
assert len(trainer.logger_connector.callback_metrics) == 8
|
||||
|
||||
# make sure correct metrics are logged (one per batch step as requested)
|
||||
assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs
|
||||
epoch_metrics = trainer.dev_debugger.logged_metrics
|
||||
|
@ -323,6 +328,8 @@ def test_training_step_epoch_end_result(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert len(trainer.logger_connector.callback_metrics) == 11
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.training_step_called
|
||||
assert not model.training_step_end_called
|
||||
|
@ -403,6 +410,8 @@ def test_no_auto_callbacks_with_train_loop_only(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert len(trainer.logger_connector.callback_metrics) == 2
|
||||
|
||||
all_losses = trainer.dev_debugger.saved_train_losses
|
||||
assert len(all_losses) == batches * epochs
|
||||
|
||||
|
|
|
@ -278,6 +278,8 @@ def test_val_step_epoch_step_metrics(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert len(trainer.logger_connector.callback_metrics) == 7
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
assert not model.validation_step_end_called
|
||||
|
@ -352,6 +354,8 @@ def test_val_step_epoch_end_result(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert len(trainer.logger_connector.callback_metrics) == 6
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
assert not model.validation_step_end_called
|
||||
|
|
Loading…
Reference in New Issue