Add `@override` for subclasses of PyTorch `Logger` (#18948)
This commit is contained in:
parent
3b05d833cc
commit
b30f1a995d
|
@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
|
|||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
|
||||
|
@ -302,12 +303,14 @@ class CometLogger(Logger):
|
|||
|
||||
return self._experiment
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
|
||||
params = _convert_params(params)
|
||||
params = _flatten_dict(params)
|
||||
self.experiment.log_parameters(params)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
|
||||
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
|
||||
|
@ -324,6 +327,7 @@ class CometLogger(Logger):
|
|||
def reset_experiment(self) -> None:
|
||||
self._experiment = None
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you
|
||||
|
@ -342,6 +346,7 @@ class CometLogger(Logger):
|
|||
self.reset_experiment()
|
||||
|
||||
@property
|
||||
@override
|
||||
def save_dir(self) -> Optional[str]:
|
||||
"""Gets the save directory.
|
||||
|
||||
|
@ -352,6 +357,7 @@ class CometLogger(Logger):
|
|||
return self._save_dir
|
||||
|
||||
@property
|
||||
@override
|
||||
def name(self) -> str:
|
||||
"""Gets the project name.
|
||||
|
||||
|
@ -369,6 +375,7 @@ class CometLogger(Logger):
|
|||
return "comet-default"
|
||||
|
||||
@property
|
||||
@override
|
||||
def version(self) -> str:
|
||||
"""Gets the version.
|
||||
|
||||
|
@ -417,6 +424,7 @@ class CometLogger(Logger):
|
|||
state["_experiment"] = None
|
||||
return state
|
||||
|
||||
@override
|
||||
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
|
||||
if self._experiment is not None:
|
||||
self._experiment.set_model_graph(model)
|
||||
|
|
|
@ -23,6 +23,8 @@ import os
|
|||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger
|
||||
from lightning.fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter
|
||||
from lightning.fabric.loggers.logger import rank_zero_experiment
|
||||
|
@ -58,6 +60,7 @@ class ExperimentWriter(_FabricExperimentWriter):
|
|||
"""Record hparams."""
|
||||
self.hparams.update(params)
|
||||
|
||||
@override
|
||||
def save(self) -> None:
|
||||
"""Save recorded hparams and metrics into files."""
|
||||
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
|
||||
|
@ -106,6 +109,7 @@ class CSVLogger(Logger, FabricCSVLogger):
|
|||
self._save_dir = os.fspath(save_dir)
|
||||
|
||||
@property
|
||||
@override
|
||||
def root_dir(self) -> str:
|
||||
"""Parent directory for all checkpoint subdirectories.
|
||||
|
||||
|
@ -116,6 +120,7 @@ class CSVLogger(Logger, FabricCSVLogger):
|
|||
return os.path.join(self.save_dir, self.name)
|
||||
|
||||
@property
|
||||
@override
|
||||
def log_dir(self) -> str:
|
||||
"""The log directory for this run.
|
||||
|
||||
|
@ -128,6 +133,7 @@ class CSVLogger(Logger, FabricCSVLogger):
|
|||
return os.path.join(self.root_dir, version)
|
||||
|
||||
@property
|
||||
@override
|
||||
def save_dir(self) -> str:
|
||||
"""The current directory where logs are saved.
|
||||
|
||||
|
@ -137,12 +143,14 @@ class CSVLogger(Logger, FabricCSVLogger):
|
|||
"""
|
||||
return self._save_dir
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
|
||||
params = _convert_params(params)
|
||||
self.experiment.log_hparams(params)
|
||||
|
||||
@property
|
||||
@override
|
||||
@rank_zero_experiment
|
||||
def experiment(self) -> _FabricExperimentWriter:
|
||||
r"""Actual _ExperimentWriter object. To use _ExperimentWriter features in your
|
||||
|
|
|
@ -21,6 +21,7 @@ from collections import defaultdict
|
|||
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.loggers import Logger as FabricLogger
|
||||
from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility
|
||||
|
@ -63,18 +64,22 @@ class DummyLogger(Logger):
|
|||
"""Return the experiment object associated with this logger."""
|
||||
return self._experiment
|
||||
|
||||
@override
|
||||
def log_metrics(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@override
|
||||
def log_hyperparams(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@override
|
||||
def name(self) -> str:
|
||||
"""Return the experiment name."""
|
||||
return ""
|
||||
|
||||
@property
|
||||
@override
|
||||
def version(self) -> str:
|
||||
"""Return the experiment version."""
|
||||
return ""
|
||||
|
|
|
@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, O
|
|||
import yaml
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
|
||||
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
||||
|
@ -217,6 +218,7 @@ class MLFlowLogger(Logger):
|
|||
_ = self.experiment
|
||||
return self._experiment_id
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
|
||||
params = _convert_params(params)
|
||||
|
@ -232,6 +234,7 @@ class MLFlowLogger(Logger):
|
|||
for idx in range(0, len(params_list), 100):
|
||||
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
|
||||
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
|
||||
|
@ -259,6 +262,7 @@ class MLFlowLogger(Logger):
|
|||
|
||||
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str = "success") -> None:
|
||||
if not self._initialized:
|
||||
|
@ -278,6 +282,7 @@ class MLFlowLogger(Logger):
|
|||
self.experiment.set_terminated(self.run_id, status)
|
||||
|
||||
@property
|
||||
@override
|
||||
def save_dir(self) -> Optional[str]:
|
||||
"""The root file directory in which MLflow experiments are saved.
|
||||
|
||||
|
@ -291,6 +296,7 @@ class MLFlowLogger(Logger):
|
|||
return None
|
||||
|
||||
@property
|
||||
@override
|
||||
def name(self) -> Optional[str]:
|
||||
"""Get the experiment id.
|
||||
|
||||
|
@ -301,6 +307,7 @@ class MLFlowLogger(Logger):
|
|||
return self.experiment_id
|
||||
|
||||
@property
|
||||
@override
|
||||
def version(self) -> Optional[str]:
|
||||
"""Get the run id.
|
||||
|
||||
|
@ -310,6 +317,7 @@ class MLFlowLogger(Logger):
|
|||
"""
|
||||
return self.run_id
|
||||
|
||||
@override
|
||||
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
|
||||
# log checkpoints as artifacts
|
||||
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
|
||||
|
|
|
@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Uni
|
|||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
|
||||
|
@ -387,6 +388,7 @@ class NeptuneLogger(Logger):
|
|||
|
||||
return self._run_instance
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
|
||||
r"""Log hyperparameters to the run.
|
||||
|
@ -436,6 +438,7 @@ class NeptuneLogger(Logger):
|
|||
|
||||
self.run[parameters_key] = stringify_unsupported(params)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_metrics( # type: ignore[override]
|
||||
self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
|
||||
|
@ -457,6 +460,7 @@ class NeptuneLogger(Logger):
|
|||
# Lightning does not always guarantee.
|
||||
self.run[key].append(val)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
if not self._run_instance:
|
||||
|
@ -469,6 +473,7 @@ class NeptuneLogger(Logger):
|
|||
super().finalize(status)
|
||||
|
||||
@property
|
||||
@override
|
||||
def save_dir(self) -> Optional[str]:
|
||||
"""Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save
|
||||
locally.
|
||||
|
@ -491,6 +496,7 @@ class NeptuneLogger(Logger):
|
|||
content=model_str, extension="txt"
|
||||
)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
|
||||
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
|
||||
|
@ -578,11 +584,13 @@ class NeptuneLogger(Logger):
|
|||
yield from cls._dict_paths(v, path)
|
||||
|
||||
@property
|
||||
@override
|
||||
def name(self) -> Optional[str]:
|
||||
"""Return the experiment name or 'offline-name' when exp is run in offline mode."""
|
||||
return self._run_name
|
||||
|
||||
@property
|
||||
@override
|
||||
def version(self) -> Optional[str]:
|
||||
"""Return the experiment version.
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from argparse import Namespace
|
|||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from torch import Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE
|
||||
|
@ -112,6 +113,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
self.hparams: Union[Dict[str, Any], Namespace] = {}
|
||||
|
||||
@property
|
||||
@override
|
||||
def root_dir(self) -> str:
|
||||
"""Parent directory for all tensorboard checkpoint subdirectories.
|
||||
|
||||
|
@ -122,6 +124,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
return os.path.join(super().root_dir, self.name)
|
||||
|
||||
@property
|
||||
@override
|
||||
def log_dir(self) -> str:
|
||||
"""The directory for this run's tensorboard checkpoint.
|
||||
|
||||
|
@ -139,6 +142,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
return log_dir
|
||||
|
||||
@property
|
||||
@override
|
||||
def save_dir(self) -> str:
|
||||
"""Gets the save directory where the TensorBoard experiments are saved.
|
||||
|
||||
|
@ -148,6 +152,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
"""
|
||||
return self._root_dir
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_hyperparams( # type: ignore[override]
|
||||
self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None
|
||||
|
@ -174,6 +179,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
|
||||
return super().log_hyperparams(params=params, metrics=metrics)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_graph( # type: ignore[override]
|
||||
self, model: "pl.LightningModule", input_array: Optional[Tensor] = None
|
||||
|
@ -200,6 +206,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
with pl.core.module._jit_is_scripting():
|
||||
self.experiment.add_graph(model, input_array)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def save(self) -> None:
|
||||
super().save()
|
||||
|
@ -212,6 +219,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
if _is_dir(self._fs, dir_path) and not self._fs.isfile(hparams_file):
|
||||
save_hparams_to_yaml(hparams_file, self.hparams)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
super().finalize(status)
|
||||
|
@ -219,6 +227,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
# saving hparams happens independent of experiment manager
|
||||
self.save()
|
||||
|
||||
@override
|
||||
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
|
||||
"""Called after model checkpoint callback saves a new checkpoint.
|
||||
|
||||
|
@ -228,6 +237,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
|
|||
"""
|
||||
pass
|
||||
|
||||
@override
|
||||
def _get_next_version(self) -> int:
|
||||
root_dir = self.root_dir
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, U
|
|||
import torch.nn as nn
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
|
||||
from lightning.fabric.utilities.types import _PATH
|
||||
|
@ -410,12 +411,14 @@ class WandbLogger(Logger):
|
|||
def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None:
|
||||
self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
|
||||
params = _convert_params(params)
|
||||
params = _sanitize_callable_params(params)
|
||||
self.experiment.config.update(params, allow_val_change=True)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
|
||||
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
|
||||
|
@ -483,6 +486,7 @@ class WandbLogger(Logger):
|
|||
self.log_metrics(metrics, step) # type: ignore[arg-type]
|
||||
|
||||
@property
|
||||
@override
|
||||
def save_dir(self) -> Optional[str]:
|
||||
"""Gets the save directory.
|
||||
|
||||
|
@ -493,6 +497,7 @@ class WandbLogger(Logger):
|
|||
return self._save_dir
|
||||
|
||||
@property
|
||||
@override
|
||||
def name(self) -> Optional[str]:
|
||||
"""The project name of this experiment.
|
||||
|
||||
|
@ -504,6 +509,7 @@ class WandbLogger(Logger):
|
|||
return self._project
|
||||
|
||||
@property
|
||||
@override
|
||||
def version(self) -> Optional[str]:
|
||||
"""Gets the id of the experiment.
|
||||
|
||||
|
@ -514,6 +520,7 @@ class WandbLogger(Logger):
|
|||
# don't create an experiment if we don't have one
|
||||
return self._experiment.id if self._experiment else self._id
|
||||
|
||||
@override
|
||||
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
|
||||
# log checkpoints as artifacts
|
||||
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
|
||||
|
@ -565,6 +572,7 @@ class WandbLogger(Logger):
|
|||
"""
|
||||
return self.experiment.use_artifact(artifact, type=artifact_type)
|
||||
|
||||
@override
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
if status != "success":
|
||||
|
|
Loading…
Reference in New Issue