Add `@override` for subclasses of PyTorch `Logger` (#18948)

This commit is contained in:
Victor Prins 2023-11-06 16:18:39 +01:00 committed by GitHub
parent 3b05d833cc
commit b30f1a995d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 0 deletions

View File

@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from torch.nn import Module
from typing_extensions import override
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
@ -302,12 +303,14 @@ class CometLogger(Logger):
return self._experiment
@override
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
params = _convert_params(params)
params = _flatten_dict(params)
self.experiment.log_parameters(params)
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
@ -324,6 +327,7 @@ class CometLogger(Logger):
def reset_experiment(self) -> None:
self._experiment = None
@override
@rank_zero_only
def finalize(self, status: str) -> None:
r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you
@ -342,6 +346,7 @@ class CometLogger(Logger):
self.reset_experiment()
@property
@override
def save_dir(self) -> Optional[str]:
"""Gets the save directory.
@ -352,6 +357,7 @@ class CometLogger(Logger):
return self._save_dir
@property
@override
def name(self) -> str:
"""Gets the project name.
@ -369,6 +375,7 @@ class CometLogger(Logger):
return "comet-default"
@property
@override
def version(self) -> str:
"""Gets the version.
@ -417,6 +424,7 @@ class CometLogger(Logger):
state["_experiment"] = None
return state
@override
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
if self._experiment is not None:
self._experiment.set_model_graph(model)

View File

@ -23,6 +23,8 @@ import os
from argparse import Namespace
from typing import Any, Dict, Optional, Union
from typing_extensions import override
from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger
from lightning.fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter
from lightning.fabric.loggers.logger import rank_zero_experiment
@ -58,6 +60,7 @@ class ExperimentWriter(_FabricExperimentWriter):
"""Record hparams."""
self.hparams.update(params)
@override
def save(self) -> None:
"""Save recorded hparams and metrics into files."""
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
@ -106,6 +109,7 @@ class CSVLogger(Logger, FabricCSVLogger):
self._save_dir = os.fspath(save_dir)
@property
@override
def root_dir(self) -> str:
"""Parent directory for all checkpoint subdirectories.
@ -116,6 +120,7 @@ class CSVLogger(Logger, FabricCSVLogger):
return os.path.join(self.save_dir, self.name)
@property
@override
def log_dir(self) -> str:
"""The log directory for this run.
@ -128,6 +133,7 @@ class CSVLogger(Logger, FabricCSVLogger):
return os.path.join(self.root_dir, version)
@property
@override
def save_dir(self) -> str:
"""The current directory where logs are saved.
@ -137,12 +143,14 @@ class CSVLogger(Logger, FabricCSVLogger):
"""
return self._save_dir
@override
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
params = _convert_params(params)
self.experiment.log_hparams(params)
@property
@override
@rank_zero_experiment
def experiment(self) -> _FabricExperimentWriter:
r"""Actual _ExperimentWriter object. To use _ExperimentWriter features in your

View File

@ -21,6 +21,7 @@ from collections import defaultdict
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
import numpy as np
from typing_extensions import override
from lightning.fabric.loggers import Logger as FabricLogger
from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility
@ -63,18 +64,22 @@ class DummyLogger(Logger):
"""Return the experiment object associated with this logger."""
return self._experiment
@override
def log_metrics(self, *args: Any, **kwargs: Any) -> None:
pass
@override
def log_hyperparams(self, *args: Any, **kwargs: Any) -> None:
pass
@property
@override
def name(self) -> str:
"""Return the experiment name."""
return ""
@property
@override
def version(self) -> str:
"""Return the experiment version."""
return ""

View File

@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, O
import yaml
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
@ -217,6 +218,7 @@ class MLFlowLogger(Logger):
_ = self.experiment
return self._experiment_id
@override
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
params = _convert_params(params)
@ -232,6 +234,7 @@ class MLFlowLogger(Logger):
for idx in range(0, len(params_list), 100):
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
@ -259,6 +262,7 @@ class MLFlowLogger(Logger):
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)
@override
@rank_zero_only
def finalize(self, status: str = "success") -> None:
if not self._initialized:
@ -278,6 +282,7 @@ class MLFlowLogger(Logger):
self.experiment.set_terminated(self.run_id, status)
@property
@override
def save_dir(self) -> Optional[str]:
"""The root file directory in which MLflow experiments are saved.
@ -291,6 +296,7 @@ class MLFlowLogger(Logger):
return None
@property
@override
def name(self) -> Optional[str]:
"""Get the experiment id.
@ -301,6 +307,7 @@ class MLFlowLogger(Logger):
return self.experiment_id
@property
@override
def version(self) -> Optional[str]:
"""Get the run id.
@ -310,6 +317,7 @@ class MLFlowLogger(Logger):
"""
return self.run_id
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
# log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:

View File

@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Uni
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
@ -387,6 +388,7 @@ class NeptuneLogger(Logger):
return self._run_instance
@override
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
r"""Log hyperparameters to the run.
@ -436,6 +438,7 @@ class NeptuneLogger(Logger):
self.run[parameters_key] = stringify_unsupported(params)
@override
@rank_zero_only
def log_metrics( # type: ignore[override]
self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
@ -457,6 +460,7 @@ class NeptuneLogger(Logger):
# Lightning does not always guarantee.
self.run[key].append(val)
@override
@rank_zero_only
def finalize(self, status: str) -> None:
if not self._run_instance:
@ -469,6 +473,7 @@ class NeptuneLogger(Logger):
super().finalize(status)
@property
@override
def save_dir(self) -> Optional[str]:
"""Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save
locally.
@ -491,6 +496,7 @@ class NeptuneLogger(Logger):
content=model_str, extension="txt"
)
@override
@rank_zero_only
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
@ -578,11 +584,13 @@ class NeptuneLogger(Logger):
yield from cls._dict_paths(v, path)
@property
@override
def name(self) -> Optional[str]:
"""Return the experiment name or 'offline-name' when exp is run in offline mode."""
return self._run_name
@property
@override
def version(self) -> Optional[str]:
"""Return the experiment version.

View File

@ -22,6 +22,7 @@ from argparse import Namespace
from typing import Any, Dict, Optional, Union
from torch import Tensor
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE
@ -112,6 +113,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
self.hparams: Union[Dict[str, Any], Namespace] = {}
@property
@override
def root_dir(self) -> str:
"""Parent directory for all tensorboard checkpoint subdirectories.
@ -122,6 +124,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
return os.path.join(super().root_dir, self.name)
@property
@override
def log_dir(self) -> str:
"""The directory for this run's tensorboard checkpoint.
@ -139,6 +142,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
return log_dir
@property
@override
def save_dir(self) -> str:
"""Gets the save directory where the TensorBoard experiments are saved.
@ -148,6 +152,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
"""
return self._root_dir
@override
@rank_zero_only
def log_hyperparams( # type: ignore[override]
self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None
@ -174,6 +179,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
return super().log_hyperparams(params=params, metrics=metrics)
@override
@rank_zero_only
def log_graph( # type: ignore[override]
self, model: "pl.LightningModule", input_array: Optional[Tensor] = None
@ -200,6 +206,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
with pl.core.module._jit_is_scripting():
self.experiment.add_graph(model, input_array)
@override
@rank_zero_only
def save(self) -> None:
super().save()
@ -212,6 +219,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
if _is_dir(self._fs, dir_path) and not self._fs.isfile(hparams_file):
save_hparams_to_yaml(hparams_file, self.hparams)
@override
@rank_zero_only
def finalize(self, status: str) -> None:
super().finalize(status)
@ -219,6 +227,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
# saving hparams happens independent of experiment manager
self.save()
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
"""Called after model checkpoint callback saves a new checkpoint.
@ -228,6 +237,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
"""
pass
@override
def _get_next_version(self) -> int:
root_dir = self.root_dir

View File

@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, U
import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from lightning.fabric.utilities.types import _PATH
@ -410,12 +411,14 @@ class WandbLogger(Logger):
def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None:
self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph)
@override
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
params = _convert_params(params)
params = _sanitize_callable_params(params)
self.experiment.config.update(params, allow_val_change=True)
@override
@rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
@ -483,6 +486,7 @@ class WandbLogger(Logger):
self.log_metrics(metrics, step) # type: ignore[arg-type]
@property
@override
def save_dir(self) -> Optional[str]:
"""Gets the save directory.
@ -493,6 +497,7 @@ class WandbLogger(Logger):
return self._save_dir
@property
@override
def name(self) -> Optional[str]:
"""The project name of this experiment.
@ -504,6 +509,7 @@ class WandbLogger(Logger):
return self._project
@property
@override
def version(self) -> Optional[str]:
"""Gets the id of the experiment.
@ -514,6 +520,7 @@ class WandbLogger(Logger):
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
# log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
@ -565,6 +572,7 @@ class WandbLogger(Logger):
"""
return self.experiment.use_artifact(artifact, type=artifact_type)
@override
@rank_zero_only
def finalize(self, status: str) -> None:
if status != "success":