From 0a6dc5239a67f1fd9176a762b027e0ed82418d3f Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 12 Jul 2022 07:11:31 -0400 Subject: [PATCH] Fix mypy errors attributed to `pytorch_lightning.loggers.csv_logs.py` (#13538) Co-authored-by: Akihiro Nitta --- pyproject.toml | 1 - src/pytorch_lightning/loggers/csv_logs.py | 22 +++++++++++----------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 770f0983c3..ba18f63aba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,6 @@ module = [ "pytorch_lightning.distributed.dist", "pytorch_lightning.loggers.base", "pytorch_lightning.loggers.comet", - "pytorch_lightning.loggers.csv_logs", "pytorch_lightning.loggers.mlflow", "pytorch_lightning.loggers.neptune", "pytorch_lightning.loggers.tensorboard", diff --git a/src/pytorch_lightning/loggers/csv_logs.py b/src/pytorch_lightning/loggers/csv_logs.py index 3316a5e86e..72d21ae2c4 100644 --- a/src/pytorch_lightning/loggers/csv_logs.py +++ b/src/pytorch_lightning/loggers/csv_logs.py @@ -22,7 +22,7 @@ import csv import logging import os from argparse import Namespace -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from torch import Tensor @@ -49,8 +49,8 @@ class ExperimentWriter: NAME_METRICS_FILE = "metrics.csv" def __init__(self, log_dir: str) -> None: - self.hparams = {} - self.metrics = [] + self.hparams: Dict[str, Any] = {} + self.metrics: List[Dict[str, float]] = [] self.log_dir = log_dir if os.path.exists(self.log_dir) and os.listdir(self.log_dir): @@ -69,7 +69,7 @@ class ExperimentWriter: def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: """Record metrics.""" - def _handle_value(value): + def _handle_value(value: Union[Tensor, Any]) -> Any: if isinstance(value, Tensor): return value.item() return value @@ -126,7 +126,7 @@ class CSVLogger(Logger): def __init__( self, save_dir: str, - name: Optional[str] = "lightning_logs", + name: str = "lightning_logs", version: Optional[Union[int, str]] = None, prefix: str = "", flush_logs_every_n_steps: int = 100, @@ -136,7 +136,7 @@ class CSVLogger(Logger): self._name = name or "" self._version = version self._prefix = prefix - self._experiment = None + self._experiment: Optional[ExperimentWriter] = None self._flush_logs_every_n_steps = flush_logs_every_n_steps @property @@ -161,7 +161,7 @@ class CSVLogger(Logger): return log_dir @property - def save_dir(self) -> Optional[str]: + def save_dir(self) -> str: """The current directory where logs are saved. Returns: @@ -169,7 +169,7 @@ class CSVLogger(Logger): """ return self._save_dir - @property + @property # type: ignore[misc] @rank_zero_experiment def experiment(self) -> ExperimentWriter: r""" @@ -182,7 +182,7 @@ class CSVLogger(Logger): self.logger.experiment.some_experiment_writer_function() """ - if self._experiment: + if self._experiment is not None: return self._experiment os.makedirs(self.root_dir, exist_ok=True) @@ -220,7 +220,7 @@ class CSVLogger(Logger): return self._name @property - def version(self) -> int: + def version(self) -> Union[int, str]: """Gets the version of the experiment. Returns: @@ -230,7 +230,7 @@ class CSVLogger(Logger): self._version = self._get_next_version() return self._version - def _get_next_version(self): + def _get_next_version(self) -> int: root_dir = self.root_dir if not os.path.isdir(root_dir):