Fix log_dir property (#5537)
* fix and update tests * update with ModelCheckpoint * chlog * wip wandb fix * all fixed Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
parent
a3161267d9
commit
2abf4693bc
|
@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745))
|
||||
|
||||
|
||||
- Fixed `log_dir` property ([#5537](https://github.com/PyTorchLightning/pytorch-lightning/pull/5537))
|
||||
|
||||
|
||||
- Fixed a race condition in `ModelCheckpoint` when checking if a checkpoint file exists ([#5144](https://github.com/PyTorchLightning/pytorch-lightning/pull/5144))
|
||||
|
||||
- Remove unnecessary intermediate layers in Dockerfiles ([#5697](https://github.com/PyTorchLightning/pytorch-lightning/pull/5697))
|
||||
|
|
|
@ -195,7 +195,7 @@ class ModelCheckpoint(Callback):
|
|||
"""
|
||||
When pretrain routine starts we build the ckpt dir on the fly
|
||||
"""
|
||||
self.__resolve_ckpt_dir(trainer, pl_module)
|
||||
self.__resolve_ckpt_dir(trainer)
|
||||
self.save_function = trainer.save_checkpoint
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
|
@ -427,7 +427,7 @@ class ModelCheckpoint(Callback):
|
|||
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
|
||||
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name
|
||||
|
||||
def __resolve_ckpt_dir(self, trainer, pl_module):
|
||||
def __resolve_ckpt_dir(self, trainer):
|
||||
"""
|
||||
Determines model checkpoint save directory at runtime. References attributes from the
|
||||
trainer's logger to determine where to save checkpoints.
|
||||
|
|
|
@ -24,6 +24,7 @@ import torch.nn as nn
|
|||
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import _module_available, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.warning_utils import WarningCache
|
||||
|
||||
_WANDB_AVAILABLE = _module_available("wandb")
|
||||
|
@ -100,6 +101,14 @@ class WandbLogger(LightningLoggerBase):
|
|||
if wandb is None:
|
||||
raise ImportError('You want to use `wandb` logger which is not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install wandb`.')
|
||||
|
||||
if offline and log_model:
|
||||
raise MisconfigurationException(
|
||||
f'Providing log_model={log_model} and offline={offline} is an invalid configuration'
|
||||
' since model checkpoints cannot be uploaded in offline mode.\n'
|
||||
'Hint: Set `offline=False` to log your model.'
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self._name = name
|
||||
self._save_dir = save_dir
|
||||
|
@ -144,10 +153,12 @@ class WandbLogger(LightningLoggerBase):
|
|||
self._experiment = wandb.init(
|
||||
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
|
||||
id=self._id, resume='allow', **self._kwargs) if wandb.run is None else wandb.run
|
||||
|
||||
# offset logging step when resuming a run
|
||||
self._step_offset = self._experiment.step
|
||||
|
||||
# save checkpoints in wandb dir to upload on W&B servers
|
||||
if self._log_model:
|
||||
if self._save_dir is None:
|
||||
self._save_dir = self._experiment.dir
|
||||
return self._experiment
|
||||
|
||||
|
|
|
@ -14,15 +14,11 @@
|
|||
import inspect
|
||||
import os
|
||||
from abc import ABC
|
||||
from argparse import ArgumentParser
|
||||
from argparse import Namespace
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import cast, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from pytorch_lightning.accelerators.legacy.accelerator import Accelerator
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import ProgressBarBase
|
||||
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
|
@ -30,15 +26,13 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import Checkpoint
|
|||
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import DeviceType
|
||||
from pytorch_lightning.utilities import DistributedType
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.argparse import add_argparse_args
|
||||
from pytorch_lightning.utilities.argparse import from_argparse_args
|
||||
from pytorch_lightning.utilities.argparse import parse_argparser
|
||||
from pytorch_lightning.utilities.argparse import parse_env_variables
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, DeviceType, DistributedType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.argparse import (
|
||||
add_argparse_args,
|
||||
from_argparse_args,
|
||||
parse_argparser,
|
||||
parse_env_variables,
|
||||
)
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
||||
|
@ -75,16 +69,10 @@ class TrainerProperties(ABC):
|
|||
|
||||
@property
|
||||
def log_dir(self):
|
||||
if self.checkpoint_callback is not None:
|
||||
dirpath = self.checkpoint_callback.dirpath
|
||||
dirpath = os.path.split(dirpath)[0]
|
||||
elif self.logger is not None:
|
||||
if isinstance(self.logger, TensorBoardLogger):
|
||||
dirpath = self.logger.log_dir
|
||||
else:
|
||||
dirpath = self.logger.save_dir
|
||||
if self.logger is None:
|
||||
dirpath = self.default_root_dir
|
||||
else:
|
||||
dirpath = self._default_root_dir
|
||||
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')
|
||||
|
||||
if self.accelerator_backend is not None:
|
||||
dirpath = self.accelerator_backend.broadcast(dirpath)
|
||||
|
|
|
@ -19,7 +19,7 @@ import pytest
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import CometLogger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import BoringModel
|
||||
from tests.base import BoringModel, EvalModelTemplate
|
||||
|
||||
|
||||
def _patch_comet_atexit(monkeypatch):
|
||||
|
@ -74,7 +74,7 @@ def test_comet_logger_online(comet):
|
|||
@patch('pytorch_lightning.loggers.comet.comet_ml')
|
||||
def test_comet_logger_no_api_key_given(comet):
|
||||
""" Test that CometLogger fails to initialize if both api key and save_dir are missing. """
|
||||
with pytest.raises(MisconfigurationException):
|
||||
with pytest.raises(MisconfigurationException, match='requires either api_key or save_dir'):
|
||||
comet.config.get_api_key.return_value = None
|
||||
CometLogger(workspace='dummy-test', project_name='general')
|
||||
|
||||
|
@ -89,13 +89,10 @@ def test_comet_logger_experiment_name(comet):
|
|||
# Test api_key given
|
||||
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet_experiment:
|
||||
logger = CometLogger(api_key=api_key, experiment_name=experiment_name,)
|
||||
|
||||
assert logger._experiment is None
|
||||
|
||||
_ = logger.experiment
|
||||
|
||||
comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)
|
||||
|
||||
comet_experiment().set_name.assert_called_once_with(experiment_name)
|
||||
|
||||
|
||||
|
@ -118,13 +115,10 @@ def test_comet_logger_manual_experiment_key(comet):
|
|||
with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}):
|
||||
with patch('pytorch_lightning.loggers.comet.CometExperiment', side_effect=save_os_environ) as comet_experiment:
|
||||
logger = CometLogger(api_key=api_key)
|
||||
|
||||
assert logger.version == experiment_key
|
||||
|
||||
assert logger._experiment is None
|
||||
|
||||
_ = logger.experiment
|
||||
|
||||
comet_experiment.assert_called_once_with(api_key=api_key, project_name=None)
|
||||
|
||||
assert instantation_environ["COMET_EXPERIMENT_KEY"] == experiment_key
|
||||
|
@ -154,19 +148,14 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch
|
|||
logger.experiment.id = '1'
|
||||
logger.experiment.project_name = 'test'
|
||||
|
||||
limit_batches = 5
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
logger=logger,
|
||||
max_epochs=1,
|
||||
limit_train_batches=limit_batches,
|
||||
limit_val_batches=limit_batches,
|
||||
)
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints')
|
||||
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f'epoch=0-step={limit_batches - 1}.ckpt']
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
|
||||
|
||||
@patch('pytorch_lightning.loggers.comet.comet_ml')
|
||||
|
@ -177,11 +166,8 @@ def test_comet_name_default(comet):
|
|||
|
||||
with patch('pytorch_lightning.loggers.comet.CometExperiment'):
|
||||
logger = CometLogger(api_key=api_key)
|
||||
|
||||
assert logger._experiment is None
|
||||
|
||||
assert logger.name == "comet-default"
|
||||
|
||||
assert logger._experiment is None
|
||||
|
||||
|
||||
|
@ -194,11 +180,8 @@ def test_comet_name_project_name(comet):
|
|||
|
||||
with patch('pytorch_lightning.loggers.comet.CometExperiment'):
|
||||
logger = CometLogger(api_key=api_key, project_name=project_name)
|
||||
|
||||
assert logger._experiment is None
|
||||
|
||||
assert logger.name == project_name
|
||||
|
||||
assert logger._experiment is None
|
||||
|
||||
|
||||
|
@ -212,14 +195,11 @@ def test_comet_version_without_experiment(comet):
|
|||
|
||||
with patch('pytorch_lightning.loggers.comet.CometExperiment'):
|
||||
logger = CometLogger(api_key=api_key, experiment_name=experiment_name)
|
||||
|
||||
assert logger._experiment is None
|
||||
|
||||
first_version = logger.version
|
||||
assert first_version is not None
|
||||
|
||||
assert logger.version == first_version
|
||||
|
||||
assert logger._experiment is None
|
||||
|
||||
_ = logger.experiment
|
||||
|
|
|
@ -111,9 +111,11 @@ def test_mlflow_log_dir(client, mlflow, tmpdir):
|
|||
limit_train_batches=1,
|
||||
limit_val_batches=3,
|
||||
)
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
trainer.fit(model)
|
||||
assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints')
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'}
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
|
||||
|
||||
def test_mlflow_logger_dirs_creation(tmpdir):
|
||||
|
|
|
@ -114,7 +114,9 @@ def test_neptune_leave_open_experiment_after_fit(neptune, tmpdir):
|
|||
limit_train_batches=0.05,
|
||||
logger=logger,
|
||||
)
|
||||
assert trainer.log_dir is None
|
||||
trainer.fit(model)
|
||||
assert trainer.log_dir is None
|
||||
return logger
|
||||
|
||||
logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
|
||||
|
|
|
@ -24,7 +24,7 @@ from tensorboard.backend.event_processing.event_accumulator import EventAccumula
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests.base import BoringModel
|
||||
from tests.base import BoringModel, EvalModelTemplate
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -32,16 +32,14 @@ from tests.base import BoringModel
|
|||
reason="Minimal PT version is set to 1.5",
|
||||
)
|
||||
def test_tensorboard_hparams_reload(tmpdir):
|
||||
class CustomModel(BoringModel):
|
||||
def __init__(self, b1=0.5, b2=0.999):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
model = EvalModelTemplate()
|
||||
|
||||
model = CustomModel()
|
||||
trainer = Trainer(max_steps=1, default_root_dir=tmpdir)
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
assert trainer.log_dir == trainer.logger.log_dir
|
||||
trainer.fit(model)
|
||||
|
||||
folder_path = trainer.logger.log_dir
|
||||
assert trainer.log_dir == trainer.logger.log_dir
|
||||
folder_path = trainer.log_dir
|
||||
|
||||
# make sure yaml is there
|
||||
with open(os.path.join(folder_path, "hparams.yaml")) as file:
|
||||
|
@ -49,8 +47,7 @@ def test_tensorboard_hparams_reload(tmpdir):
|
|||
# scalar values to Python the dictionary format
|
||||
yaml_params = yaml.safe_load(file)
|
||||
assert yaml_params["b1"] == 0.5
|
||||
assert yaml_params["b2"] == 0.999
|
||||
assert len(yaml_params.keys()) == 2
|
||||
assert len(yaml_params.keys()) == 10
|
||||
|
||||
# verify artifacts
|
||||
assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1
|
||||
|
@ -59,8 +56,14 @@ def test_tensorboard_hparams_reload(tmpdir):
|
|||
event_acc = EventAccumulator(folder_path)
|
||||
event_acc.Reload()
|
||||
|
||||
data_pt_1_5 = b'\x12\x1b"\x04\n\x02b1"\x04\n\x02b2*\r\n\x0b\x12\thp_metric'
|
||||
data_pt_1_6 = b'\x12\x1f"\x06\n\x02b1 \x03"\x06\n\x02b2 \x03*\r\n\x0b\x12\thp_metric'
|
||||
data_pt_1_5 = b'\x12\x93\x01"\x0b\n\tdrop_prob"\x0c\n\nbatch_size"\r\n\x0bin_features"\x0f\n\rlearning_rate"' \
|
||||
b'\x10\n\x0eoptimizer_name"\x0b\n\tdata_root"\x0e\n\x0cout_features"\x0c\n\nhidden_dim"' \
|
||||
b'\x04\n\x02b1"\x04\n\x02b2*\r\n\x0b\x12\thp_metric'
|
||||
data_pt_1_6 = b'\x12\xa7\x01"\r\n\tdrop_prob \x03"\x0e\n\nbatch_size \x03"\x0f\n\x0bin_features \x03"' \
|
||||
b'\x11\n\rlearning_rate \x03"\x12\n\x0eoptimizer_name \x01"\r\n\tdata_root \x01"' \
|
||||
b'\x10\n\x0cout_features \x03"\x0e\n\nhidden_dim \x03"\x06\n\x02b1 \x03"' \
|
||||
b'\x06\n\x02b2 \x03*\r\n\x0b\x12\thp_metric'
|
||||
|
||||
hparams_data = data_pt_1_6 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0") else data_pt_1_5
|
||||
|
||||
assert event_acc.summary_metadata['_hparams_/experiment'].plugin_data.plugin_name == 'hparams'
|
||||
|
|
|
@ -13,13 +13,11 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import pickle
|
||||
import types
|
||||
from argparse import ArgumentParser
|
||||
from unittest import mock
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from tests.base import BoringModel
|
||||
from tests.base import BoringModel, EvalModelTemplate
|
||||
|
||||
|
||||
def get_warnings(recwarn):
|
||||
|
@ -106,6 +104,7 @@ def test_wandb_pickle(wandb, tmpdir):
|
|||
""" """
|
||||
id = 'the_id'
|
||||
step = 0
|
||||
dir = 'wandb'
|
||||
|
||||
def project_name(self):
|
||||
return 'the_project_name'
|
||||
|
@ -121,6 +120,7 @@ def test_wandb_pickle(wandb, tmpdir):
|
|||
)
|
||||
# Access the experiment to ensure it's created
|
||||
assert trainer.logger.experiment, 'missing experiment'
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
|
||||
|
@ -158,19 +158,14 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
|
|||
assert not os.listdir(tmpdir)
|
||||
|
||||
version = logger.version
|
||||
model = BoringModel()
|
||||
limit_batches = 5
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
logger=logger,
|
||||
max_epochs=1,
|
||||
limit_train_batches=limit_batches,
|
||||
limit_val_batches=limit_batches,
|
||||
)
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3, log_every_n_steps=1)
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
trainer.fit(model)
|
||||
|
||||
assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
|
||||
assert os.listdir(trainer.checkpoint_callback.dirpath) == [f'epoch=0-step={limit_batches - 1}.ckpt']
|
||||
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'}
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
|
||||
|
||||
def test_wandb_sanitize_callable_params(tmpdir):
|
||||
|
@ -201,3 +196,10 @@ def test_wandb_sanitize_callable_params(tmpdir):
|
|||
assert params["something"] == "something"
|
||||
assert params["wrapper_something"] == "wrapper_something"
|
||||
assert params["wrapper_something_wo_name"] == "<lambda>"
|
||||
|
||||
|
||||
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
|
||||
def test_wandb_logger_offline_log_model(wandb, tmpdir):
|
||||
""" Test that log_model=True raises an error in offline mode """
|
||||
with pytest.raises(MisconfigurationException, match='checkpoints cannot be uploaded in offline mode'):
|
||||
logger = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True)
|
||||
|
|
|
@ -13,110 +13,140 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
|
||||
<<<<<<< HEAD
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
=======
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel, RandomDataset
|
||||
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self, expected_log_dir):
|
||||
super().__init__()
|
||||
self.expected_log_dir = expected_log_dir
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
assert self.trainer.log_dir == self.expected_log_dir
|
||||
return super().training_step(*args, **kwargs)
|
||||
>>>>>>> 793fe736 (Fix log_dir property (#5537))
|
||||
|
||||
|
||||
def test_logdir(tmpdir):
|
||||
"""
|
||||
Tests that the path is correct when checkpoint and loggers are used
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
expected = os.path.join(tmpdir, 'lightning_logs', 'version_0')
|
||||
|
||||
expected = os.path.join(self.trainer.default_root_dir, 'lightning_logs', 'version_0')
|
||||
assert self.trainer.log_dir == expected
|
||||
return {"loss": loss}
|
||||
model = TestModel(expected)
|
||||
|
||||
model = TestModel()
|
||||
|
||||
limit_train_batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
max_steps=2,
|
||||
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
|
||||
)
|
||||
|
||||
assert trainer.log_dir == expected
|
||||
trainer.fit(model)
|
||||
assert trainer.log_dir == expected
|
||||
|
||||
|
||||
def test_logdir_no_checkpoint_cb(tmpdir):
|
||||
"""
|
||||
Tests that the path is correct with no checkpoint
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
expected = os.path.join(self.trainer.default_root_dir, 'lightning_logs', 'version_0')
|
||||
assert self.trainer.log_dir == expected
|
||||
return {"loss": loss}
|
||||
expected = os.path.join(tmpdir, 'lightning_logs', 'version_0')
|
||||
model = TestModel(expected)
|
||||
|
||||
model = TestModel()
|
||||
|
||||
limit_train_batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
max_steps=2,
|
||||
checkpoint_callback=False
|
||||
)
|
||||
|
||||
assert trainer.log_dir == expected
|
||||
trainer.fit(model)
|
||||
assert trainer.log_dir == expected
|
||||
|
||||
|
||||
def test_logdir_no_logger(tmpdir):
|
||||
"""
|
||||
Tests that the path is correct even when there is no logger
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
expected = os.path.join(self.trainer.default_root_dir)
|
||||
assert self.trainer.log_dir == expected
|
||||
return {"loss": loss}
|
||||
expected = os.path.join(tmpdir)
|
||||
model = TestModel(expected)
|
||||
|
||||
model = TestModel()
|
||||
|
||||
limit_train_batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
max_steps=2,
|
||||
logger=False,
|
||||
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
|
||||
)
|
||||
|
||||
assert trainer.log_dir == expected
|
||||
trainer.fit(model)
|
||||
assert trainer.log_dir == expected
|
||||
|
||||
|
||||
def test_logdir_no_logger_no_checkpoint(tmpdir):
|
||||
"""
|
||||
Tests that the path is correct even when there is no logger
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
expected = os.path.join(self.trainer.default_root_dir)
|
||||
assert self.trainer.log_dir == expected
|
||||
return {"loss": loss}
|
||||
expected = os.path.join(tmpdir)
|
||||
model = TestModel(expected)
|
||||
|
||||
model = TestModel()
|
||||
|
||||
limit_train_batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=limit_train_batches,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
max_steps=2,
|
||||
logger=False,
|
||||
checkpoint_callback=False
|
||||
)
|
||||
|
||||
assert trainer.log_dir == expected
|
||||
trainer.fit(model)
|
||||
assert trainer.log_dir == expected
|
||||
|
||||
|
||||
def test_logdir_custom_callback(tmpdir):
|
||||
"""
|
||||
Tests that the path is correct even when there is a custom callback
|
||||
"""
|
||||
expected = os.path.join(tmpdir, 'lightning_logs', 'version_0')
|
||||
model = TestModel(expected)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=2,
|
||||
callbacks=[ModelCheckpoint(dirpath=os.path.join(tmpdir, 'ckpts'))],
|
||||
)
|
||||
|
||||
assert trainer.log_dir == expected
|
||||
trainer.fit(model)
|
||||
assert trainer.log_dir == expected
|
||||
|
||||
|
||||
def test_logdir_custom_logger(tmpdir):
|
||||
"""
|
||||
Tests that the path is correct even when there is a custom logger
|
||||
"""
|
||||
expected = os.path.join(tmpdir, 'custom_logs', 'version_0')
|
||||
model = TestModel(expected)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=2,
|
||||
callbacks=[ModelCheckpoint(dirpath=tmpdir)],
|
||||
logger=TensorBoardLogger(save_dir=tmpdir, name='custom_logs')
|
||||
)
|
||||
|
||||
assert trainer.log_dir == expected
|
||||
trainer.fit(model)
|
||||
assert trainer.log_dir == expected
|
||||
|
|
Loading…
Reference in New Issue