Add `@override` for subclasses of PyTorch `Logger` (#18948)

This commit is contained in:
Victor Prins 2023-11-06 16:18:39 +01:00 committed by GitHub
parent 3b05d833cc
commit b30f1a995d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 0 deletions

View File

@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union
from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.imports import RequirementCache
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from typing_extensions import override
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
@ -302,12 +303,14 @@ class CometLogger(Logger):
return self._experiment return self._experiment
@override
@rank_zero_only @rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
params = _convert_params(params) params = _convert_params(params)
params = _flatten_dict(params) params = _flatten_dict(params)
self.experiment.log_parameters(params) self.experiment.log_parameters(params)
@override
@rank_zero_only @rank_zero_only
def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
@ -324,6 +327,7 @@ class CometLogger(Logger):
def reset_experiment(self) -> None: def reset_experiment(self) -> None:
self._experiment = None self._experiment = None
@override
@rank_zero_only @rank_zero_only
def finalize(self, status: str) -> None: def finalize(self, status: str) -> None:
r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you
@ -342,6 +346,7 @@ class CometLogger(Logger):
self.reset_experiment() self.reset_experiment()
@property @property
@override
def save_dir(self) -> Optional[str]: def save_dir(self) -> Optional[str]:
"""Gets the save directory. """Gets the save directory.
@ -352,6 +357,7 @@ class CometLogger(Logger):
return self._save_dir return self._save_dir
@property @property
@override
def name(self) -> str: def name(self) -> str:
"""Gets the project name. """Gets the project name.
@ -369,6 +375,7 @@ class CometLogger(Logger):
return "comet-default" return "comet-default"
@property @property
@override
def version(self) -> str: def version(self) -> str:
"""Gets the version. """Gets the version.
@ -417,6 +424,7 @@ class CometLogger(Logger):
state["_experiment"] = None state["_experiment"] = None
return state return state
@override
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
if self._experiment is not None: if self._experiment is not None:
self._experiment.set_model_graph(model) self._experiment.set_model_graph(model)

View File

@ -23,6 +23,8 @@ import os
from argparse import Namespace from argparse import Namespace
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from typing_extensions import override
from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger
from lightning.fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter from lightning.fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter
from lightning.fabric.loggers.logger import rank_zero_experiment from lightning.fabric.loggers.logger import rank_zero_experiment
@ -58,6 +60,7 @@ class ExperimentWriter(_FabricExperimentWriter):
"""Record hparams.""" """Record hparams."""
self.hparams.update(params) self.hparams.update(params)
@override
def save(self) -> None: def save(self) -> None:
"""Save recorded hparams and metrics into files.""" """Save recorded hparams and metrics into files."""
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
@ -106,6 +109,7 @@ class CSVLogger(Logger, FabricCSVLogger):
self._save_dir = os.fspath(save_dir) self._save_dir = os.fspath(save_dir)
@property @property
@override
def root_dir(self) -> str: def root_dir(self) -> str:
"""Parent directory for all checkpoint subdirectories. """Parent directory for all checkpoint subdirectories.
@ -116,6 +120,7 @@ class CSVLogger(Logger, FabricCSVLogger):
return os.path.join(self.save_dir, self.name) return os.path.join(self.save_dir, self.name)
@property @property
@override
def log_dir(self) -> str: def log_dir(self) -> str:
"""The log directory for this run. """The log directory for this run.
@ -128,6 +133,7 @@ class CSVLogger(Logger, FabricCSVLogger):
return os.path.join(self.root_dir, version) return os.path.join(self.root_dir, version)
@property @property
@override
def save_dir(self) -> str: def save_dir(self) -> str:
"""The current directory where logs are saved. """The current directory where logs are saved.
@ -137,12 +143,14 @@ class CSVLogger(Logger, FabricCSVLogger):
""" """
return self._save_dir return self._save_dir
@override
@rank_zero_only @rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
params = _convert_params(params) params = _convert_params(params)
self.experiment.log_hparams(params) self.experiment.log_hparams(params)
@property @property
@override
@rank_zero_experiment @rank_zero_experiment
def experiment(self) -> _FabricExperimentWriter: def experiment(self) -> _FabricExperimentWriter:
r"""Actual _ExperimentWriter object. To use _ExperimentWriter features in your r"""Actual _ExperimentWriter object. To use _ExperimentWriter features in your

View File

@ -21,6 +21,7 @@ from collections import defaultdict
from typing import Any, Callable, Dict, Mapping, Optional, Sequence from typing import Any, Callable, Dict, Mapping, Optional, Sequence
import numpy as np import numpy as np
from typing_extensions import override
from lightning.fabric.loggers import Logger as FabricLogger from lightning.fabric.loggers import Logger as FabricLogger
from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility
@ -63,18 +64,22 @@ class DummyLogger(Logger):
"""Return the experiment object associated with this logger.""" """Return the experiment object associated with this logger."""
return self._experiment return self._experiment
@override
def log_metrics(self, *args: Any, **kwargs: Any) -> None: def log_metrics(self, *args: Any, **kwargs: Any) -> None:
pass pass
@override
def log_hyperparams(self, *args: Any, **kwargs: Any) -> None: def log_hyperparams(self, *args: Any, **kwargs: Any) -> None:
pass pass
@property @property
@override
def name(self) -> str: def name(self) -> str:
"""Return the experiment name.""" """Return the experiment name."""
return "" return ""
@property @property
@override
def version(self) -> str: def version(self) -> str:
"""Return the experiment version.""" """Return the experiment version."""
return "" return ""

View File

@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, O
import yaml import yaml
from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.imports import RequirementCache
from torch import Tensor from torch import Tensor
from typing_extensions import override
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
@ -217,6 +218,7 @@ class MLFlowLogger(Logger):
_ = self.experiment _ = self.experiment
return self._experiment_id return self._experiment_id
@override
@rank_zero_only @rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
params = _convert_params(params) params = _convert_params(params)
@ -232,6 +234,7 @@ class MLFlowLogger(Logger):
for idx in range(0, len(params_list), 100): for idx in range(0, len(params_list), 100):
self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100]) self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100])
@override
@rank_zero_only @rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
@ -259,6 +262,7 @@ class MLFlowLogger(Logger):
self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list) self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list)
@override
@rank_zero_only @rank_zero_only
def finalize(self, status: str = "success") -> None: def finalize(self, status: str = "success") -> None:
if not self._initialized: if not self._initialized:
@ -278,6 +282,7 @@ class MLFlowLogger(Logger):
self.experiment.set_terminated(self.run_id, status) self.experiment.set_terminated(self.run_id, status)
@property @property
@override
def save_dir(self) -> Optional[str]: def save_dir(self) -> Optional[str]:
"""The root file directory in which MLflow experiments are saved. """The root file directory in which MLflow experiments are saved.
@ -291,6 +296,7 @@ class MLFlowLogger(Logger):
return None return None
@property @property
@override
def name(self) -> Optional[str]: def name(self) -> Optional[str]:
"""Get the experiment id. """Get the experiment id.
@ -301,6 +307,7 @@ class MLFlowLogger(Logger):
return self.experiment_id return self.experiment_id
@property @property
@override
def version(self) -> Optional[str]: def version(self) -> Optional[str]:
"""Get the run id. """Get the run id.
@ -310,6 +317,7 @@ class MLFlowLogger(Logger):
""" """
return self.run_id return self.run_id
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
# log checkpoints as artifacts # log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:

View File

@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Uni
from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.imports import RequirementCache
from torch import Tensor from torch import Tensor
from typing_extensions import override
import lightning.pytorch as pl import lightning.pytorch as pl
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
@ -387,6 +388,7 @@ class NeptuneLogger(Logger):
return self._run_instance return self._run_instance
@override
@rank_zero_only @rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
r"""Log hyperparameters to the run. r"""Log hyperparameters to the run.
@ -436,6 +438,7 @@ class NeptuneLogger(Logger):
self.run[parameters_key] = stringify_unsupported(params) self.run[parameters_key] = stringify_unsupported(params)
@override
@rank_zero_only @rank_zero_only
def log_metrics( # type: ignore[override] def log_metrics( # type: ignore[override]
self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None
@ -457,6 +460,7 @@ class NeptuneLogger(Logger):
# Lightning does not always guarantee. # Lightning does not always guarantee.
self.run[key].append(val) self.run[key].append(val)
@override
@rank_zero_only @rank_zero_only
def finalize(self, status: str) -> None: def finalize(self, status: str) -> None:
if not self._run_instance: if not self._run_instance:
@ -469,6 +473,7 @@ class NeptuneLogger(Logger):
super().finalize(status) super().finalize(status)
@property @property
@override
def save_dir(self) -> Optional[str]: def save_dir(self) -> Optional[str]:
"""Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save """Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save
locally. locally.
@ -491,6 +496,7 @@ class NeptuneLogger(Logger):
content=model_str, extension="txt" content=model_str, extension="txt"
) )
@override
@rank_zero_only @rank_zero_only
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
@ -578,11 +584,13 @@ class NeptuneLogger(Logger):
yield from cls._dict_paths(v, path) yield from cls._dict_paths(v, path)
@property @property
@override
def name(self) -> Optional[str]: def name(self) -> Optional[str]:
"""Return the experiment name or 'offline-name' when exp is run in offline mode.""" """Return the experiment name or 'offline-name' when exp is run in offline mode."""
return self._run_name return self._run_name
@property @property
@override
def version(self) -> Optional[str]: def version(self) -> Optional[str]:
"""Return the experiment version. """Return the experiment version.

View File

@ -22,6 +22,7 @@ from argparse import Namespace
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
from torch import Tensor from torch import Tensor
from typing_extensions import override
import lightning.pytorch as pl import lightning.pytorch as pl
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE
@ -112,6 +113,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
self.hparams: Union[Dict[str, Any], Namespace] = {} self.hparams: Union[Dict[str, Any], Namespace] = {}
@property @property
@override
def root_dir(self) -> str: def root_dir(self) -> str:
"""Parent directory for all tensorboard checkpoint subdirectories. """Parent directory for all tensorboard checkpoint subdirectories.
@ -122,6 +124,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
return os.path.join(super().root_dir, self.name) return os.path.join(super().root_dir, self.name)
@property @property
@override
def log_dir(self) -> str: def log_dir(self) -> str:
"""The directory for this run's tensorboard checkpoint. """The directory for this run's tensorboard checkpoint.
@ -139,6 +142,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
return log_dir return log_dir
@property @property
@override
def save_dir(self) -> str: def save_dir(self) -> str:
"""Gets the save directory where the TensorBoard experiments are saved. """Gets the save directory where the TensorBoard experiments are saved.
@ -148,6 +152,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
""" """
return self._root_dir return self._root_dir
@override
@rank_zero_only @rank_zero_only
def log_hyperparams( # type: ignore[override] def log_hyperparams( # type: ignore[override]
self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None
@ -174,6 +179,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
return super().log_hyperparams(params=params, metrics=metrics) return super().log_hyperparams(params=params, metrics=metrics)
@override
@rank_zero_only @rank_zero_only
def log_graph( # type: ignore[override] def log_graph( # type: ignore[override]
self, model: "pl.LightningModule", input_array: Optional[Tensor] = None self, model: "pl.LightningModule", input_array: Optional[Tensor] = None
@ -200,6 +206,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
with pl.core.module._jit_is_scripting(): with pl.core.module._jit_is_scripting():
self.experiment.add_graph(model, input_array) self.experiment.add_graph(model, input_array)
@override
@rank_zero_only @rank_zero_only
def save(self) -> None: def save(self) -> None:
super().save() super().save()
@ -212,6 +219,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
if _is_dir(self._fs, dir_path) and not self._fs.isfile(hparams_file): if _is_dir(self._fs, dir_path) and not self._fs.isfile(hparams_file):
save_hparams_to_yaml(hparams_file, self.hparams) save_hparams_to_yaml(hparams_file, self.hparams)
@override
@rank_zero_only @rank_zero_only
def finalize(self, status: str) -> None: def finalize(self, status: str) -> None:
super().finalize(status) super().finalize(status)
@ -219,6 +227,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
# saving hparams happens independent of experiment manager # saving hparams happens independent of experiment manager
self.save() self.save()
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
"""Called after model checkpoint callback saves a new checkpoint. """Called after model checkpoint callback saves a new checkpoint.
@ -228,6 +237,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger):
""" """
pass pass
@override
def _get_next_version(self) -> int: def _get_next_version(self) -> int:
root_dir = self.root_dir root_dir = self.root_dir

View File

@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, U
import torch.nn as nn import torch.nn as nn
from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.imports import RequirementCache
from torch import Tensor from torch import Tensor
from typing_extensions import override
from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from lightning.fabric.utilities.types import _PATH from lightning.fabric.utilities.types import _PATH
@ -410,12 +411,14 @@ class WandbLogger(Logger):
def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None: def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None:
self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph)
@override
@rank_zero_only @rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override]
params = _convert_params(params) params = _convert_params(params)
params = _sanitize_callable_params(params) params = _sanitize_callable_params(params)
self.experiment.config.update(params, allow_val_change=True) self.experiment.config.update(params, allow_val_change=True)
@override
@rank_zero_only @rank_zero_only
def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
@ -483,6 +486,7 @@ class WandbLogger(Logger):
self.log_metrics(metrics, step) # type: ignore[arg-type] self.log_metrics(metrics, step) # type: ignore[arg-type]
@property @property
@override
def save_dir(self) -> Optional[str]: def save_dir(self) -> Optional[str]:
"""Gets the save directory. """Gets the save directory.
@ -493,6 +497,7 @@ class WandbLogger(Logger):
return self._save_dir return self._save_dir
@property @property
@override
def name(self) -> Optional[str]: def name(self) -> Optional[str]:
"""The project name of this experiment. """The project name of this experiment.
@ -504,6 +509,7 @@ class WandbLogger(Logger):
return self._project return self._project
@property @property
@override
def version(self) -> Optional[str]: def version(self) -> Optional[str]:
"""Gets the id of the experiment. """Gets the id of the experiment.
@ -514,6 +520,7 @@ class WandbLogger(Logger):
# don't create an experiment if we don't have one # don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id return self._experiment.id if self._experiment else self._id
@override
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
# log checkpoints as artifacts # log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
@ -565,6 +572,7 @@ class WandbLogger(Logger):
""" """
return self.experiment.use_artifact(artifact, type=artifact_type) return self.experiment.use_artifact(artifact, type=artifact_type)
@override
@rank_zero_only @rank_zero_only
def finalize(self, status: str) -> None: def finalize(self, status: str) -> None:
if status != "success": if status != "success":