Expose deprecated arguments from logger base interface (#12609)

This commit is contained in:
Carlos Mocholí 2022-04-06 18:47:35 +02:00 committed by lexierule
parent 6d682101a9
commit cd276fdb6d
5 changed files with 17 additions and 9 deletions

View File

@ -76,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Avoid calling `average_parameters` multiple times per optimizer step ([#12452](https://github.com/PyTorchLightning/pytorch-lightning/pull/12452)) - Avoid calling `average_parameters` multiple times per optimizer step ([#12452](https://github.com/PyTorchLightning/pytorch-lightning/pull/12452))
- - Properly pass some Logger's parent's arguments to `super().__init__()` ([#12609](https://github.com/PyTorchLightning/pytorch-lightning/pull/12609))
- -

View File

@ -19,7 +19,7 @@ Comet Logger
import logging import logging
import os import os
from argparse import Namespace from argparse import Namespace
from typing import Any, Dict, Optional, Union from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
import torch import torch
from torch import is_tensor from torch import is_tensor
@ -140,13 +140,15 @@ class CometLogger(LightningLoggerBase):
experiment_key: Optional[str] = None, experiment_key: Optional[str] = None,
offline: bool = False, offline: bool = False,
prefix: str = "", prefix: str = "",
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**kwargs, **kwargs,
): ):
if comet_ml is None: if comet_ml is None:
raise ModuleNotFoundError( raise ModuleNotFoundError(
"You want to use `comet_ml` logger which is not installed yet, install it with `pip install comet-ml`." "You want to use `comet_ml` logger which is not installed yet, install it with `pip install comet-ml`."
) )
super().__init__() super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
self._experiment = None self._experiment = None
# Determine online or offline mode based on which arguments were passed to CometLogger # Determine online or offline mode based on which arguments were passed to CometLogger

View File

@ -24,7 +24,7 @@ import os
import warnings import warnings
from argparse import Namespace from argparse import Namespace
from functools import reduce from functools import reduce
from typing import Any, Dict, Generator, Optional, Set, Union from typing import Any, Callable, Dict, Generator, Mapping, Optional, Sequence, Set, Union
from weakref import ReferenceType from weakref import ReferenceType
import torch import torch
@ -265,6 +265,8 @@ class NeptuneLogger(LightningLoggerBase):
run: Optional["Run"] = None, run: Optional["Run"] = None,
log_model_checkpoints: Optional[bool] = True, log_model_checkpoints: Optional[bool] = True,
prefix: str = "training", prefix: str = "training",
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**neptune_run_kwargs, **neptune_run_kwargs,
): ):
# verify if user passed proper init arguments # verify if user passed proper init arguments
@ -275,7 +277,7 @@ class NeptuneLogger(LightningLoggerBase):
" `pip install neptune-client`." " `pip install neptune-client`."
) )
super().__init__() super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
self._log_model_checkpoints = log_model_checkpoints self._log_model_checkpoints = log_model_checkpoints
self._prefix = prefix self._prefix = prefix
self._run_name = name self._run_name = name

View File

@ -19,7 +19,7 @@ TensorBoard Logger
import logging import logging
import os import os
from argparse import Namespace from argparse import Namespace
from typing import Any, Dict, Optional, Union from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
import numpy as np import numpy as np
import torch import torch
@ -94,9 +94,11 @@ class TensorBoardLogger(LightningLoggerBase):
default_hp_metric: bool = True, default_hp_metric: bool = True,
prefix: str = "", prefix: str = "",
sub_dir: Optional[str] = None, sub_dir: Optional[str] = None,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
self._save_dir = save_dir self._save_dir = save_dir
self._name = name or "" self._name = name or ""
self._version = version self._version = version

View File

@ -18,7 +18,7 @@ Weights and Biases Logger
import os import os
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
from weakref import ReferenceType from weakref import ReferenceType
import torch.nn as nn import torch.nn as nn
@ -259,6 +259,8 @@ class WandbLogger(LightningLoggerBase):
log_model: Union[str, bool] = False, log_model: Union[str, bool] = False,
experiment=None, experiment=None,
prefix: Optional[str] = "", prefix: Optional[str] = "",
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**kwargs, **kwargs,
): ):
if wandb is None: if wandb is None:
@ -281,7 +283,7 @@ class WandbLogger(LightningLoggerBase):
"Hint: Upgrade with `pip install --upgrade wandb`." "Hint: Upgrade with `pip install --upgrade wandb`."
) )
super().__init__() super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
self._offline = offline self._offline = offline
self._log_model = log_model self._log_model = log_model
self._prefix = prefix self._prefix = prefix