Fix mypy errors attributed to `pytorch_lightning.loggers.csv_logs.py` (#13538)

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
Justin Goheen 2022-07-12 07:11:31 -04:00 committed by GitHub
parent 6f51932449
commit 0a6dc5239a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 12 deletions

View File

@ -61,7 +61,6 @@ module = [
"pytorch_lightning.distributed.dist",
"pytorch_lightning.loggers.base",
"pytorch_lightning.loggers.comet",
"pytorch_lightning.loggers.csv_logs",
"pytorch_lightning.loggers.mlflow",
"pytorch_lightning.loggers.neptune",
"pytorch_lightning.loggers.tensorboard",

View File

@ -22,7 +22,7 @@ import csv
import logging
import os
from argparse import Namespace
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
from torch import Tensor
@ -49,8 +49,8 @@ class ExperimentWriter:
NAME_METRICS_FILE = "metrics.csv"
def __init__(self, log_dir: str) -> None:
self.hparams = {}
self.metrics = []
self.hparams: Dict[str, Any] = {}
self.metrics: List[Dict[str, float]] = []
self.log_dir = log_dir
if os.path.exists(self.log_dir) and os.listdir(self.log_dir):
@ -69,7 +69,7 @@ class ExperimentWriter:
def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None:
"""Record metrics."""
def _handle_value(value):
def _handle_value(value: Union[Tensor, Any]) -> Any:
if isinstance(value, Tensor):
return value.item()
return value
@ -126,7 +126,7 @@ class CSVLogger(Logger):
def __init__(
self,
save_dir: str,
name: Optional[str] = "lightning_logs",
name: str = "lightning_logs",
version: Optional[Union[int, str]] = None,
prefix: str = "",
flush_logs_every_n_steps: int = 100,
@ -136,7 +136,7 @@ class CSVLogger(Logger):
self._name = name or ""
self._version = version
self._prefix = prefix
self._experiment = None
self._experiment: Optional[ExperimentWriter] = None
self._flush_logs_every_n_steps = flush_logs_every_n_steps
@property
@ -161,7 +161,7 @@ class CSVLogger(Logger):
return log_dir
@property
def save_dir(self) -> Optional[str]:
def save_dir(self) -> str:
"""The current directory where logs are saved.
Returns:
@ -169,7 +169,7 @@ class CSVLogger(Logger):
"""
return self._save_dir
@property
@property # type: ignore[misc]
@rank_zero_experiment
def experiment(self) -> ExperimentWriter:
r"""
@ -182,7 +182,7 @@ class CSVLogger(Logger):
self.logger.experiment.some_experiment_writer_function()
"""
if self._experiment:
if self._experiment is not None:
return self._experiment
os.makedirs(self.root_dir, exist_ok=True)
@ -220,7 +220,7 @@ class CSVLogger(Logger):
return self._name
@property
def version(self) -> int:
def version(self) -> Union[int, str]:
"""Gets the version of the experiment.
Returns:
@ -230,7 +230,7 @@ class CSVLogger(Logger):
self._version = self._get_next_version()
return self._version
def _get_next_version(self):
def _get_next_version(self) -> int:
root_dir = self.root_dir
if not os.path.isdir(root_dir):