Refactor codebase to use `trainer.loggers` over `trainer.logger` when needed (#11920)

This commit is contained in:
Akash Kwatra 2022-02-25 16:01:04 -08:00 committed by GitHub
parent 244f365fae
commit 7e2f9fbad5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 184 additions and 114 deletions

View File

@ -206,7 +206,8 @@ class GAN(LightningModule):
# log sampled images
sample_imgs = self(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
for logger in self.loggers:
logger.experiment.add_image("generated_images", grid, self.current_epoch)
def main(args: Namespace) -> None:

View File

@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback):
"""
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")
def on_train_batch_start(
@ -55,7 +55,7 @@ class DeviceStatsMonitor(Callback):
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
if not trainer.logger_connector.should_update_logs:
@ -63,9 +63,10 @@ class DeviceStatsMonitor(Callback):
device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
def on_train_batch_end(
self,
@ -76,7 +77,7 @@ class DeviceStatsMonitor(Callback):
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
if not trainer.logger_connector.should_update_logs:
@ -84,9 +85,10 @@ class DeviceStatsMonitor(Callback):
device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:

View File

@ -123,7 +123,7 @@ class GPUStatsMonitor(Callback):
self._gpu_ids: List[str] = [] # will be assigned later in setup()
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")
if trainer.strategy.root_device.type != "cuda":
@ -161,8 +161,8 @@ class GPUStatsMonitor(Callback):
# First log at beginning of second step
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000
assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)
@rank_zero_only
def on_train_batch_end(
@ -186,8 +186,8 @@ class GPUStatsMonitor(Callback):
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000
assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)
@staticmethod
def _get_gpu_ids(device_ids: List[int]) -> List[str]:

View File

@ -104,7 +104,7 @@ class LearningRateMonitor(Callback):
MisconfigurationException:
If ``Trainer`` has no ``logger``.
"""
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException(
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
)
@ -149,7 +149,6 @@ class LearningRateMonitor(Callback):
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}
def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if not trainer.logger_connector.should_update_logs:
return
@ -158,16 +157,17 @@ class LearningRateMonitor(Callback):
latest_stat = self._extract_stats(trainer, interval)
if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)
def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if self.logging_interval != "step":
interval = "epoch" if self.logging_interval is None else "any"
latest_stat = self._extract_stats(trainer, interval)
if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)
def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
latest_stat = {}

View File

@ -35,6 +35,7 @@ import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
@ -379,8 +380,9 @@ class ModelCheckpoint(Callback):
self._save_last_checkpoint(trainer, monitor_candidates)
# notify loggers
if trainer.is_global_zero and trainer.logger:
trainer.logger.after_save_checkpoint(proxy(self))
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))
def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn
@ -572,20 +574,20 @@ class ModelCheckpoint(Callback):
"""
if self.dirpath is not None:
return # short circuit
if trainer.logger is not None:
if trainer.loggers:
if trainer.weights_save_path != trainer.default_root_dir:
# the user has changed weights_save_path, it overrides anything
save_dir = trainer.weights_save_path
else:
elif len(trainer.loggers) == 1:
save_dir = trainer.logger.save_dir or trainer.default_root_dir
else:
save_dir = trainer.default_root_dir
version = (
trainer.logger.version
if isinstance(trainer.logger.version, str)
else f"version_{trainer.logger.version}"
)
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
name = _name(trainer.loggers)
version = _version(trainer.loggers)
version = version if isinstance(version, str) else f"version_{version}"
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

View File

@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Union
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.logger import _version
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
@ -213,11 +214,12 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
if pl_module.truncated_bptt_steps > 0:
items_dict["split_idx"] = trainer.fit_loop.split_idx
if trainer.logger is not None and trainer.logger.version is not None:
version = trainer.logger.version
if isinstance(version, str):
# show last 4 places of long version strings
version = version[-4:]
items_dict["v_num"] = version
if trainer.loggers:
version = _version(trainer.loggers)
if version is not None:
if isinstance(version, str):
# show last 4 places of long version strings
version = version[-4:]
items_dict["v_num"] = version
return items_dict

View File

@ -70,7 +70,7 @@ class XLAStatsMonitor(Callback):
self._verbose = verbose
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
if isinstance(trainer.accelerator, TPUAccelerator):
@ -88,7 +88,7 @@ class XLAStatsMonitor(Callback):
self._start_time = time.time()
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
device = trainer.strategy.root_device
@ -102,10 +102,11 @@ class XLAStatsMonitor(Callback):
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
epoch_time = trainer.strategy.reduce(epoch_time)
trainer.logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)
for logger in trainer.loggers:
logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)
if self._verbose:
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")

View File

@ -253,7 +253,7 @@ class LightningModule(
@property
def loggers(self) -> List[LightningLoggerBase]:
"""Reference to the loggers object in the Trainer."""
"""Reference to the list of loggers in the Trainer."""
return self.trainer.loggers if self.trainer else []
def _apply_batch_transfer_handler(

View File

@ -504,8 +504,10 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
"""Flushes loggers to disk."""
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()
# TODO: is_global_zero check should be moved to logger.save() implementation
if should_flush_logs and self.trainer.is_global_zero:
for logger in self.trainer.loggers:
logger.save()
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None:
if self._dataloader_state_dict:

View File

@ -155,7 +155,11 @@ class PrecisionPlugin(CheckpointHooks):
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {}
kwargs = {}
if len(trainer.loggers) == 1:
kwargs["group_separator"] = trainer.loggers[0].group_separator
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
if grad_norm_dict:
prev_fx = trainer.lightning_module._current_fx_name

View File

@ -17,7 +17,7 @@ import torch
import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage
@ -90,15 +90,15 @@ class LoggerConnector:
def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None:
if isinstance(logger, bool):
# default logger
self.trainer.logger = (
TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())
self.trainer.loggers = (
[TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())]
if logger
else None
else []
)
elif isinstance(logger, Iterable):
self.trainer.logger = LoggerCollection(logger)
self.trainer.loggers = list(logger)
else:
self.trainer.logger = logger
self.trainer.loggers = [logger]
def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
"""Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses
@ -109,7 +109,7 @@ class LoggerConnector:
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
the total validation / test log step count during validation and testing.
"""
if self.trainer.logger is None or not metrics:
if not self.trainer.loggers or not metrics:
return
self._logged_metrics.update(metrics)
@ -126,11 +126,12 @@ class LoggerConnector:
step = self.trainer.global_step
# log actual metrics
if self._override_agg_and_log_metrics:
self.trainer.logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
else:
self.trainer.logger.log_metrics(metrics=scalar_metrics, step=step)
self.trainer.logger.save()
for logger in self.trainer.loggers:
if self._override_agg_and_log_metrics:
logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
else:
logger.log_metrics(metrics=scalar_metrics, step=step)
logger.save()
"""
Evaluation metric updates

View File

@ -66,8 +66,8 @@ class SignalConnector:
rank_zero_info("handling SIGUSR1")
# save logger to make sure we get all the metrics
if self.trainer.logger:
self.trainer.logger.finalize("finished")
for logger in self.trainer.loggers:
logger.finalize("finished")
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.weights_save_path)
self.trainer.save_checkpoint(hpc_save_path)

View File

@ -657,7 +657,7 @@ class Trainer(
self.fit_loop.max_epochs = 1
val_check_interval = 1.0
self.check_val_every_n_epoch = 1
self.logger = DummyLogger() if self.logger is not None else None
self.loggers = [DummyLogger()] if self.loggers else []
rank_zero_info(
"Running in fast_dev_run mode: will run a full train,"
@ -1246,41 +1246,43 @@ class Trainer(
return results
def _log_hyperparams(self) -> None:
if not self.loggers:
return
# log hyper-parameters
hparams_initial = None
if self.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False
# save exp to get started (this is where the first experiment logs are written)
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = self.datamodule.hparams_initial
lightning_hparams = self.lightning_module.hparams_initial
inconsistent_keys = []
for key in lightning_hparams.keys() & datamodule_hparams.keys():
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
if type(lm_val) != type(dm_val):
inconsistent_keys.append(key)
elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val):
inconsistent_keys.append(key)
elif lm_val != dm_val:
inconsistent_keys.append(key)
if inconsistent_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {inconsistent_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams "
"but have different values."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif self.lightning_module._log_hyperparams:
hparams_initial = self.lightning_module.hparams_initial
elif datamodule_log_hyperparams:
hparams_initial = self.datamodule.hparams_initial
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = self.datamodule.hparams_initial
lightning_hparams = self.lightning_module.hparams_initial
inconsistent_keys = []
for key in lightning_hparams.keys() & datamodule_hparams.keys():
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
if type(lm_val) != type(dm_val):
inconsistent_keys.append(key)
elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val):
inconsistent_keys.append(key)
elif lm_val != dm_val:
inconsistent_keys.append(key)
if inconsistent_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {inconsistent_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams "
"but have different values."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif self.lightning_module._log_hyperparams:
hparams_initial = self.lightning_module.hparams_initial
elif datamodule_log_hyperparams:
hparams_initial = self.datamodule.hparams_initial
for logger in self.loggers:
if hparams_initial is not None:
self.logger.log_hyperparams(hparams_initial)
self.logger.log_graph(self.lightning_module)
self.logger.save()
logger.log_hyperparams(hparams_initial)
logger.log_graph(self.lightning_module)
logger.save()
def _teardown(self):
"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
@ -1519,8 +1521,8 @@ class Trainer(
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu kill loggers.
if self.logger is not None:
self.logger.finalize("success")
for logger in self.loggers:
logger.finalize("success")
# summarize profile results
self.profiler.describe()
@ -1890,7 +1892,7 @@ class Trainer(
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
if self.logger and self.num_training_batches < self.log_every_n_steps:
if self.loggers and self.num_training_batches < self.log_every_n_steps:
rank_zero_warn(
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
@ -2137,14 +2139,13 @@ class Trainer(
@property
def log_dir(self) -> Optional[str]:
if self.logger is None:
dirpath = self.default_root_dir
elif isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
elif isinstance(self.logger, LoggerCollection):
dirpath = self.default_root_dir
if len(self.loggers) == 1:
if isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
else:
dirpath = self.logger.save_dir
else:
dirpath = self.logger.save_dir
dirpath = self.default_root_dir
dirpath = self.strategy.broadcast(dirpath)
return dirpath

View File

@ -109,7 +109,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N
trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times
trainer.fit_loop.max_steps = steps_per_trial # take few steps
trainer.logger = DummyLogger() if trainer.logger is not None else None
trainer.loggers = [DummyLogger()] if trainer.loggers else []
trainer.callbacks = [] # not needed before full run
trainer.limit_train_batches = 1.0

View File

@ -267,7 +267,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto
# Use special lr logger callback
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
# No logging
trainer.logger = DummyLogger() if trainer.logger is not None else None
trainer.loggers = [DummyLogger()] if trainer.loggers else []
# Max step set to number of iterations
trainer.fit_loop.max_steps = num_training

View File

@ -146,3 +146,19 @@ def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[
metrics = {f"{prefix}{separator}{k}": v for k, v in metrics.items()}
return metrics
def _name(loggers: List[Any], separator: str = "_") -> str:
if len(loggers) == 1:
return loggers[0].name
else:
# Concatenate names together, removing duplicates and preserving order
return separator.join(dict.fromkeys(str(logger.name) for logger in loggers))
def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]:
if len(loggers) == 1:
return loggers[0].version
else:
# Concatenate versions together, removing duplicates and preserving order
return separator.join(dict.fromkeys(str(logger.version) for logger in loggers))

View File

@ -1273,7 +1273,7 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir):
def test_last_global_step_saved():
# this should not save anything
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
trainer = Mock()
trainer = MagicMock()
trainer.callback_metrics = {"foo": 123}
model_checkpoint.save_checkpoint(trainer)
assert model_checkpoint._last_global_step_saved == -1

View File

@ -186,10 +186,11 @@ def test_multiple_loggers_pickle(tmpdir):
trainer = Trainer(logger=[logger1, logger2])
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0}, 0)
for logger in trainer2.loggers:
logger.log_metrics({"acc": 1.0}, 0)
assert trainer2.logger[0].metrics_logged == {"acc": 1.0}
assert trainer2.logger[1].metrics_logged == {"acc": 1.0}
for logger in trainer2.loggers:
assert logger.metrics_logged == {"acc": 1.0}
def test_adding_step_key(tmpdir):

View File

@ -24,8 +24,7 @@ import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -467,7 +466,7 @@ def test_pytorch_profiler_logger_collection(tmpdir):
model = BoringModel()
# Wrap the logger in a list so it becomes a LoggerCollection
logger = [TensorBoardLogger(save_dir=tmpdir), DummyLogger()]
logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
assert isinstance(trainer.logger, LoggerCollection)
trainer.fit(model)

View File

@ -15,8 +15,7 @@ import os
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
from tests.helpers.boring_model import BoringModel
@ -118,7 +117,7 @@ def test_logdir_logger_collection(tmpdir):
trainer = Trainer(
default_root_dir=default_root_dir,
max_steps=2,
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), DummyLogger()],
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)],
)
assert isinstance(trainer.logger, LoggerCollection)
assert trainer.log_dir == default_root_dir

View File

@ -30,18 +30,20 @@ def test_trainer_loggers_property():
# trainer.loggers should create a list of size 1
trainer = Trainer(logger=logger1)
assert trainer.logger == logger1
assert trainer.loggers == [logger1]
# trainer.loggers should be an empty list
trainer = Trainer(logger=False)
assert trainer.logger is None
assert trainer.loggers == []
# trainer.loggers should be a list of size 1 holding the default logger
trainer = Trainer(logger=True)
assert trainer.loggers == [trainer.logger]
assert type(trainer.loggers[0]) == TensorBoardLogger
assert isinstance(trainer.logger, TensorBoardLogger)
def test_trainer_loggers_setters():

View File

@ -17,12 +17,15 @@ import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities.logger import (
_add_prefix,
_convert_params,
_flatten_dict,
_name,
_sanitize_callable_params,
_sanitize_params,
_version,
)
@ -172,3 +175,37 @@ def test_add_prefix():
assert "prefix-metric2" not in metrics
assert metrics["prefix2_prefix-metric1"] == 1
assert metrics["prefix2_prefix-metric2"] == 2
def test_name(tmpdir):
"""Verify names of loggers are concatenated properly."""
logger1 = CSVLogger(tmpdir, name="foo")
logger2 = CSVLogger(tmpdir, name="bar")
logger3 = CSVLogger(tmpdir, name="foo")
logger4 = CSVLogger(tmpdir, name="baz")
loggers = [logger1, logger2, logger3, logger4]
name = _name([])
assert name == ""
name = _name([logger3])
assert name == "foo"
name = _name(loggers)
assert name == "foo_bar_baz"
name = _name(loggers, "-")
assert name == "foo-bar-baz"
def test_version(tmpdir):
"""Verify versions of loggers are concatenated properly."""
logger1 = CSVLogger(tmpdir, version=0)
logger2 = CSVLogger(tmpdir, version=2)
logger3 = CSVLogger(tmpdir, version=1)
logger4 = CSVLogger(tmpdir, version=0)
loggers = [logger1, logger2, logger3, logger4]
version = _version([])
assert version == ""
version = _version([logger3])
assert version == 1
version = _version(loggers)
assert version == "0_2_1"
version = _version(loggers, "-")
assert version == "0-2-1"