Fix mypy errors attributed to `pytorch_lightning.loggers.csv_logs.py` (#13538)
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
parent
6f51932449
commit
0a6dc5239a
|
@ -61,7 +61,6 @@ module = [
|
|||
"pytorch_lightning.distributed.dist",
|
||||
"pytorch_lightning.loggers.base",
|
||||
"pytorch_lightning.loggers.comet",
|
||||
"pytorch_lightning.loggers.csv_logs",
|
||||
"pytorch_lightning.loggers.mlflow",
|
||||
"pytorch_lightning.loggers.neptune",
|
||||
"pytorch_lightning.loggers.tensorboard",
|
||||
|
|
|
@ -22,7 +22,7 @@ import csv
|
|||
import logging
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -49,8 +49,8 @@ class ExperimentWriter:
|
|||
NAME_METRICS_FILE = "metrics.csv"
|
||||
|
||||
def __init__(self, log_dir: str) -> None:
|
||||
self.hparams = {}
|
||||
self.metrics = []
|
||||
self.hparams: Dict[str, Any] = {}
|
||||
self.metrics: List[Dict[str, float]] = []
|
||||
|
||||
self.log_dir = log_dir
|
||||
if os.path.exists(self.log_dir) and os.listdir(self.log_dir):
|
||||
|
@ -69,7 +69,7 @@ class ExperimentWriter:
|
|||
def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None:
|
||||
"""Record metrics."""
|
||||
|
||||
def _handle_value(value):
|
||||
def _handle_value(value: Union[Tensor, Any]) -> Any:
|
||||
if isinstance(value, Tensor):
|
||||
return value.item()
|
||||
return value
|
||||
|
@ -126,7 +126,7 @@ class CSVLogger(Logger):
|
|||
def __init__(
|
||||
self,
|
||||
save_dir: str,
|
||||
name: Optional[str] = "lightning_logs",
|
||||
name: str = "lightning_logs",
|
||||
version: Optional[Union[int, str]] = None,
|
||||
prefix: str = "",
|
||||
flush_logs_every_n_steps: int = 100,
|
||||
|
@ -136,7 +136,7 @@ class CSVLogger(Logger):
|
|||
self._name = name or ""
|
||||
self._version = version
|
||||
self._prefix = prefix
|
||||
self._experiment = None
|
||||
self._experiment: Optional[ExperimentWriter] = None
|
||||
self._flush_logs_every_n_steps = flush_logs_every_n_steps
|
||||
|
||||
@property
|
||||
|
@ -161,7 +161,7 @@ class CSVLogger(Logger):
|
|||
return log_dir
|
||||
|
||||
@property
|
||||
def save_dir(self) -> Optional[str]:
|
||||
def save_dir(self) -> str:
|
||||
"""The current directory where logs are saved.
|
||||
|
||||
Returns:
|
||||
|
@ -169,7 +169,7 @@ class CSVLogger(Logger):
|
|||
"""
|
||||
return self._save_dir
|
||||
|
||||
@property
|
||||
@property # type: ignore[misc]
|
||||
@rank_zero_experiment
|
||||
def experiment(self) -> ExperimentWriter:
|
||||
r"""
|
||||
|
@ -182,7 +182,7 @@ class CSVLogger(Logger):
|
|||
self.logger.experiment.some_experiment_writer_function()
|
||||
|
||||
"""
|
||||
if self._experiment:
|
||||
if self._experiment is not None:
|
||||
return self._experiment
|
||||
|
||||
os.makedirs(self.root_dir, exist_ok=True)
|
||||
|
@ -220,7 +220,7 @@ class CSVLogger(Logger):
|
|||
return self._name
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
def version(self) -> Union[int, str]:
|
||||
"""Gets the version of the experiment.
|
||||
|
||||
Returns:
|
||||
|
@ -230,7 +230,7 @@ class CSVLogger(Logger):
|
|||
self._version = self._get_next_version()
|
||||
return self._version
|
||||
|
||||
def _get_next_version(self):
|
||||
def _get_next_version(self) -> int:
|
||||
root_dir = self.root_dir
|
||||
|
||||
if not os.path.isdir(root_dir):
|
||||
|
|
Loading…
Reference in New Issue