Further clean up aggregation logic (#12053)

This commit is contained in:
Danielle Pintz 2022-02-23 16:42:51 -05:00 committed by GitHub
parent 1026ceb86d
commit a8ee5cacb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 67 deletions

View File

@ -580,6 +580,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `get_mp_spawn_kwargs` from `DDPSpawnStrategy` and `TPUSpawnStrategy` in favor of configuration in the `_SpawnLauncher` ([#11966](https://github.com/PyTorchLightning/pytorch-lightning/pull/11966))
- Removed `_aggregate_metrics`, `_reduce_agg_metrics`, and `_finalize_agg_metrics` from `LightningLoggerBase` ([#12053](https://github.com/PyTorchLightning/pytorch-lightning/pull/12053))
### Fixed
- Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752)

View File

@ -19,7 +19,7 @@ import operator
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Union
from weakref import ReferenceType
import numpy as np
@ -123,66 +123,6 @@ class LightningLoggerBase(ABC):
"`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8."
)
def _aggregate_metrics(
self, metrics: Dict[str, float], step: Optional[int] = None
) -> Tuple[int, Optional[Dict[str, float]]]:
"""Aggregates metrics.
.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
Returns:
Step and aggregated metrics. The return value could be ``None``. In such case, metrics
are added to the aggregation list, but not aggregated yet.
"""
# if you still receiving metric from the same step, just accumulate it
if step == self._prev_step:
self._metrics_to_agg.append(metrics)
return step, None
# compute the metrics
agg_step, agg_mets = self._reduce_agg_metrics()
# as new step received reset accumulator
self._metrics_to_agg = [metrics]
self._prev_step = step
return agg_step, agg_mets
def _reduce_agg_metrics(self):
"""Aggregate accumulated metrics.
See deprecation warning below.
.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
"""
# compute the metrics
if not self._metrics_to_agg:
agg_mets = None
elif len(self._metrics_to_agg) == 1:
agg_mets = self._metrics_to_agg[0]
else:
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
return self._prev_step, agg_mets
def _finalize_agg_metrics(self):
"""This shall be called before save/close.
See deprecation warning below.
.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
"""
agg_step, metrics_to_log = self._reduce_agg_metrics()
self._metrics_to_agg = []
if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead
it aggregates them and logs only if metrics are ready to be logged.
@ -195,10 +135,7 @@ class LightningLoggerBase(ABC):
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step)
if metrics_to_log:
self.log_metrics(metrics=metrics_to_log, step=agg_step)
self.log_metrics(metrics=metrics, step=step)
@abstractmethod
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
@ -221,7 +158,7 @@ class LightningLoggerBase(ABC):
Args:
params: :class:`~argparse.Namespace` containing the hyperparameters
args: Optional positional arguments, depends on the specific logger being used
kwargs: Optional keywoard arguments, depends on the specific logger being used
kwargs: Optional keyword arguments, depends on the specific logger being used
"""
def log_graph(self, model: "pl.LightningModule", input_array=None) -> None:
@ -235,7 +172,6 @@ class LightningLoggerBase(ABC):
def save(self) -> None:
"""Save log data."""
self._finalize_agg_metrics()
def finalize(self, status: str) -> None:
"""Do any processing that is necessary to finalize an experiment.