From 6604fc1344e1b8a459c45a5a2157aa7fc60d950d Mon Sep 17 00:00:00 2001 From: Gustaf Ahdritz Date: Mon, 19 Jul 2021 18:12:12 -0400 Subject: [PATCH] Use `default_root_dir` as the `log_dir` with `LoggerCollection`s (#8187) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/properties.py | 7 +++++- pytorch_lightning/trainer/trainer.py | 5 +++- tests/profiler/test_profiler.py | 32 +++++++++++++++++++++++++ tests/trainer/properties/log_dir.py | 19 ++++++++++++++- 4 files changed, 60 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 685ad979ee..aa9b5ec43b 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -28,6 +28,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.loggers.base import LoggerCollection from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.loops import PredictionLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop @@ -226,8 +227,12 @@ class TrainerProperties(ABC): def log_dir(self) -> Optional[str]: if self.logger is None: dirpath = self.default_root_dir + elif isinstance(self.logger, TensorBoardLogger): + dirpath = self.logger.log_dir + elif isinstance(self.logger, LoggerCollection): + dirpath = self.default_root_dir else: - dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir') + dirpath = self.logger.save_dir dirpath = self.accelerator.broadcast(dirpath) return dirpath diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2a716ebc3a..bae2d4ef1d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -233,7 +233,10 @@ class Trainer( limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches) logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses - the default ``TensorBoardLogger``. ``False`` will disable logging. + the default ``TensorBoardLogger``. ``False`` will disable logging. If multiple loggers are + provided and the `save_dir` property of that logger is not set, local files (checkpoints, + profiler traces, etc.) are saved in ``default_root_dir`` rather than in the ``log_dir`` of any + of the individual loggers. log_gpu_memory: None, 'min_max', 'all'. Might slow performance diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index c04a1d849a..0e26ba39ee 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -23,6 +23,8 @@ import torch from packaging.version import Version from pytorch_lightning import Callback, Trainer +from pytorch_lightning.loggers.base import LoggerCollection +from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from pytorch_lightning.profiler.pytorch import RegisterRecordFunction from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -430,6 +432,36 @@ def test_pytorch_profiler_nested(tmpdir): assert events_name == expected, (events_name, torch.__version__, platform.system()) +def test_pytorch_profiler_logger_collection(tmpdir): + """ + Tests whether the PyTorch profiler is able to write its trace locally when + the Trainer's logger is an instance of LoggerCollection. See issue #8157. + """ + + def look_for_trace(trace_dir): + """ Determines if a directory contains a PyTorch trace """ + return any("trace.json" in filename for filename in os.listdir(trace_dir)) + + # Sanity check + assert not look_for_trace(tmpdir) + + model = BoringModel() + + # Wrap the logger in a list so it becomes a LoggerCollection + logger = [TensorBoardLogger(save_dir=tmpdir)] + trainer = Trainer( + default_root_dir=tmpdir, + profiler="pytorch", + logger=logger, + limit_train_batches=5, + max_epochs=1, + ) + + assert isinstance(trainer.logger, LoggerCollection) + trainer.fit(model) + assert look_for_trace(tmpdir) + + @RunIf(min_gpus=1, special=True) def test_pytorch_profiler_nested_emit_nvtx(tmpdir): """ diff --git a/tests/trainer/properties/log_dir.py b/tests/trainer/properties/log_dir.py index 730e2a1512..b4565b963e 100644 --- a/tests/trainer/properties/log_dir.py +++ b/tests/trainer/properties/log_dir.py @@ -15,7 +15,7 @@ import os from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from tests.helpers.boring_model import BoringModel @@ -140,3 +140,20 @@ def test_logdir_custom_logger(tmpdir): assert trainer.log_dir == expected trainer.fit(model) assert trainer.log_dir == expected + + +def test_logdir_logger_collection(tmpdir): + """Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection""" + default_root_dir = tmpdir / "default_root_dir" + save_dir = tmpdir / "save_dir" + model = TestModel(default_root_dir) + trainer = Trainer( + default_root_dir=default_root_dir, + max_steps=2, + logger=[TensorBoardLogger(save_dir=save_dir, name='custom_logs')] + ) + assert isinstance(trainer.logger, LoggerCollection) + assert trainer.log_dir == default_root_dir + + trainer.fit(model) + assert trainer.log_dir == default_root_dir