From a8ee5cacb7b31064c8b73372ca944d43f7caadc6 Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Wed, 23 Feb 2022 16:42:51 -0500 Subject: [PATCH] Further clean up aggregation logic (#12053) --- CHANGELOG.md | 3 ++ pytorch_lightning/loggers/base.py | 70 ++----------------------------- 2 files changed, 6 insertions(+), 67 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a2abb489d..bd552c8c42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -580,6 +580,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `get_mp_spawn_kwargs` from `DDPSpawnStrategy` and `TPUSpawnStrategy` in favor of configuration in the `_SpawnLauncher` ([#11966](https://github.com/PyTorchLightning/pytorch-lightning/pull/11966)) +- Removed `_aggregate_metrics`, `_reduce_agg_metrics`, and `_finalize_agg_metrics` from `LightningLoggerBase` ([#12053](https://github.com/PyTorchLightning/pytorch-lightning/pull/12053)) + + ### Fixed - Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 5ff7114835..f0a8ba13db 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -19,7 +19,7 @@ import operator from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Union from weakref import ReferenceType import numpy as np @@ -123,66 +123,6 @@ class LightningLoggerBase(ABC): "`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8." ) - def _aggregate_metrics( - self, metrics: Dict[str, float], step: Optional[int] = None - ) -> Tuple[int, Optional[Dict[str, float]]]: - """Aggregates metrics. - - .. deprecated:: v1.6 - This method is deprecated in v1.6 and will be removed in v1.8. - - Args: - metrics: Dictionary with metric names as keys and measured quantities as values - step: Step number at which the metrics should be recorded - - Returns: - Step and aggregated metrics. The return value could be ``None``. In such case, metrics - are added to the aggregation list, but not aggregated yet. - """ - # if you still receiving metric from the same step, just accumulate it - if step == self._prev_step: - self._metrics_to_agg.append(metrics) - return step, None - - # compute the metrics - agg_step, agg_mets = self._reduce_agg_metrics() - - # as new step received reset accumulator - self._metrics_to_agg = [metrics] - self._prev_step = step - return agg_step, agg_mets - - def _reduce_agg_metrics(self): - """Aggregate accumulated metrics. - - See deprecation warning below. - - .. deprecated:: v1.6 - This method is deprecated in v1.6 and will be removed in v1.8. - """ - # compute the metrics - if not self._metrics_to_agg: - agg_mets = None - elif len(self._metrics_to_agg) == 1: - agg_mets = self._metrics_to_agg[0] - else: - agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func) - return self._prev_step, agg_mets - - def _finalize_agg_metrics(self): - """This shall be called before save/close. - - See deprecation warning below. - - .. deprecated:: v1.6 - This method is deprecated in v1.6 and will be removed in v1.8. - """ - agg_step, metrics_to_log = self._reduce_agg_metrics() - self._metrics_to_agg = [] - - if metrics_to_log is not None: - self.log_metrics(metrics=metrics_to_log, step=agg_step) - def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead it aggregates them and logs only if metrics are ready to be logged. @@ -195,10 +135,7 @@ class LightningLoggerBase(ABC): metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """ - agg_step, metrics_to_log = self._aggregate_metrics(metrics=metrics, step=step) - - if metrics_to_log: - self.log_metrics(metrics=metrics_to_log, step=agg_step) + self.log_metrics(metrics=metrics, step=step) @abstractmethod def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): @@ -221,7 +158,7 @@ class LightningLoggerBase(ABC): Args: params: :class:`~argparse.Namespace` containing the hyperparameters args: Optional positional arguments, depends on the specific logger being used - kwargs: Optional keywoard arguments, depends on the specific logger being used + kwargs: Optional keyword arguments, depends on the specific logger being used """ def log_graph(self, model: "pl.LightningModule", input_array=None) -> None: @@ -235,7 +172,6 @@ class LightningLoggerBase(ABC): def save(self) -> None: """Save log data.""" - self._finalize_agg_metrics() def finalize(self, status: str) -> None: """Do any processing that is necessary to finalize an experiment.