Use `default_root_dir` as the `log_dir` with `LoggerCollection`s (#8187)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Gustaf Ahdritz 2021-07-19 18:12:12 -04:00 committed by GitHub
parent a6fd32a708
commit 6604fc1344
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 3 deletions

View File

@ -28,6 +28,7 @@ from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops import PredictionLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
@ -226,8 +227,12 @@ class TrainerProperties(ABC):
def log_dir(self) -> Optional[str]:
if self.logger is None:
dirpath = self.default_root_dir
elif isinstance(self.logger, TensorBoardLogger):
dirpath = self.logger.log_dir
elif isinstance(self.logger, LoggerCollection):
dirpath = self.default_root_dir
else:
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')
dirpath = self.logger.save_dir
dirpath = self.accelerator.broadcast(dirpath)
return dirpath

View File

@ -233,7 +233,10 @@ class Trainer(
limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches)
logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
the default ``TensorBoardLogger``. ``False`` will disable logging.
the default ``TensorBoardLogger``. ``False`` will disable logging. If multiple loggers are
provided and the `save_dir` property of that logger is not set, local files (checkpoints,
profiler traces, etc.) are saved in ``default_root_dir`` rather than in the ``log_dir`` of any
of the individual loggers.
log_gpu_memory: None, 'min_max', 'all'. Might slow performance

View File

@ -23,6 +23,8 @@ import torch
from packaging.version import Version
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -430,6 +432,36 @@ def test_pytorch_profiler_nested(tmpdir):
assert events_name == expected, (events_name, torch.__version__, platform.system())
def test_pytorch_profiler_logger_collection(tmpdir):
"""
Tests whether the PyTorch profiler is able to write its trace locally when
the Trainer's logger is an instance of LoggerCollection. See issue #8157.
"""
def look_for_trace(trace_dir):
""" Determines if a directory contains a PyTorch trace """
return any("trace.json" in filename for filename in os.listdir(trace_dir))
# Sanity check
assert not look_for_trace(tmpdir)
model = BoringModel()
# Wrap the logger in a list so it becomes a LoggerCollection
logger = [TensorBoardLogger(save_dir=tmpdir)]
trainer = Trainer(
default_root_dir=tmpdir,
profiler="pytorch",
logger=logger,
limit_train_batches=5,
max_epochs=1,
)
assert isinstance(trainer.logger, LoggerCollection)
trainer.fit(model)
assert look_for_trace(tmpdir)
@RunIf(min_gpus=1, special=True)
def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
"""

View File

@ -15,7 +15,7 @@ import os
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from tests.helpers.boring_model import BoringModel
@ -140,3 +140,20 @@ def test_logdir_custom_logger(tmpdir):
assert trainer.log_dir == expected
trainer.fit(model)
assert trainer.log_dir == expected
def test_logdir_logger_collection(tmpdir):
"""Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection"""
default_root_dir = tmpdir / "default_root_dir"
save_dir = tmpdir / "save_dir"
model = TestModel(default_root_dir)
trainer = Trainer(
default_root_dir=default_root_dir,
max_steps=2,
logger=[TensorBoardLogger(save_dir=save_dir, name='custom_logs')]
)
assert isinstance(trainer.logger, LoggerCollection)
assert trainer.log_dir == default_root_dir
trainer.fit(model)
assert trainer.log_dir == default_root_dir