From 57f5e1858790e9dc2efd56711ad990ebc709dcbc Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 12 Jul 2022 07:45:21 -0400 Subject: [PATCH] Fix mypy errors attributed to `pytorch_lightning.loggers.base.py` (#13494) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rohit Gupta Co-authored-by: Carlos MocholĂ­ --- pyproject.toml | 1 - src/pytorch_lightning/loggers/base.py | 16 ++++++++++++---- src/pytorch_lightning/loggers/logger.py | 12 ++++++------ 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ba18f63aba..c6e3452784 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ module = [ "pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.distributed.dist", - "pytorch_lightning.loggers.base", "pytorch_lightning.loggers.comet", "pytorch_lightning.loggers.mlflow", "pytorch_lightning.loggers.neptune", diff --git a/src/pytorch_lightning/loggers/base.py b/src/pytorch_lightning/loggers/base.py index 1da0749e46..43c572e395 100644 --- a/src/pytorch_lightning/loggers/base.py +++ b/src/pytorch_lightning/loggers/base.py @@ -12,16 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable, Dict, Mapping, Optional, Sequence + +import numpy as np + import pytorch_lightning.loggers.logger as logger from pytorch_lightning.utilities.warnings import rank_zero_deprecation -def rank_zero_experiment(*args, **kwargs) -> None: # type: ignore[no-untyped-def] +def rank_zero_experiment(fn: Callable) -> Callable: rank_zero_deprecation( "The `pytorch_lightning.loggers.base.rank_zero_experiment` is deprecated in v1.7" " and will be removed in v1.9. Please use `pytorch_lightning.loggers.logger.rank_zero_experiment` instead." ) - return logger.rank_zero_experiment(*args, **kwargs) + return logger.rank_zero_experiment(fn) class LightningLoggerBase(logger.Logger): @@ -77,9 +81,13 @@ class DummyLogger(logger.DummyLogger): super().__init__(*args, **kwargs) -def merge_dicts(*args, **kwargs) -> None: # type: ignore[no-untyped-def] +def merge_dicts( + dicts: Sequence[Mapping], + agg_key_funcs: Optional[Mapping] = None, + default_func: Callable[[Sequence[float]], float] = np.mean, +) -> Dict: rank_zero_deprecation( "The `pytorch_lightning.loggers.base.merge_dicts` is deprecated in v1.7" " and will be removed in v1.9. Please use `pytorch_lightning.loggers.logger.merge_dicts` instead." ) - return logger.merge_dicts(*args, **kwargs) + return logger.merge_dicts(dicts=dicts, agg_key_funcs=agg_key_funcs, default_func=default_func) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 4113b61627..03d934aa58 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -38,13 +38,13 @@ def rank_zero_experiment(fn: Callable) -> Callable: def experiment(self) -> Union[Any, DummyExperiment]: # type: ignore[no-untyped-def] """ Note: - `self` is a custom logger instance. The loggers typical wrap an `experiment` method - with a @rank_zero_experiment decorator. An exception being `loggers.neptune` wraps - `experiment` and `run` with rank_zero_experiment. + ``self`` is a custom logger instance. The loggers typically wrap an ``experiment`` method + with a ``@rank_zero_experiment`` decorator. An exception is that ``loggers.neptune`` wraps + ``experiment`` and ``run`` with rank_zero_experiment. - Union[Any, DummyExperiment] is used because the wrapped hooks have several returns - types that are specific to the custom logger. The return type can be considered as - Union[return type of logger.experiment, DummyExperiment] + ``Union[Any, DummyExperiment]`` is used because the wrapped hooks have several return + types that are specific to the custom logger. The return type here can be considered as + ``Union[return type of logger.experiment, DummyExperiment]``. """ @rank_zero_only