lightning/pytorch_lightning/loggers/base.py

139 lines
3.7 KiB
Python
Raw Normal View History

import argparse
from abc import ABC, abstractmethod
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List
def rank_zero_only(fn: Callable):
"""Decorate a logger method to run it only on the process with rank 0.
Args:
fn: Function to decorate
"""
@wraps(fn)
def wrapped_fn(self, *args, **kwargs):
if self.rank == 0:
fn(self, *args, **kwargs)
return wrapped_fn
class LightningLoggerBase(ABC):
"""Base class for experiment loggers."""
def __init__(self):
self._rank = 0
@property
@abstractmethod
def experiment(self) -> Any:
"""Return the experiment object associated with this logger"""
pass
@abstractmethod
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Record metrics.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
pass
@abstractmethod
def log_hyperparams(self, params: argparse.Namespace):
"""Record hyperparameters.
Args:
params: argparse.Namespace containing the hyperparameters
"""
pass
def save(self):
"""Save log data."""
pass
def finalize(self, status: str):
"""Do any processing that is necessary to finalize an experiment.
Args:
status: Status that the experiment finished with (e.g. success, failed, aborted)
"""
pass
def close(self):
"""Do any cleanup that is necessary to close an experiment."""
pass
@property
def rank(self) -> int:
"""Process rank. In general, metrics should only be logged by the process with rank 0."""
return self._rank
@rank.setter
def rank(self, value: int):
"""Set the process rank."""
self._rank = value
@property
@abstractmethod
def name(self) -> str:
"""Return the experiment name."""
pass
@property
@abstractmethod
def version(self) -> Union[int, str]:
"""Return the experiment version."""
pass
class LoggerCollection(LightningLoggerBase):
"""The `LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`.
Args:
logger_iterable: An iterable collection of loggers
"""
def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
super().__init__()
self._logger_iterable = logger_iterable
@property
def experiment(self) -> List[Any]:
return [logger.experiment() for logger in self._logger_iterable]
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
[logger.log_metrics(metrics, step) for logger in self._logger_iterable]
def log_hyperparams(self, params: argparse.Namespace):
[logger.log_hyperparams(params) for logger in self._logger_iterable]
def save(self):
[logger.save() for logger in self._logger_iterable]
def finalize(self, status: str):
[logger.finalize(status) for logger in self._logger_iterable]
def close(self):
[logger.close() for logger in self._logger_iterable]
@property
def rank(self) -> int:
return self._rank
@rank.setter
def rank(self, value: int):
self._rank = value
for logger in self._logger_iterable:
logger.rank = value
@property
def name(self) -> str:
return '_'.join([str(logger.name) for logger in self._logger_iterable])
@property
def version(self) -> str:
return '_'.join([str(logger.version) for logger in self._logger_iterable])