diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index fd2cf69f14..cef2107550 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -206,7 +206,8 @@ class GAN(LightningModule): # log sampled images sample_imgs = self(z) grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image("generated_images", grid, self.current_epoch) + for logger in self.loggers: + logger.experiment.add_image("generated_images", grid, self.current_epoch) def main(args: Namespace) -> None: diff --git a/pytorch_lightning/callbacks/device_stats_monitor.py b/pytorch_lightning/callbacks/device_stats_monitor.py index eaf67db918..f9cb3cf623 100644 --- a/pytorch_lightning/callbacks/device_stats_monitor.py +++ b/pytorch_lightning/callbacks/device_stats_monitor.py @@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback): """ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.") def on_train_batch_start( @@ -55,7 +55,7 @@ class DeviceStatsMonitor(Callback): batch_idx: int, unused: Optional[int] = 0, ) -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") if not trainer.logger_connector.should_update_logs: @@ -63,9 +63,10 @@ class DeviceStatsMonitor(Callback): device = trainer.strategy.root_device device_stats = trainer.accelerator.get_device_stats(device) - separator = trainer.logger.group_separator - prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator) - trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) + for logger in trainer.loggers: + separator = logger.group_separator + prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator) + logger.log_metrics(prefixed_device_stats, step=trainer.global_step) def on_train_batch_end( self, @@ -76,7 +77,7 @@ class DeviceStatsMonitor(Callback): batch_idx: int, unused: Optional[int] = 0, ) -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.") if not trainer.logger_connector.should_update_logs: @@ -84,9 +85,10 @@ class DeviceStatsMonitor(Callback): device = trainer.strategy.root_device device_stats = trainer.accelerator.get_device_stats(device) - separator = trainer.logger.group_separator - prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator) - trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step) + for logger in trainer.loggers: + separator = logger.group_separator + prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator) + logger.log_metrics(prefixed_device_stats, step=trainer.global_step) def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index a871bfa309..2c2949e53b 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -123,7 +123,7 @@ class GPUStatsMonitor(Callback): self._gpu_ids: List[str] = [] # will be assigned later in setup() def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.") if trainer.strategy.root_device.type != "cuda": @@ -161,8 +161,8 @@ class GPUStatsMonitor(Callback): # First log at beginning of second step logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 - assert trainer.logger is not None - trainer.logger.log_metrics(logs, step=trainer.global_step) + for logger in trainer.loggers: + logger.log_metrics(logs, step=trainer.global_step) @rank_zero_only def on_train_batch_end( @@ -186,8 +186,8 @@ class GPUStatsMonitor(Callback): if self._log_stats.intra_step_time and self._snap_intra_step_time: logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 - assert trainer.logger is not None - trainer.logger.log_metrics(logs, step=trainer.global_step) + for logger in trainer.loggers: + logger.log_metrics(logs, step=trainer.global_step) @staticmethod def _get_gpu_ids(device_ids: List[int]) -> List[str]: diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index 0f3519d8fe..00ff007af5 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -104,7 +104,7 @@ class LearningRateMonitor(Callback): MisconfigurationException: If ``Trainer`` has no ``logger``. """ - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException( "Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger." ) @@ -149,7 +149,6 @@ class LearningRateMonitor(Callback): self.last_momentum_values = {name + "-momentum": None for name in names_flatten} def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - assert trainer.logger is not None if not trainer.logger_connector.should_update_logs: return @@ -158,16 +157,17 @@ class LearningRateMonitor(Callback): latest_stat = self._extract_stats(trainer, interval) if latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + for logger in trainer.loggers: + logger.log_metrics(latest_stat, step=trainer.global_step) def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: - assert trainer.logger is not None if self.logging_interval != "step": interval = "epoch" if self.logging_interval is None else "any" latest_stat = self._extract_stats(trainer, interval) if latest_stat: - trainer.logger.log_metrics(latest_stat, step=trainer.global_step) + for logger in trainer.loggers: + logger.log_metrics(latest_stat, step=trainer.global_step) def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]: latest_stat = {} diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 13791a70f9..1e231abab2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -35,6 +35,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.logger import _name, _version from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -379,8 +380,9 @@ class ModelCheckpoint(Callback): self._save_last_checkpoint(trainer, monitor_candidates) # notify loggers - if trainer.is_global_zero and trainer.logger: - trainer.logger.after_save_checkpoint(proxy(self)) + if trainer.is_global_zero: + for logger in trainer.loggers: + logger.after_save_checkpoint(proxy(self)) def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: from pytorch_lightning.trainer.states import TrainerFn @@ -572,20 +574,20 @@ class ModelCheckpoint(Callback): """ if self.dirpath is not None: return # short circuit - - if trainer.logger is not None: + if trainer.loggers: if trainer.weights_save_path != trainer.default_root_dir: # the user has changed weights_save_path, it overrides anything save_dir = trainer.weights_save_path - else: + elif len(trainer.loggers) == 1: save_dir = trainer.logger.save_dir or trainer.default_root_dir + else: + save_dir = trainer.default_root_dir - version = ( - trainer.logger.version - if isinstance(trainer.logger.version, str) - else f"version_{trainer.logger.version}" - ) - ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints") + name = _name(trainer.loggers) + version = _version(trainer.loggers) + version = version if isinstance(version, str) else f"version_{version}" + + ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") else: ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 291fb495a8..3ee1f83a54 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -15,6 +15,7 @@ from typing import Any, Dict, Optional, Union import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.logger import _version from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -213,11 +214,12 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") if pl_module.truncated_bptt_steps > 0: items_dict["split_idx"] = trainer.fit_loop.split_idx - if trainer.logger is not None and trainer.logger.version is not None: - version = trainer.logger.version - if isinstance(version, str): - # show last 4 places of long version strings - version = version[-4:] - items_dict["v_num"] = version + if trainer.loggers: + version = _version(trainer.loggers) + if version is not None: + if isinstance(version, str): + # show last 4 places of long version strings + version = version[-4:] + items_dict["v_num"] = version return items_dict diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index cfbffaff41..ebc6ca9d72 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -70,7 +70,7 @@ class XLAStatsMonitor(Callback): self._verbose = verbose def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") if isinstance(trainer.accelerator, TPUAccelerator): @@ -88,7 +88,7 @@ class XLAStatsMonitor(Callback): self._start_time = time.time() def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not trainer.logger: + if not trainer.loggers: raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.") device = trainer.strategy.root_device @@ -102,10 +102,11 @@ class XLAStatsMonitor(Callback): peak_memory = trainer.strategy.reduce(peak_memory) * 0.001 epoch_time = trainer.strategy.reduce(epoch_time) - trainer.logger.log_metrics( - {"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)}, - step=trainer.current_epoch, - ) + for logger in trainer.loggers: + logger.log_metrics( + {"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)}, + step=trainer.current_epoch, + ) if self._verbose: rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 57278a8af6..65c0e6dc97 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -253,7 +253,7 @@ class LightningModule( @property def loggers(self) -> List[LightningLoggerBase]: - """Reference to the loggers object in the Trainer.""" + """Reference to the list of loggers in the Trainer.""" return self.trainer.loggers if self.trainer else [] def _apply_batch_transfer_handler( diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index c8eefedd3c..8721694489 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -504,8 +504,10 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): """Flushes loggers to disk.""" # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs - if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: - self.trainer.logger.save() + # TODO: is_global_zero check should be moved to logger.save() implementation + if should_flush_logs and self.trainer.is_global_zero: + for logger in self.trainer.loggers: + logger.save() def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None: if self._dataloader_state_dict: diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index c876dd5c90..5a5606c5f2 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -155,7 +155,11 @@ class PrecisionPlugin(CheckpointHooks): def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if trainer.track_grad_norm == -1: return - kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {} + + kwargs = {} + if len(trainer.loggers) == 1: + kwargs["group_separator"] = trainer.loggers[0].group_separator + grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs) if grad_norm_dict: prev_fx = trainer.lightning_module._current_fx_name diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 2e241e9845..428713ff33 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -17,7 +17,7 @@ import torch import pytorch_lightning as pl from pytorch_lightning.accelerators import GPUAccelerator -from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger +from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT from pytorch_lightning.trainer.states import RunningStage @@ -90,15 +90,15 @@ class LoggerConnector: def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None: if isinstance(logger, bool): # default logger - self.trainer.logger = ( - TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id()) + self.trainer.loggers = ( + [TensorBoardLogger(save_dir=self.trainer.default_root_dir, version=SLURMEnvironment.job_id())] if logger - else None + else [] ) elif isinstance(logger, Iterable): - self.trainer.logger = LoggerCollection(logger) + self.trainer.loggers = list(logger) else: - self.trainer.logger = logger + self.trainer.loggers = [logger] def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses @@ -109,7 +109,7 @@ class LoggerConnector: step: Step for which metrics should be logged. Default value is `self.global_step` during training or the total validation / test log step count during validation and testing. """ - if self.trainer.logger is None or not metrics: + if not self.trainer.loggers or not metrics: return self._logged_metrics.update(metrics) @@ -126,11 +126,12 @@ class LoggerConnector: step = self.trainer.global_step # log actual metrics - if self._override_agg_and_log_metrics: - self.trainer.logger.agg_and_log_metrics(metrics=scalar_metrics, step=step) - else: - self.trainer.logger.log_metrics(metrics=scalar_metrics, step=step) - self.trainer.logger.save() + for logger in self.trainer.loggers: + if self._override_agg_and_log_metrics: + logger.agg_and_log_metrics(metrics=scalar_metrics, step=step) + else: + logger.log_metrics(metrics=scalar_metrics, step=step) + logger.save() """ Evaluation metric updates diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index ee32818691..8d8ac428fd 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -66,8 +66,8 @@ class SignalConnector: rank_zero_info("handling SIGUSR1") # save logger to make sure we get all the metrics - if self.trainer.logger: - self.trainer.logger.finalize("finished") + for logger in self.trainer.loggers: + logger.finalize("finished") hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.weights_save_path) self.trainer.save_checkpoint(hpc_save_path) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 64988e3a58..a26c2babc4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -657,7 +657,7 @@ class Trainer( self.fit_loop.max_epochs = 1 val_check_interval = 1.0 self.check_val_every_n_epoch = 1 - self.logger = DummyLogger() if self.logger is not None else None + self.loggers = [DummyLogger()] if self.loggers else [] rank_zero_info( "Running in fast_dev_run mode: will run a full train," @@ -1246,41 +1246,43 @@ class Trainer( return results def _log_hyperparams(self) -> None: + if not self.loggers: + return # log hyper-parameters hparams_initial = None - if self.logger is not None: - # save exp to get started (this is where the first experiment logs are written) - datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False + # save exp to get started (this is where the first experiment logs are written) + datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False - if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: - datamodule_hparams = self.datamodule.hparams_initial - lightning_hparams = self.lightning_module.hparams_initial - inconsistent_keys = [] - for key in lightning_hparams.keys() & datamodule_hparams.keys(): - lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] - if type(lm_val) != type(dm_val): - inconsistent_keys.append(key) - elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): - inconsistent_keys.append(key) - elif lm_val != dm_val: - inconsistent_keys.append(key) - if inconsistent_keys: - raise MisconfigurationException( - f"Error while merging hparams: the keys {inconsistent_keys} are present " - "in both the LightningModule's and LightningDataModule's hparams " - "but have different values." - ) - hparams_initial = {**lightning_hparams, **datamodule_hparams} - elif self.lightning_module._log_hyperparams: - hparams_initial = self.lightning_module.hparams_initial - elif datamodule_log_hyperparams: - hparams_initial = self.datamodule.hparams_initial + if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: + datamodule_hparams = self.datamodule.hparams_initial + lightning_hparams = self.lightning_module.hparams_initial + inconsistent_keys = [] + for key in lightning_hparams.keys() & datamodule_hparams.keys(): + lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] + if type(lm_val) != type(dm_val): + inconsistent_keys.append(key) + elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val): + inconsistent_keys.append(key) + elif lm_val != dm_val: + inconsistent_keys.append(key) + if inconsistent_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {inconsistent_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams " + "but have different values." + ) + hparams_initial = {**lightning_hparams, **datamodule_hparams} + elif self.lightning_module._log_hyperparams: + hparams_initial = self.lightning_module.hparams_initial + elif datamodule_log_hyperparams: + hparams_initial = self.datamodule.hparams_initial + for logger in self.loggers: if hparams_initial is not None: - self.logger.log_hyperparams(hparams_initial) - self.logger.log_graph(self.lightning_module) - self.logger.save() + logger.log_hyperparams(hparams_initial) + logger.log_graph(self.lightning_module) + logger.save() def _teardown(self): """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and @@ -1519,8 +1521,8 @@ class Trainer( # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. # It might be related to xla tensors blocked when moving the cpu kill loggers. - if self.logger is not None: - self.logger.finalize("success") + for logger in self.loggers: + logger.finalize("success") # summarize profile results self.profiler.describe() @@ -1890,7 +1892,7 @@ class Trainer( self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) - if self.logger and self.num_training_batches < self.log_every_n_steps: + if self.loggers and self.num_training_batches < self.log_every_n_steps: rank_zero_warn( f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" @@ -2137,14 +2139,13 @@ class Trainer( @property def log_dir(self) -> Optional[str]: - if self.logger is None: - dirpath = self.default_root_dir - elif isinstance(self.logger, TensorBoardLogger): - dirpath = self.logger.log_dir - elif isinstance(self.logger, LoggerCollection): - dirpath = self.default_root_dir + if len(self.loggers) == 1: + if isinstance(self.logger, TensorBoardLogger): + dirpath = self.logger.log_dir + else: + dirpath = self.logger.save_dir else: - dirpath = self.logger.save_dir + dirpath = self.default_root_dir dirpath = self.strategy.broadcast(dirpath) return dirpath diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 1526e570da..3d5916e3f8 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -109,7 +109,7 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> N trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.fit_loop.max_steps = steps_per_trial # take few steps - trainer.logger = DummyLogger() if trainer.logger is not None else None + trainer.loggers = [DummyLogger()] if trainer.loggers else [] trainer.callbacks = [] # not needed before full run trainer.limit_train_batches = 1.0 diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 876ff7823b..d929bbe2f8 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -267,7 +267,7 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto # Use special lr logger callback trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] # No logging - trainer.logger = DummyLogger() if trainer.logger is not None else None + trainer.loggers = [DummyLogger()] if trainer.loggers else [] # Max step set to number of iterations trainer.fit_loop.max_steps = num_training diff --git a/pytorch_lightning/utilities/logger.py b/pytorch_lightning/utilities/logger.py index a66582fd84..ef27761a2e 100644 --- a/pytorch_lightning/utilities/logger.py +++ b/pytorch_lightning/utilities/logger.py @@ -146,3 +146,19 @@ def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[ metrics = {f"{prefix}{separator}{k}": v for k, v in metrics.items()} return metrics + + +def _name(loggers: List[Any], separator: str = "_") -> str: + if len(loggers) == 1: + return loggers[0].name + else: + # Concatenate names together, removing duplicates and preserving order + return separator.join(dict.fromkeys(str(logger.name) for logger in loggers)) + + +def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]: + if len(loggers) == 1: + return loggers[0].version + else: + # Concatenate versions together, removing duplicates and preserving order + return separator.join(dict.fromkeys(str(logger.version) for logger in loggers)) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a5c3aae5b1..2c65426534 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1273,7 +1273,7 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir): def test_last_global_step_saved(): # this should not save anything model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo") - trainer = Mock() + trainer = MagicMock() trainer.callback_metrics = {"foo": 123} model_checkpoint.save_checkpoint(trainer) assert model_checkpoint._last_global_step_saved == -1 diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index afcb811a79..cd7eec14ee 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -186,10 +186,11 @@ def test_multiple_loggers_pickle(tmpdir): trainer = Trainer(logger=[logger1, logger2]) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) - trainer2.logger.log_metrics({"acc": 1.0}, 0) + for logger in trainer2.loggers: + logger.log_metrics({"acc": 1.0}, 0) - assert trainer2.logger[0].metrics_logged == {"acc": 1.0} - assert trainer2.logger[1].metrics_logged == {"acc": 1.0} + for logger in trainer2.loggers: + assert logger.metrics_logged == {"acc": 1.0} def test_adding_step_key(tmpdir): diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index f63a31f6d8..161c9e6e35 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -24,8 +24,7 @@ import torch from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging -from pytorch_lightning.loggers.base import DummyLogger, LoggerCollection -from pytorch_lightning.loggers.tensorboard import TensorBoardLogger +from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -467,7 +466,7 @@ def test_pytorch_profiler_logger_collection(tmpdir): model = BoringModel() # Wrap the logger in a list so it becomes a LoggerCollection - logger = [TensorBoardLogger(save_dir=tmpdir), DummyLogger()] + logger = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)] trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1) assert isinstance(trainer.logger, LoggerCollection) trainer.fit(model) diff --git a/tests/trainer/properties/test_log_dir.py b/tests/trainer/properties/test_log_dir.py index 71920a6b07..6777ec8183 100644 --- a/tests/trainer/properties/test_log_dir.py +++ b/tests/trainer/properties/test_log_dir.py @@ -15,8 +15,7 @@ import os from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger -from pytorch_lightning.loggers.base import DummyLogger +from pytorch_lightning.loggers import CSVLogger, LoggerCollection, TensorBoardLogger from tests.helpers.boring_model import BoringModel @@ -118,7 +117,7 @@ def test_logdir_logger_collection(tmpdir): trainer = Trainer( default_root_dir=default_root_dir, max_steps=2, - logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), DummyLogger()], + logger=[TensorBoardLogger(save_dir=save_dir, name="custom_logs"), CSVLogger(tmpdir)], ) assert isinstance(trainer.logger, LoggerCollection) assert trainer.log_dir == default_root_dir diff --git a/tests/trainer/properties/test_loggers.py b/tests/trainer/properties/test_loggers.py index 606c7b641a..d3db78986f 100644 --- a/tests/trainer/properties/test_loggers.py +++ b/tests/trainer/properties/test_loggers.py @@ -30,18 +30,20 @@ def test_trainer_loggers_property(): # trainer.loggers should create a list of size 1 trainer = Trainer(logger=logger1) + assert trainer.logger == logger1 assert trainer.loggers == [logger1] # trainer.loggers should be an empty list trainer = Trainer(logger=False) + assert trainer.logger is None assert trainer.loggers == [] # trainer.loggers should be a list of size 1 holding the default logger trainer = Trainer(logger=True) assert trainer.loggers == [trainer.logger] - assert type(trainer.loggers[0]) == TensorBoardLogger + assert isinstance(trainer.logger, TensorBoardLogger) def test_trainer_loggers_setters(): diff --git a/tests/utilities/test_logger.py b/tests/utilities/test_logger.py index 8d9b495fb9..6b67272289 100644 --- a/tests/utilities/test_logger.py +++ b/tests/utilities/test_logger.py @@ -17,12 +17,15 @@ import numpy as np import torch from pytorch_lightning import Trainer +from pytorch_lightning.loggers import CSVLogger from pytorch_lightning.utilities.logger import ( _add_prefix, _convert_params, _flatten_dict, + _name, _sanitize_callable_params, _sanitize_params, + _version, ) @@ -172,3 +175,37 @@ def test_add_prefix(): assert "prefix-metric2" not in metrics assert metrics["prefix2_prefix-metric1"] == 1 assert metrics["prefix2_prefix-metric2"] == 2 + + +def test_name(tmpdir): + """Verify names of loggers are concatenated properly.""" + logger1 = CSVLogger(tmpdir, name="foo") + logger2 = CSVLogger(tmpdir, name="bar") + logger3 = CSVLogger(tmpdir, name="foo") + logger4 = CSVLogger(tmpdir, name="baz") + loggers = [logger1, logger2, logger3, logger4] + name = _name([]) + assert name == "" + name = _name([logger3]) + assert name == "foo" + name = _name(loggers) + assert name == "foo_bar_baz" + name = _name(loggers, "-") + assert name == "foo-bar-baz" + + +def test_version(tmpdir): + """Verify versions of loggers are concatenated properly.""" + logger1 = CSVLogger(tmpdir, version=0) + logger2 = CSVLogger(tmpdir, version=2) + logger3 = CSVLogger(tmpdir, version=1) + logger4 = CSVLogger(tmpdir, version=0) + loggers = [logger1, logger2, logger3, logger4] + version = _version([]) + assert version == "" + version = _version([logger3]) + assert version == 1 + version = _version(loggers) + assert version == "0_2_1" + version = _version(loggers, "-") + assert version == "0-2-1"