Refactor codebase to use `trainer.loggers` over `trainer.logger` when needed (#11920)
This commit is contained in:
parent
244f365fae
commit
7e2f9fbad5
|
@ -206,7 +206,8 @@ class GAN(LightningModule):
|
|||
# log sampled images
|
||||
sample_imgs = self(z)
|
||||
grid = torchvision.utils.make_grid(sample_imgs)
|
||||
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
|
||||
for logger in self.loggers:
|
||||
logger.experiment.add_image("generated_images", grid, self.current_epoch)
|
||||
|
||||
|
||||
def main(args: Namespace) -> None:
|
||||
|
|
|
@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback):
|
|||
"""
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
if not trainer.logger:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")
|
||||
|
||||
def on_train_batch_start(
|
||||
|
@ -55,7 +55,7 @@ class DeviceStatsMonitor(Callback):
|
|||
batch_idx: int,
|
||||
unused: Optional[int] = 0,
|
||||
) -> None:
|
||||
if not trainer.logger:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
|
||||
|
||||
if not trainer.logger_connector.should_update_logs:
|
||||
|
@ -63,9 +63,10 @@ class DeviceStatsMonitor(Callback):
|
|||
|
||||
device = trainer.strategy.root_device
|
||||
device_stats = trainer.accelerator.get_device_stats(device)
|
||||
separator = trainer.logger.group_separator
|
||||
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
|
||||
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
|
||||
for logger in trainer.loggers:
|
||||
separator = logger.group_separator
|
||||
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
|
||||
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
|
||||
|
||||
def on_train_batch_end(
|
||||
self,
|
||||
|
@ -76,7 +77,7 @@ class DeviceStatsMonitor(Callback):
|
|||
batch_idx: int,
|
||||
unused: Optional[int] = 0,
|
||||
) -> None:
|
||||
if not trainer.logger:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
|
||||
|
||||
if not trainer.logger_connector.should_update_logs:
|
||||
|
@ -84,9 +85,10 @@ class DeviceStatsMonitor(Callback):
|
|||
|
||||
device = trainer.strategy.root_device
|
||||
device_stats = trainer.accelerator.get_device_stats(device)
|
||||
separator = trainer.logger.group_separator
|
||||
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
|
||||
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
|
||||
for logger in trainer.loggers:
|
||||
separator = logger.group_separator
|
||||
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
|
||||
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
|
||||
|
||||
|
||||
def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
|
||||
|
|
|
@ -123,7 +123,7 @@ class GPUStatsMonitor(Callback):
|
|||
self._gpu_ids: List[str] = [] # will be assigned later in setup()
|
||||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
if not trainer.logger:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")
|
||||
|
||||
if trainer.strategy.root_device.type != "cuda":
|
||||
|
@ -161,8 +161,8 @@ class GPUStatsMonitor(Callback):
|
|||
# First log at beginning of second step
|
||||
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000
|
||||
|
||||
assert trainer.logger is not None
|
||||
trainer.logger.log_metrics(logs, step=trainer.global_step)
|
||||
for logger in trainer.loggers:
|
||||
logger.log_metrics(logs, step=trainer.global_step)
|
||||
|
||||
@rank_zero_only
|
||||
def on_train_batch_end(
|
||||
|
@ -186,8 +186,8 @@ class GPUStatsMonitor(Callback):
|
|||
if self._log_stats.intra_step_time and self._snap_intra_step_time:
|
||||
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000
|
||||
|
||||
assert trainer.logger is not None
|
||||
trainer.logger.log_metrics(logs, step=trainer.global_step)
|
||||
for logger in trainer.loggers:
|
||||
logger.log_metrics(logs, step=trainer.global_step)
|
||||
|
||||
@staticmethod
|
||||
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
|
||||
|
|
|
@ -104,7 +104,7 @@ class LearningRateMonitor(Callback):
|
|||
MisconfigurationException:
|
||||
If ``Trainer`` has no ``logger``.
|
||||
"""
|
||||
if not trainer.logger:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException(
|
||||
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
|
||||
)
|
||||
|
@ -149,7 +149,6 @@ class LearningRateMonitor(Callback):
|
|||
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}
|
||||
|
||||
def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
||||
assert trainer.logger is not None
|
||||
if not trainer.logger_connector.should_update_logs:
|
||||
return
|
||||
|
||||
|
@ -158,16 +157,17 @@ class LearningRateMonitor(Callback):
|
|||
latest_stat = self._extract_stats(trainer, interval)
|
||||
|
||||
if latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||
for logger in trainer.loggers:
|
||||
logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||
|
||||
def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
||||
assert trainer.logger is not None
|
||||
if self.logging_interval != "step":
|
||||
interval = "epoch" if self.logging_interval is None else "any"
|
||||
latest_stat = self._extract_stats(trainer, interval)
|
||||
|
||||
if latest_stat:
|
||||
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||
for logger in trainer.loggers:
|
||||
logger.log_metrics(latest_stat, step=trainer.global_step)
|
||||
|
||||
def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
|
||||
latest_stat = {}
|
||||
|
|
|
@ -35,6 +35,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.logger import _name, _version
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
@ -379,8 +380,9 @@ class ModelCheckpoint(Callback):
|
|||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
# notify loggers
|
||||
if trainer.is_global_zero and trainer.logger:
|
||||
trainer.logger.after_save_checkpoint(proxy(self))
|
||||
if trainer.is_global_zero:
|
||||
for logger in trainer.loggers:
|
||||
logger.after_save_checkpoint(proxy(self))
|
||||
|
||||
def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
@ -572,20 +574,20 @@ class ModelCheckpoint(Callback):
|
|||
"""
|
||||
if self.dirpath is not None:
|
||||
return # short circuit
|
||||
|
||||
if trainer.logger is not None:
|
||||
if trainer.loggers:
|
||||
if trainer.weights_save_path != trainer.default_root_dir:
|
||||
# the user has changed weights_save_path, it overrides anything
|
||||
save_dir = trainer.weights_save_path
|
||||
else:
|
||||
elif len(trainer.loggers) == 1:
|
||||
save_dir = trainer.logger.save_dir or trainer.default_root_dir
|
||||
else:
|
||||
save_dir = trainer.default_root_dir
|
||||
|
||||
version = (
|
||||
trainer.logger.version
|
||||
if isinstance(trainer.logger.version, str)
|
||||
else f"version_{trainer.logger.version}"
|
||||
)
|
||||
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
|
||||
name = _name(trainer.loggers)
|
||||
version = _version(trainer.loggers)
|
||||
version = version if isinstance(version, str) else f"version_{version}"
|
||||
|
||||
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
|
||||
else:
|
||||
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Union
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities.logger import _version
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
|
||||
|
@ -213,11 +214,12 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
|
|||
if pl_module.truncated_bptt_steps > 0:
|
||||
items_dict["split_idx"] = trainer.fit_loop.split_idx
|
||||
|
||||
if trainer.logger is not None and trainer.logger.version is not None:
|
||||
version = trainer.logger.version
|
||||
if isinstance(version, str):
|
||||
# show last 4 places of long version strings
|
||||
version = version[-4:]
|
||||
items_dict["v_num"] = version
|
||||
if trainer.loggers:
|
||||
version = _version(trainer.loggers)
|
||||
if version is not None:
|
||||
if isinstance(version, str):
|
||||
# show last 4 places of long version strings
|
||||
version = version[-4:]
|
||||
items_dict["v_num"] = version
|
||||
|
||||
return items_dict
|
||||
|
|
|
@ -70,7 +70,7 @@ class XLAStatsMonitor(Callback):
|
|||
self._verbose = verbose
|
||||
|
||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not trainer.logger:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
|
||||
|
||||
if isinstance(trainer.accelerator, TPUAccelerator):
|
||||
|
@ -88,7 +88,7 @@ class XLAStatsMonitor(Callback):
|
|||
self._start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not trainer.logger:
|
||||
if not trainer.loggers:
|
||||
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
|
||||
|
||||
device = trainer.strategy.root_device
|
||||
|
@ -102,10 +102,11 @@ class XLAStatsMonitor(Callback):
|
|||
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
|
||||
epoch_time = trainer.strategy.reduce(epoch_time)
|
||||
|
||||
trainer.logger.log_metrics(
|
||||
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
|
||||
step=trainer.current_epoch,
|
||||
)
|
||||
for logger in trainer.loggers:
|
||||
logger.log_metrics(
|
||||
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
|
||||
step=trainer.current_epoch,
|
||||
)
|
||||
|
||||
if self._verbose:
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
|
|
|
@ -253,7 +253,7 @@ class LightningModule(
|
|||
|
||||
@property
|
||||
def loggers(self) -> List[LightningLoggerBase]:
|
||||
"""Reference to the loggers object in the Trainer."""
|
||||
"""Reference to the list of loggers in the Trainer."""
|
||||
return self.trainer.loggers if self.trainer else []
|
||||
|
||||
def _apply_batch_transfer_handler(
|
||||
|
|
|
@ -504,8 +504,10 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
"""Flushes loggers to disk."""
|
||||
# when loggers should save to disk
|
||||
should_flush_logs = self.trainer.logger_connector.should_flush_logs
|
||||
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
|
||||
self.trainer.logger.save()
|
||||
# TODO: is_global_zero check should be moved to logger.save() implementation
|
||||
if should_flush_logs and self.trainer.is_global_zero:
|
||||
for logger in self.trainer.loggers:
|
||||
logger.save()
|
||||
|
||||
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None:
|
||||
if self._dataloader_state_dict:
|
||||
|
|
|
@ -155,7 +155,11 @@ class PrecisionPlugin(CheckpointHooks):
|
|||
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
|
||||
if trainer.track_grad_norm == -1:
|
||||
return
|
||||
kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {}
|
||||
|
||||
kwargs = {}
|
||||
if len(trainer.loggers) == 1:
|
||||
kwargs["group_separator"] = trainer.loggers[0].group_separator
|
||||
|
||||
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
|
||||
if grad_norm_dict:
|
||||
prev_fx = trainer.lightning_module._current_fx_name
|
||||
|
|
|
@ -17,7 +17,7 @@ import torch
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.accelerators import GPUAccelerator
|
||||
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger
|
||||
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
|
@ -90,15 +90,15 @@ class LoggerConnector:
|
|||
def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None:
|
||||
if isinstance(logger, bool):
|
||||
# default logger
|
||||
self.trainer.logger = (
|
||||
TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())
|
||||
self.trainer.loggers = (
|
||||
[TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())]
|
||||
if logger
|
||||
else None
|
||||
else []
|
||||
)
|
||||
elif isinstance(logger, Iterable):
|
||||
self.trainer.logger = LoggerCollection(logger)
|
||||
self.trainer.loggers = list(logger)
|
||||
else:
|
||||
self.trainer.logger = logger
|
||||
self.trainer.loggers = [logger]
|
||||
|
||||
def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
|
||||
"""Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses
|
||||
|
@ -109,7 +109,7 @@ class LoggerConnector:
|
|||
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
|
||||
the total validation / test log step count during validation and testing.
|
||||
"""
|
||||
if self.trainer.logger is None or not metrics:
|
||||
if not self.trainer.loggers or not metrics:
|
||||
return
|
||||
|
||||
self._logged_metrics.update(metrics)
|
||||
|
@ -126,11 +126,12 @@ class LoggerConnector:
|
|||
step = self.trainer.global_step
|
||||
|
||||
# log actual metrics
|
||||
if self._override_agg_and_log_metrics:
|
||||
self.trainer.logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
|
||||
else:
|
||||
self.trainer.logger.log_metrics(metrics=scalar_metrics, step=step)
|
||||
self.trainer.logger.save()
|
||||
for logger in self.trainer.loggers:
|
||||
if self._override_agg_and_log_metrics:
|
||||
logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
|
||||
else:
|
||||
logger.log_metrics(metrics=scalar_metrics, step=step)
|
||||
logger.save()
|
||||
|
||||
"""
|
||||
Evaluation metric updates
|
||||
|
|
|
@ -66,8 +66,8 @@ class SignalConnector:
|
|||
rank_zero_info("handling SIGUSR1")
|
||||
|
||||
# save logger to make sure we get all the metrics
|
||||
if self.trainer.logger:
|
||||
self.trainer.logger.finalize("finished")
|
||||
for logger in self.trainer.loggers:
|
||||
logger.finalize("finished")
|
||||
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.weights_save_path)
|
||||
self.trainer.save_checkpoint(hpc_save_path)
|
||||
|
||||
|
|
|
@ -657,7 +657,7 @@ class Trainer(
|
|||
self.fit_loop.max_epochs = 1
|
||||
val_check_interval = 1.0
|
||||
self.check_val_every_n_epoch = 1
|
||||
self.logger = DummyLogger() if self.logger is not None else None
|
||||
self.loggers = [DummyLogger()] if self.loggers else []
|
||||
|
||||
rank_zero_info(
|
||||
"Running in fast_dev_run mode: will run a full train,"
|
||||
|
@ -1246,41 +1246,43 @@ class Trainer(
|
|||
return results
|
||||
|
||||
def _log_hyperparams(self) -> None:
|
||||
if not self.loggers:
|
||||
return
|
||||
# log hyper-parameters
|
||||
hparams_initial = None
|
||||
|
||||
if self.logger is not None:
|
||||
# save exp to get started (this is where the first experiment logs are written)
|
||||
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False
|
||||
# save exp to get started (this is where the first experiment logs are written)
|
||||
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False
|
||||
|
||||
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
|
||||
datamodule_hparams = self.datamodule.hparams_initial
|
||||
lightning_hparams = self.lightning_module.hparams_initial
|
||||
inconsistent_keys = []
|
||||
for key in lightning_hparams.keys() & datamodule_hparams.keys():
|
||||
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
|
||||
if type(lm_val) != type(dm_val):
|
||||
inconsistent_keys.append(key)
|
||||
elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val):
|
||||
inconsistent_keys.append(key)
|
||||
elif lm_val != dm_val:
|
||||
inconsistent_keys.append(key)
|
||||
if inconsistent_keys:
|
||||
raise MisconfigurationException(
|
||||
f"Error while merging hparams: the keys {inconsistent_keys} are present "
|
||||
"in both the LightningModule's and LightningDataModule's hparams "
|
||||
"but have different values."
|
||||
)
|
||||
hparams_initial = {**lightning_hparams, **datamodule_hparams}
|
||||
elif self.lightning_module._log_hyperparams:
|
||||
hparams_initial = self.lightning_module.hparams_initial
|
||||
elif datamodule_log_hyperparams:
|
||||
hparams_initial = self.datamodule.hparams_initial
|
||||
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
|
||||
datamodule_hparams = self.datamodule.hparams_initial
|
||||
lightning_hparams = self.lightning_module.hparams_initial
|
||||
inconsistent_keys = []
|
||||
for key in lightning_hparams.keys() & datamodule_hparams.keys():
|
||||
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
|
||||
if type(lm_val) != type(dm_val):
|
||||
inconsistent_keys.append(key)
|
||||
elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val):
|
||||
inconsistent_keys.append(key)
|
||||
elif lm_val != dm_val:
|
||||
inconsistent_keys.append(key)
|
||||
if inconsistent_keys:
|
||||
raise MisconfigurationException(
|
||||
f"Error while merging hparams: the keys {inconsistent_keys} are present "
|
||||
"in both the LightningModule's and LightningDataModule's hparams "
|
||||
"but have different values."
|
||||
)
|
||||
hparams_initial = {**lightning_hparams, **datamodule_hparams}
|
||||
elif self.lightning_module._log_hyperparams:
|
||||
hparams_initial = self.lightning_module.hparams_initial
|
||||
elif datamodule_log_hyperparams:
|
||||
hparams_initial = self.datamodule.hparams_initial
|
||||
|
||||
for logger in self.loggers:
|
||||
if hparams_initial is not None:
|
||||
self.logger.log_hyperparams(hparams_initial)
|
||||
self.logger.log_graph(self.lightning_module)
|
||||
self.logger.save()
|
||||
logger.log_hyperparams(hparams_initial)
|
||||
logger.log_graph(self.lightning_module)
|
||||
logger.save()
|
||||
|
||||
def _teardown(self):
|
||||
"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
|
||||
|
@ -1519,8 +1521,8 @@ class Trainer(
|
|||
|
||||
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
|
||||
# It might be related to xla tensors blocked when moving the cpu kill loggers.
|
||||
if self.logger is not None:
|
||||
self.logger.finalize("success")
|
||||
for logger in self.loggers:
|
||||
logger.finalize("success")
|
||||
|
||||
# summarize profile results
|
||||
self.profiler.describe()
|
||||
|
@ -1890,7 +1892,7 @@ class Trainer(
|
|||
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
||||
self.val_check_batch = max(1, self.val_check_batch)
|
||||
|
||||
if self.logger and self.num_training_batches < self.log_every_n_steps:
|
||||
if self.loggers and self.num_training_batches < self.log_every_n_steps:
|
||||
rank_zero_warn(
|
||||
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
|
||||
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
|
||||
|
@ -2137,14 +2139,13 @@ class Trainer(
|
|||
|
||||
@property
|
||||
def log_dir(self) -> Optional[str]:
|
||||
if self.logger is None:
|
||||
dirpath = self.default_root_dir
|
||||
elif isinstance(self.logger, TensorBoardLogger):
|
||||
dirpath = self.logger.log_dir
|
||||
elif isinstance(self.logger, LoggerCollection):
|
||||
dirpath = self.default_root_dir
|
||||
if len(self.loggers) == 1:
|
||||
if isinstance(self.logger, TensorBoardLogger):
|
||||
dirpath = self.logger.log_dir
|
||||
else:
|
||||
dirpath = self.logger.save_dir
|
||||
else:
|
||||
dirpath = self.logger.save_dir
|
||||
dirpath = self.default_root_dir
|
||||
|
||||
dirpath = self.strategy.broadcast(dirpath)
|
||||
return dirpath
|
||||
|
|
|
@ -109,7 +109,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N
|
|||
trainer.auto_scale_batch_size = None # prevent recursion
|
||||
trainer.auto_lr_find = False # avoid lr find being called multiple times
|
||||
trainer.fit_loop.max_steps = steps_per_trial # take few steps
|
||||
trainer.logger = DummyLogger() if trainer.logger is not None else None
|
||||
trainer.loggers = [DummyLogger()] if trainer.loggers else []
|
||||
trainer.callbacks = [] # not needed before full run
|
||||
trainer.limit_train_batches = 1.0
|
||||
|
||||
|
|
|
@ -267,7 +267,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto
|
|||
# Use special lr logger callback
|
||||
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
|
||||
# No logging
|
||||
trainer.logger = DummyLogger() if trainer.logger is not None else None
|
||||
trainer.loggers = [DummyLogger()] if trainer.loggers else []
|
||||
# Max step set to number of iterations
|
||||
trainer.fit_loop.max_steps = num_training
|
||||
|
||||
|
|
|
@ -146,3 +146,19 @@ def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[
|
|||
metrics = {f"{prefix}{separator}{k}": v for k, v in metrics.items()}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _name(loggers: List[Any], separator: str = "_") -> str:
|
||||
if len(loggers) == 1:
|
||||
return loggers[0].name
|
||||
else:
|
||||
# Concatenate names together, removing duplicates and preserving order
|
||||
return separator.join(dict.fromkeys(str(logger.name) for logger in loggers))
|
||||
|
||||
|
||||
def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]:
|
||||
if len(loggers) == 1:
|
||||
return loggers[0].version
|
||||
else:
|
||||
# Concatenate versions together, removing duplicates and preserving order
|
||||
return separator.join(dict.fromkeys(str(logger.version) for logger in loggers))
|
||||
|
|
|
@ -1273,7 +1273,7 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir):
|
|||
def test_last_global_step_saved():
|
||||
# this should not save anything
|
||||
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
|
||||
trainer = Mock()
|
||||
trainer = MagicMock()
|
||||
trainer.callback_metrics = {"foo": 123}
|
||||
model_checkpoint.save_checkpoint(trainer)
|
||||
assert model_checkpoint._last_global_step_saved == -1
|
||||
|
|
|
@ -186,10 +186,11 @@ def test_multiple_loggers_pickle(tmpdir):
|
|||
trainer = Trainer(logger=[logger1, logger2])
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({"acc": 1.0}, 0)
|
||||
for logger in trainer2.loggers:
|
||||
logger.log_metrics({"acc": 1.0}, 0)
|
||||
|
||||
assert trainer2.logger[0].metrics_logged == {"acc": 1.0}
|
||||
assert trainer2.logger[1].metrics_logged == {"acc": 1.0}
|
||||
for logger in trainer2.loggers:
|
||||
assert logger.metrics_logged == {"acc": 1.0}
|
||||
|
||||
|
||||
def test_adding_step_key(tmpdir):
|
||||
|
|
|
@ -24,8 +24,7 @@ import torch
|
|||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
|
||||
from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
|
||||
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -467,7 +466,7 @@ def test_pytorch_profiler_logger_collection(tmpdir):
|
|||
|
||||
model = BoringModel()
|
||||
# Wrap the logger in a list so it becomes a LoggerCollection
|
||||
logger = [TensorBoardLogger(save_dir=tmpdir), DummyLogger()]
|
||||
logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
|
||||
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
|
||||
assert isinstance(trainer.logger, LoggerCollection)
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -15,8 +15,7 @@ import os
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -118,7 +117,7 @@ def test_logdir_logger_collection(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=default_root_dir,
|
||||
max_steps=2,
|
||||
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), DummyLogger()],
|
||||
logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)],
|
||||
)
|
||||
assert isinstance(trainer.logger, LoggerCollection)
|
||||
assert trainer.log_dir == default_root_dir
|
||||
|
|
|
@ -30,18 +30,20 @@ def test_trainer_loggers_property():
|
|||
# trainer.loggers should create a list of size 1
|
||||
trainer = Trainer(logger=logger1)
|
||||
|
||||
assert trainer.logger == logger1
|
||||
assert trainer.loggers == [logger1]
|
||||
|
||||
# trainer.loggers should be an empty list
|
||||
trainer = Trainer(logger=False)
|
||||
|
||||
assert trainer.logger is None
|
||||
assert trainer.loggers == []
|
||||
|
||||
# trainer.loggers should be a list of size 1 holding the default logger
|
||||
trainer = Trainer(logger=True)
|
||||
|
||||
assert trainer.loggers == [trainer.logger]
|
||||
assert type(trainer.loggers[0]) == TensorBoardLogger
|
||||
assert isinstance(trainer.logger, TensorBoardLogger)
|
||||
|
||||
|
||||
def test_trainer_loggers_setters():
|
||||
|
|
|
@ -17,12 +17,15 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import CSVLogger
|
||||
from pytorch_lightning.utilities.logger import (
|
||||
_add_prefix,
|
||||
_convert_params,
|
||||
_flatten_dict,
|
||||
_name,
|
||||
_sanitize_callable_params,
|
||||
_sanitize_params,
|
||||
_version,
|
||||
)
|
||||
|
||||
|
||||
|
@ -172,3 +175,37 @@ def test_add_prefix():
|
|||
assert "prefix-metric2" not in metrics
|
||||
assert metrics["prefix2_prefix-metric1"] == 1
|
||||
assert metrics["prefix2_prefix-metric2"] == 2
|
||||
|
||||
|
||||
def test_name(tmpdir):
|
||||
"""Verify names of loggers are concatenated properly."""
|
||||
logger1 = CSVLogger(tmpdir, name="foo")
|
||||
logger2 = CSVLogger(tmpdir, name="bar")
|
||||
logger3 = CSVLogger(tmpdir, name="foo")
|
||||
logger4 = CSVLogger(tmpdir, name="baz")
|
||||
loggers = [logger1, logger2, logger3, logger4]
|
||||
name = _name([])
|
||||
assert name == ""
|
||||
name = _name([logger3])
|
||||
assert name == "foo"
|
||||
name = _name(loggers)
|
||||
assert name == "foo_bar_baz"
|
||||
name = _name(loggers, "-")
|
||||
assert name == "foo-bar-baz"
|
||||
|
||||
|
||||
def test_version(tmpdir):
|
||||
"""Verify versions of loggers are concatenated properly."""
|
||||
logger1 = CSVLogger(tmpdir, version=0)
|
||||
logger2 = CSVLogger(tmpdir, version=2)
|
||||
logger3 = CSVLogger(tmpdir, version=1)
|
||||
logger4 = CSVLogger(tmpdir, version=0)
|
||||
loggers = [logger1, logger2, logger3, logger4]
|
||||
version = _version([])
|
||||
assert version == ""
|
||||
version = _version([logger3])
|
||||
assert version == 1
|
||||
version = _version(loggers)
|
||||
assert version == "0_2_1"
|
||||
version = _version(loggers, "-")
|
||||
assert version == "0-2-1"
|
||||
|
|
Loading…
Reference in New Issue