Allow user to specify 'step' key while logging metrics (#808)
* allow to specify 'step' key * add test * docs to log_metrics * fix test * rename * also rename
This commit is contained in:
parent
62e9963cf7
commit
06ca6428b6
|
@ -418,7 +418,8 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
The outputs here are strictly for the progress bar.
|
||||
If you don't need to display anything, don't return anything.
|
||||
Any keys present in 'log', 'progress_bar' or the rest of the dictionary
|
||||
are available for callbacks to access.
|
||||
are available for callbacks to access. If you want to manually set current step, you can specify it with
|
||||
'step' key in the 'log' Dict.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
@ -468,7 +469,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
# show val_loss and val_acc in progress bar but only log val_loss
|
||||
results = {
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': {'val_loss': val_loss_mean.item()}
|
||||
'log': {'val_loss': val_loss_mean.item(), 'step': self.current_epoch}
|
||||
}
|
||||
return results
|
||||
|
||||
|
|
|
@ -39,13 +39,12 @@ class TrainerLoggingMixin(ABC):
|
|||
|
||||
def log_metrics(self, metrics, grad_norm_dic, step=None):
|
||||
"""Logs the metric dict passed in.
|
||||
|
||||
:param metrics:
|
||||
:param grad_norm_dic:
|
||||
If `step` parameter is None and `step` key is presented is metrics,
|
||||
uses metrics["step"] as a step
|
||||
:param metrics (dict): Metric values
|
||||
:param grad_norm_dic (dict): Gradient norms
|
||||
:param step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
|
||||
"""
|
||||
# added metrics by Lightning for convenience
|
||||
metrics['epoch'] = self.current_epoch
|
||||
|
||||
# add gpu memory
|
||||
if self.on_gpu and self.log_gpu_memory:
|
||||
mem_map = memory.get_memory_profile(self.log_gpu_memory)
|
||||
|
@ -57,6 +56,11 @@ class TrainerLoggingMixin(ABC):
|
|||
# turn all tensors to scalars
|
||||
scalar_metrics = self.metrics_to_scalars(metrics)
|
||||
|
||||
if "step" in scalar_metrics and step is None:
|
||||
step = scalar_metrics.pop("step")
|
||||
else:
|
||||
# added metrics by Lightning for convenience
|
||||
metrics['epoch'] = self.current_epoch
|
||||
step = step if step is not None else self.global_step
|
||||
# log actual metrics
|
||||
if self.proc_rank == 0 and self.logger is not None:
|
||||
|
|
|
@ -376,3 +376,33 @@ def test_custom_logger(tmpdir):
|
|||
assert logger.hparams_logged == hparams
|
||||
assert logger.metrics_logged != {}
|
||||
assert logger.finalized_status == "success"
|
||||
|
||||
|
||||
def test_adding_step_key(tmpdir):
|
||||
logged_step = 0
|
||||
|
||||
def _validation_end(outputs):
|
||||
nonlocal logged_step
|
||||
logged_step += 1
|
||||
return {"log": {"step": logged_step, "val_acc": logged_step / 10}}
|
||||
|
||||
def _log_metrics_decorator(log_metrics_fn):
|
||||
def decorated(metrics, step):
|
||||
if "val_acc" in metrics:
|
||||
assert step == logged_step
|
||||
return log_metrics_fn(metrics, step)
|
||||
|
||||
return decorated
|
||||
|
||||
model, hparams = tutils.get_model()
|
||||
model.validation_end = _validation_end
|
||||
trainer_options = dict(
|
||||
max_epochs=4,
|
||||
default_save_path=tmpdir,
|
||||
train_percent_check=0.001,
|
||||
val_percent_check=0.01,
|
||||
num_sanity_val_steps=0
|
||||
)
|
||||
trainer = Trainer(**trainer_options)
|
||||
trainer.logger.log_metrics = _log_metrics_decorator(trainer.logger.log_metrics)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue