Expose deprecated arguments from logger base interface (#12609)
This commit is contained in:
parent
6d682101a9
commit
cd276fdb6d
|
@ -76,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Avoid calling `average_parameters` multiple times per optimizer step ([#12452](https://github.com/PyTorchLightning/pytorch-lightning/pull/12452))
|
||||
|
||||
|
||||
-
|
||||
- Properly pass some Logger's parent's arguments to `super().__init__()` ([#12609](https://github.com/PyTorchLightning/pytorch-lightning/pull/12609))
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -19,7 +19,7 @@ Comet Logger
|
|||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch import is_tensor
|
||||
|
@ -140,13 +140,15 @@ class CometLogger(LightningLoggerBase):
|
|||
experiment_key: Optional[str] = None,
|
||||
offline: bool = False,
|
||||
prefix: str = "",
|
||||
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
|
||||
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if comet_ml is None:
|
||||
raise ModuleNotFoundError(
|
||||
"You want to use `comet_ml` logger which is not installed yet, install it with `pip install comet-ml`."
|
||||
)
|
||||
super().__init__()
|
||||
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
|
||||
self._experiment = None
|
||||
|
||||
# Determine online or offline mode based on which arguments were passed to CometLogger
|
||||
|
|
|
@ -24,7 +24,7 @@ import os
|
|||
import warnings
|
||||
from argparse import Namespace
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, Generator, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, Generator, Mapping, Optional, Sequence, Set, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch
|
||||
|
@ -265,6 +265,8 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
run: Optional["Run"] = None,
|
||||
log_model_checkpoints: Optional[bool] = True,
|
||||
prefix: str = "training",
|
||||
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
|
||||
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
|
||||
**neptune_run_kwargs,
|
||||
):
|
||||
# verify if user passed proper init arguments
|
||||
|
@ -275,7 +277,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
" `pip install neptune-client`."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
|
||||
self._log_model_checkpoints = log_model_checkpoints
|
||||
self._prefix = prefix
|
||||
self._run_name = name
|
||||
|
|
|
@ -19,7 +19,7 @@ TensorBoard Logger
|
|||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -94,9 +94,11 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
default_hp_metric: bool = True,
|
||||
prefix: str = "",
|
||||
sub_dir: Optional[str] = None,
|
||||
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
|
||||
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
|
||||
self._save_dir = save_dir
|
||||
self._name = name or ""
|
||||
self._version = version
|
||||
|
|
|
@ -18,7 +18,7 @@ Weights and Biases Logger
|
|||
import os
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch.nn as nn
|
||||
|
@ -259,6 +259,8 @@ class WandbLogger(LightningLoggerBase):
|
|||
log_model: Union[str, bool] = False,
|
||||
experiment=None,
|
||||
prefix: Optional[str] = "",
|
||||
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
|
||||
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if wandb is None:
|
||||
|
@ -281,7 +283,7 @@ class WandbLogger(LightningLoggerBase):
|
|||
"Hint: Upgrade with `pip install --upgrade wandb`."
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
|
||||
self._offline = offline
|
||||
self._log_model = log_model
|
||||
self._prefix = prefix
|
||||
|
|
Loading…
Reference in New Issue