diff --git a/CHANGELOG.md b/CHANGELOG.md index c2b5ca5c7a..e818af4675 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Avoid calling `average_parameters` multiple times per optimizer step ([#12452](https://github.com/PyTorchLightning/pytorch-lightning/pull/12452)) -- +- Properly pass some Logger's parent's arguments to `super().__init__()` ([#12609](https://github.com/PyTorchLightning/pytorch-lightning/pull/12609)) - diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index a4fca47923..bbd870a15e 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -19,7 +19,7 @@ Comet Logger import logging import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union import torch from torch import is_tensor @@ -140,13 +140,15 @@ class CometLogger(LightningLoggerBase): experiment_key: Optional[str] = None, offline: bool = False, prefix: str = "", + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_default_func: Optional[Callable[[Sequence[float]], float]] = None, **kwargs, ): if comet_ml is None: raise ModuleNotFoundError( "You want to use `comet_ml` logger which is not installed yet, install it with `pip install comet-ml`." ) - super().__init__() + super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func) self._experiment = None # Determine online or offline mode based on which arguments were passed to CometLogger diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 0544366151..05952f6bc5 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -24,7 +24,7 @@ import os import warnings from argparse import Namespace from functools import reduce -from typing import Any, Dict, Generator, Optional, Set, Union +from typing import Any, Callable, Dict, Generator, Mapping, Optional, Sequence, Set, Union from weakref import ReferenceType import torch @@ -265,6 +265,8 @@ class NeptuneLogger(LightningLoggerBase): run: Optional["Run"] = None, log_model_checkpoints: Optional[bool] = True, prefix: str = "training", + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_default_func: Optional[Callable[[Sequence[float]], float]] = None, **neptune_run_kwargs, ): # verify if user passed proper init arguments @@ -275,7 +277,7 @@ class NeptuneLogger(LightningLoggerBase): " `pip install neptune-client`." ) - super().__init__() + super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func) self._log_model_checkpoints = log_model_checkpoints self._prefix = prefix self._run_name = name diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index fcd3500360..f2bca67372 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -19,7 +19,7 @@ TensorBoard Logger import logging import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union import numpy as np import torch @@ -94,9 +94,11 @@ class TensorBoardLogger(LightningLoggerBase): default_hp_metric: bool = True, prefix: str = "", sub_dir: Optional[str] = None, + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_default_func: Optional[Callable[[Sequence[float]], float]] = None, **kwargs, ): - super().__init__() + super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func) self._save_dir = save_dir self._name = name or "" self._version = version diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 2a599e54bc..a9809c4449 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -18,7 +18,7 @@ Weights and Biases Logger import os from argparse import Namespace from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union from weakref import ReferenceType import torch.nn as nn @@ -259,6 +259,8 @@ class WandbLogger(LightningLoggerBase): log_model: Union[str, bool] = False, experiment=None, prefix: Optional[str] = "", + agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_default_func: Optional[Callable[[Sequence[float]], float]] = None, **kwargs, ): if wandb is None: @@ -281,7 +283,7 @@ class WandbLogger(LightningLoggerBase): "Hint: Upgrade with `pip install --upgrade wandb`." ) - super().__init__() + super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func) self._offline = offline self._log_model = log_model self._prefix = prefix