From b30f1a995d58c2a1b5b75e98132276aa0514e850 Mon Sep 17 00:00:00 2001 From: Victor Prins Date: Mon, 6 Nov 2023 16:18:39 +0100 Subject: [PATCH] Add `@override` for subclasses of PyTorch `Logger` (#18948) --- src/lightning/pytorch/loggers/comet.py | 8 ++++++++ src/lightning/pytorch/loggers/csv_logs.py | 8 ++++++++ src/lightning/pytorch/loggers/logger.py | 5 +++++ src/lightning/pytorch/loggers/mlflow.py | 8 ++++++++ src/lightning/pytorch/loggers/neptune.py | 8 ++++++++ src/lightning/pytorch/loggers/tensorboard.py | 10 ++++++++++ src/lightning/pytorch/loggers/wandb.py | 8 ++++++++ 7 files changed, 55 insertions(+) diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 85d17300a2..ae19f579e2 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn import Module +from typing_extensions import override from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment @@ -302,12 +303,14 @@ class CometLogger(Logger): return self._experiment + @override @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] params = _convert_params(params) params = _flatten_dict(params) self.experiment.log_parameters(params) + @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" @@ -324,6 +327,7 @@ class CometLogger(Logger): def reset_experiment(self) -> None: self._experiment = None + @override @rank_zero_only def finalize(self, status: str) -> None: r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you @@ -342,6 +346,7 @@ class CometLogger(Logger): self.reset_experiment() @property + @override def save_dir(self) -> Optional[str]: """Gets the save directory. @@ -352,6 +357,7 @@ class CometLogger(Logger): return self._save_dir @property + @override def name(self) -> str: """Gets the project name. @@ -369,6 +375,7 @@ class CometLogger(Logger): return "comet-default" @property + @override def version(self) -> str: """Gets the version. @@ -417,6 +424,7 @@ class CometLogger(Logger): state["_experiment"] = None return state + @override def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: if self._experiment is not None: self._experiment.set_model_graph(model) diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index 9d5008eb09..16aeb3da2e 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -23,6 +23,8 @@ import os from argparse import Namespace from typing import Any, Dict, Optional, Union +from typing_extensions import override + from lightning.fabric.loggers.csv_logs import CSVLogger as FabricCSVLogger from lightning.fabric.loggers.csv_logs import _ExperimentWriter as _FabricExperimentWriter from lightning.fabric.loggers.logger import rank_zero_experiment @@ -58,6 +60,7 @@ class ExperimentWriter(_FabricExperimentWriter): """Record hparams.""" self.hparams.update(params) + @override def save(self) -> None: """Save recorded hparams and metrics into files.""" hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) @@ -106,6 +109,7 @@ class CSVLogger(Logger, FabricCSVLogger): self._save_dir = os.fspath(save_dir) @property + @override def root_dir(self) -> str: """Parent directory for all checkpoint subdirectories. @@ -116,6 +120,7 @@ class CSVLogger(Logger, FabricCSVLogger): return os.path.join(self.save_dir, self.name) @property + @override def log_dir(self) -> str: """The log directory for this run. @@ -128,6 +133,7 @@ class CSVLogger(Logger, FabricCSVLogger): return os.path.join(self.root_dir, version) @property + @override def save_dir(self) -> str: """The current directory where logs are saved. @@ -137,12 +143,14 @@ class CSVLogger(Logger, FabricCSVLogger): """ return self._save_dir + @override @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] params = _convert_params(params) self.experiment.log_hparams(params) @property + @override @rank_zero_experiment def experiment(self) -> _FabricExperimentWriter: r"""Actual _ExperimentWriter object. To use _ExperimentWriter features in your diff --git a/src/lightning/pytorch/loggers/logger.py b/src/lightning/pytorch/loggers/logger.py index 59ff16ac99..c7dc5cf41d 100644 --- a/src/lightning/pytorch/loggers/logger.py +++ b/src/lightning/pytorch/loggers/logger.py @@ -21,6 +21,7 @@ from collections import defaultdict from typing import Any, Callable, Dict, Mapping, Optional, Sequence import numpy as np +from typing_extensions import override from lightning.fabric.loggers import Logger as FabricLogger from lightning.fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility @@ -63,18 +64,22 @@ class DummyLogger(Logger): """Return the experiment object associated with this logger.""" return self._experiment + @override def log_metrics(self, *args: Any, **kwargs: Any) -> None: pass + @override def log_hyperparams(self, *args: Any, **kwargs: Any) -> None: pass @property + @override def name(self) -> str: """Return the experiment name.""" return "" @property + @override def version(self) -> str: """Return the experiment version.""" return "" diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index 917bfeeba0..9303781d1a 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, O import yaml from lightning_utilities.core.imports import RequirementCache from torch import Tensor +from typing_extensions import override from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint @@ -217,6 +218,7 @@ class MLFlowLogger(Logger): _ = self.experiment return self._experiment_id + @override @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] params = _convert_params(params) @@ -232,6 +234,7 @@ class MLFlowLogger(Logger): for idx in range(0, len(params_list), 100): self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100]) + @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" @@ -259,6 +262,7 @@ class MLFlowLogger(Logger): self.experiment.log_batch(run_id=self.run_id, metrics=metrics_list) + @override @rank_zero_only def finalize(self, status: str = "success") -> None: if not self._initialized: @@ -278,6 +282,7 @@ class MLFlowLogger(Logger): self.experiment.set_terminated(self.run_id, status) @property + @override def save_dir(self) -> Optional[str]: """The root file directory in which MLflow experiments are saved. @@ -291,6 +296,7 @@ class MLFlowLogger(Logger): return None @property + @override def name(self) -> Optional[str]: """Get the experiment id. @@ -301,6 +307,7 @@ class MLFlowLogger(Logger): return self.experiment_id @property + @override def version(self) -> Optional[str]: """Get the run id. @@ -310,6 +317,7 @@ class MLFlowLogger(Logger): """ return self.run_id + @override def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: # log checkpoints as artifacts if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index 9ce09d7010..d8f8b36251 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Set, Uni from lightning_utilities.core.imports import RequirementCache from torch import Tensor +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params @@ -387,6 +388,7 @@ class NeptuneLogger(Logger): return self._run_instance + @override @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] r"""Log hyperparameters to the run. @@ -436,6 +438,7 @@ class NeptuneLogger(Logger): self.run[parameters_key] = stringify_unsupported(params) + @override @rank_zero_only def log_metrics( # type: ignore[override] self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None @@ -457,6 +460,7 @@ class NeptuneLogger(Logger): # Lightning does not always guarantee. self.run[key].append(val) + @override @rank_zero_only def finalize(self, status: str) -> None: if not self._run_instance: @@ -469,6 +473,7 @@ class NeptuneLogger(Logger): super().finalize(status) @property + @override def save_dir(self) -> Optional[str]: """Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save locally. @@ -491,6 +496,7 @@ class NeptuneLogger(Logger): content=model_str, extension="txt" ) + @override @rank_zero_only def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. @@ -578,11 +584,13 @@ class NeptuneLogger(Logger): yield from cls._dict_paths(v, path) @property + @override def name(self) -> Optional[str]: """Return the experiment name or 'offline-name' when exp is run in offline mode.""" return self._run_name @property + @override def version(self) -> Optional[str]: """Return the experiment version. diff --git a/src/lightning/pytorch/loggers/tensorboard.py b/src/lightning/pytorch/loggers/tensorboard.py index 3cda6a50bd..9755074132 100644 --- a/src/lightning/pytorch/loggers/tensorboard.py +++ b/src/lightning/pytorch/loggers/tensorboard.py @@ -22,6 +22,7 @@ from argparse import Namespace from typing import Any, Dict, Optional, Union from torch import Tensor +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE @@ -112,6 +113,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): self.hparams: Union[Dict[str, Any], Namespace] = {} @property + @override def root_dir(self) -> str: """Parent directory for all tensorboard checkpoint subdirectories. @@ -122,6 +124,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): return os.path.join(super().root_dir, self.name) @property + @override def log_dir(self) -> str: """The directory for this run's tensorboard checkpoint. @@ -139,6 +142,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): return log_dir @property + @override def save_dir(self) -> str: """Gets the save directory where the TensorBoard experiments are saved. @@ -148,6 +152,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): """ return self._root_dir + @override @rank_zero_only def log_hyperparams( # type: ignore[override] self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None @@ -174,6 +179,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): return super().log_hyperparams(params=params, metrics=metrics) + @override @rank_zero_only def log_graph( # type: ignore[override] self, model: "pl.LightningModule", input_array: Optional[Tensor] = None @@ -200,6 +206,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): with pl.core.module._jit_is_scripting(): self.experiment.add_graph(model, input_array) + @override @rank_zero_only def save(self) -> None: super().save() @@ -212,6 +219,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): if _is_dir(self._fs, dir_path) and not self._fs.isfile(hparams_file): save_hparams_to_yaml(hparams_file, self.hparams) + @override @rank_zero_only def finalize(self, status: str) -> None: super().finalize(status) @@ -219,6 +227,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): # saving hparams happens independent of experiment manager self.save() + @override def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: """Called after model checkpoint callback saves a new checkpoint. @@ -228,6 +237,7 @@ class TensorBoardLogger(Logger, FabricTensorBoardLogger): """ pass + @override def _get_next_version(self) -> int: root_dir = self.root_dir diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 1d0bfd5926..8cf5772b74 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, U import torch.nn as nn from lightning_utilities.core.imports import RequirementCache from torch import Tensor +from typing_extensions import override from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params from lightning.fabric.utilities.types import _PATH @@ -410,12 +411,14 @@ class WandbLogger(Logger): def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None: self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) + @override @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] params = _convert_params(params) params = _sanitize_callable_params(params) self.experiment.config.update(params, allow_val_change=True) + @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" @@ -483,6 +486,7 @@ class WandbLogger(Logger): self.log_metrics(metrics, step) # type: ignore[arg-type] @property + @override def save_dir(self) -> Optional[str]: """Gets the save directory. @@ -493,6 +497,7 @@ class WandbLogger(Logger): return self._save_dir @property + @override def name(self) -> Optional[str]: """The project name of this experiment. @@ -504,6 +509,7 @@ class WandbLogger(Logger): return self._project @property + @override def version(self) -> Optional[str]: """Gets the id of the experiment. @@ -514,6 +520,7 @@ class WandbLogger(Logger): # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id + @override def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: # log checkpoints as artifacts if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: @@ -565,6 +572,7 @@ class WandbLogger(Logger): """ return self.experiment.use_artifact(artifact, type=artifact_type) + @override @rank_zero_only def finalize(self, status: str) -> None: if status != "success":