Use `default_root_dir` as the `log_dir` with `LoggerCollection`s (#8187)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
a6fd32a708
commit
6604fc1344
|
@ -28,6 +28,7 @@ from pytorch_lightning.callbacks.base import Callback
|
|||
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.base import LoggerCollection
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.loops import PredictionLoop
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
|
||||
|
@ -226,8 +227,12 @@ class TrainerProperties(ABC):
|
|||
def log_dir(self) -> Optional[str]:
|
||||
if self.logger is None:
|
||||
dirpath = self.default_root_dir
|
||||
elif isinstance(self.logger, TensorBoardLogger):
|
||||
dirpath = self.logger.log_dir
|
||||
elif isinstance(self.logger, LoggerCollection):
|
||||
dirpath = self.default_root_dir
|
||||
else:
|
||||
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')
|
||||
dirpath = self.logger.save_dir
|
||||
|
||||
dirpath = self.accelerator.broadcast(dirpath)
|
||||
return dirpath
|
||||
|
|
|
@ -233,7 +233,10 @@ class Trainer(
|
|||
limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches)
|
||||
|
||||
logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
|
||||
the default ``TensorBoardLogger``. ``False`` will disable logging.
|
||||
the default ``TensorBoardLogger``. ``False`` will disable logging. If multiple loggers are
|
||||
provided and the `save_dir` property of that logger is not set, local files (checkpoints,
|
||||
profiler traces, etc.) are saved in ``default_root_dir`` rather than in the ``log_dir`` of any
|
||||
of the individual loggers.
|
||||
|
||||
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ import torch
|
|||
from packaging.version import Version
|
||||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.loggers.base import LoggerCollection
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
|
||||
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -430,6 +432,36 @@ def test_pytorch_profiler_nested(tmpdir):
|
|||
assert events_name == expected, (events_name, torch.__version__, platform.system())
|
||||
|
||||
|
||||
def test_pytorch_profiler_logger_collection(tmpdir):
|
||||
"""
|
||||
Tests whether the PyTorch profiler is able to write its trace locally when
|
||||
the Trainer's logger is an instance of LoggerCollection. See issue #8157.
|
||||
"""
|
||||
|
||||
def look_for_trace(trace_dir):
|
||||
""" Determines if a directory contains a PyTorch trace """
|
||||
return any("trace.json" in filename for filename in os.listdir(trace_dir))
|
||||
|
||||
# Sanity check
|
||||
assert not look_for_trace(tmpdir)
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
# Wrap the logger in a list so it becomes a LoggerCollection
|
||||
logger = [TensorBoardLogger(save_dir=tmpdir)]
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
profiler="pytorch",
|
||||
logger=logger,
|
||||
limit_train_batches=5,
|
||||
max_epochs=1,
|
||||
)
|
||||
|
||||
assert isinstance(trainer.logger, LoggerCollection)
|
||||
trainer.fit(model)
|
||||
assert look_for_trace(tmpdir)
|
||||
|
||||
|
||||
@RunIf(min_gpus=1, special=True)
|
||||
def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
|
||||
"""
|
||||
|
|
|
@ -15,7 +15,7 @@ import os
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -140,3 +140,20 @@ def test_logdir_custom_logger(tmpdir):
|
|||
assert trainer.log_dir == expected
|
||||
trainer.fit(model)
|
||||
assert trainer.log_dir == expected
|
||||
|
||||
|
||||
def test_logdir_logger_collection(tmpdir):
|
||||
"""Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection"""
|
||||
default_root_dir = tmpdir / "default_root_dir"
|
||||
save_dir = tmpdir / "save_dir"
|
||||
model = TestModel(default_root_dir)
|
||||
trainer = Trainer(
|
||||
default_root_dir=default_root_dir,
|
||||
max_steps=2,
|
||||
logger=[TensorBoardLogger(save_dir=save_dir, name='custom_logs')]
|
||||
)
|
||||
assert isinstance(trainer.logger, LoggerCollection)
|
||||
assert trainer.log_dir == default_root_dir
|
||||
|
||||
trainer.fit(model)
|
||||
assert trainer.log_dir == default_root_dir
|
||||
|
|
Loading…
Reference in New Issue