Further clean up aggregation logic (#12053)
This commit is contained in:
parent
1026ceb86d
commit
a8ee5cacb7
|
@ -580,6 +580,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed `get_mp_spawn_kwargs` from `DDPSpawnStrategy` and `TPUSpawnStrategy` in favor of configuration in the `_SpawnLauncher` ([#11966](https://github.com/PyTorchLightning/pytorch-lightning/pull/11966))
|
||||
|
||||
|
||||
- Removed `_aggregate_metrics`, `_reduce_agg_metrics`, and `_finalize_agg_metrics` from `LightningLoggerBase` ([#12053](https://github.com/PyTorchLightning/pytorch-lightning/pull/12053))
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752)
|
||||
|
|
|
@ -19,7 +19,7 @@ import operator
|
|||
from abc import ABC, abstractmethod
|
||||
from argparse import Namespace
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
import numpy as np
|
||||
|
@ -123,66 +123,6 @@ class LightningLoggerBase(ABC):
|
|||
"`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8."
|
||||
)
|
||||
|
||||
def _aggregate_metrics(
|
||||
self, metrics: Dict[str, float], step: Optional[int] = None
|
||||
) -> Tuple[int, Optional[Dict[str, float]]]:
|
||||
"""Aggregates metrics.
|
||||
|
||||
.. deprecated:: v1.6
|
||||
This method is deprecated in v1.6 and will be removed in v1.8.
|
||||
|
||||
Args:
|
||||
metrics: Dictionary with metric names as keys and measured quantities as values
|
||||
step: Step number at which the metrics should be recorded
|
||||
|
||||
Returns:
|
||||
Step and aggregated metrics. The return value could be ``None``. In such case, metrics
|
||||
are added to the aggregation list, but not aggregated yet.
|
||||
"""
|
||||
# if you still receiving metric from the same step, just accumulate it
|
||||
if step == self._prev_step:
|
||||
self._metrics_to_agg.append(metrics)
|
||||
return step, None
|
||||
|
||||
# compute the metrics
|
||||
agg_step, agg_mets = self._reduce_agg_metrics()
|
||||
|
||||
# as new step received reset accumulator
|
||||
self._metrics_to_agg = [metrics]
|
||||
self._prev_step = step
|
||||
return agg_step, agg_mets
|
||||
|
||||
def _reduce_agg_metrics(self):
|
||||
"""Aggregate accumulated metrics.
|
||||
|
||||
See deprecation warning below.
|
||||
|
||||
.. deprecated:: v1.6
|
||||
This method is deprecated in v1.6 and will be removed in v1.8.
|
||||
"""
|
||||
# compute the metrics
|
||||
if not self._metrics_to_agg:
|
||||
agg_mets = None
|
||||
elif len(self._metrics_to_agg) == 1:
|
||||
agg_mets = self._metrics_to_agg[0]
|
||||
else:
|
||||
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
|
||||
return self._prev_step, agg_mets
|
||||
|
||||
def _finalize_agg_metrics(self):
|
||||
"""This shall be called before save/close.
|
||||
|
||||
See deprecation warning below.
|
||||
|
||||
.. deprecated:: v1.6
|
||||
This method is deprecated in v1.6 and will be removed in v1.8.
|
||||
"""
|
||||
agg_step, metrics_to_log = self._reduce_agg_metrics()
|
||||
self._metrics_to_agg = []
|
||||
|
||||
if metrics_to_log is not None:
|
||||
self.log_metrics(metrics=metrics_to_log, step=agg_step)
|
||||
|
||||
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
|
||||
"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead
|
||||
it aggregates them and logs only if metrics are ready to be logged.
|
||||
|
@ -195,10 +135,7 @@ class LightningLoggerBase(ABC):
|
|||
metrics: Dictionary with metric names as keys and measured quantities as values
|
||||
step: Step number at which the metrics should be recorded
|
||||
"""
|
||||
agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step)
|
||||
|
||||
if metrics_to_log:
|
||||
self.log_metrics(metrics=metrics_to_log, step=agg_step)
|
||||
self.log_metrics(metrics=metrics, step=step)
|
||||
|
||||
@abstractmethod
|
||||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
|
||||
|
@ -221,7 +158,7 @@ class LightningLoggerBase(ABC):
|
|||
Args:
|
||||
params: :class:`~argparse.Namespace` containing the hyperparameters
|
||||
args: Optional positional arguments, depends on the specific logger being used
|
||||
kwargs: Optional keywoard arguments, depends on the specific logger being used
|
||||
kwargs: Optional keyword arguments, depends on the specific logger being used
|
||||
"""
|
||||
|
||||
def log_graph(self, model: "pl.LightningModule", input_array=None) -> None:
|
||||
|
@ -235,7 +172,6 @@ class LightningLoggerBase(ABC):
|
|||
|
||||
def save(self) -> None:
|
||||
"""Save log data."""
|
||||
self._finalize_agg_metrics()
|
||||
|
||||
def finalize(self, status: str) -> None:
|
||||
"""Do any processing that is necessary to finalize an experiment.
|
||||
|
|
Loading…
Reference in New Issue