Remove checkpoint tracking from internal debugger (#9326)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2021-09-08 02:42:31 +02:00 committed by GitHub
parent ca679cd78f
commit 91ce0d0a99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 27 additions and 51 deletions

View File

@ -492,9 +492,6 @@ class ModelCheckpoint(Callback):
log.debug(f"Removed checkpoint: {filepath}") log.debug(f"Removed checkpoint: {filepath}")
def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None: def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)
# make paths # make paths
if trainer.should_rank_save_checkpoint: if trainer.should_rank_save_checkpoint:
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

View File

@ -43,7 +43,6 @@ class InternalDebugger:
def __init__(self, trainer: "pl.Trainer") -> None: def __init__(self, trainer: "pl.Trainer") -> None:
self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1" self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1"
self.trainer = trainer self.trainer = trainer
self.checkpoint_callback_history: List[Dict[str, Any]] = []
self.events: List[Dict[str, Any]] = [] self.events: List[Dict[str, Any]] = []
self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = [] self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = []
self.train_dataloader_calls: List[Dict[str, Any]] = [] self.train_dataloader_calls: List[Dict[str, Any]] = []
@ -124,15 +123,3 @@ class InternalDebugger:
"new_lr": new_lr, "new_lr": new_lr,
} }
self.saved_lr_scheduler_updates.append(loss_dict) self.saved_lr_scheduler_updates.append(loss_dict)
@enabled_only
def track_checkpointing_history(self, filepath: str) -> None:
cb = self.trainer.checkpoint_callback
debug_dict = {
"epoch": self.trainer.current_epoch,
"global_step": self.trainer.global_step,
"monitor": cb.monitor if cb is not None else None,
"rank": self.trainer.global_rank,
"filepath": filepath,
}
self.checkpoint_callback_history.append(debug_dict)

View File

@ -17,34 +17,17 @@ from unittest import mock
import pytest import pytest
import torch import torch
from pytorch_lightning import callbacks, seed_everything, Trainer from pytorch_lightning import callbacks, Trainer
from tests.helpers import BoringModel from tests.helpers import BoringModel
from tests.helpers.runif import RunIf from tests.helpers.runif import RunIf
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_checkpoint_callback_disabled(tmpdir):
def test_mc_called(tmpdir):
seed_everything(1234)
# -----------------
# TRAIN LOOP ONLY
# -----------------
train_step_only_model = BoringModel()
train_step_only_model.validation_step = None
# no callback # no callback
trainer = Trainer(max_epochs=3, checkpoint_callback=False) trainer = Trainer(max_epochs=3, checkpoint_callback=False)
trainer.fit(train_step_only_model) assert not trainer.checkpoint_callbacks
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 trainer.fit(BoringModel())
assert not trainer.checkpoint_callbacks
# -----------------
# TRAIN + VAL LOOP ONLY
# -----------------
val_train_model = BoringModel()
# no callback
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
trainer.fit(val_train_model)
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
@mock.patch("torch.save") @mock.patch("torch.save")

View File

@ -23,7 +23,7 @@ from logging import INFO
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from unittest import mock from unittest import mock
from unittest.mock import Mock from unittest.mock import call, MagicMock, Mock, patch
import cloudpickle import cloudpickle
import pytest import pytest
@ -752,7 +752,6 @@ def test_ckpt_metric_names(tmpdir):
assert len(val) > 3 assert len(val) > 3
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_default_checkpoint_behavior(tmpdir): def test_default_checkpoint_behavior(tmpdir):
seed_everything(1234) seed_everything(1234)
@ -761,14 +760,20 @@ def test_default_checkpoint_behavior(tmpdir):
default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=0, limit_train_batches=5, limit_val_batches=5 default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=0, limit_train_batches=5, limit_val_batches=5
) )
trainer.fit(model) with patch.object(ModelCheckpoint, "_save_model", wraps=trainer.checkpoint_callback._save_model) as save_mock:
results = trainer.test() trainer.fit(model)
results = trainer.test()
assert len(results) == 1 assert len(results) == 1
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3 save_dir = tmpdir / "lightning_logs" / "version_0" / "checkpoints"
save_mock.assert_has_calls(
# make sure the checkpoint we saved has the metric in the name [
ckpts = os.listdir(os.path.join(tmpdir, "lightning_logs", "version_0", "checkpoints")) call(trainer, save_dir / "epoch=0-step=4.ckpt"),
call(trainer, save_dir / "epoch=1-step=9.ckpt"),
call(trainer, save_dir / "epoch=2-step=14.ckpt"),
]
)
ckpts = os.listdir(save_dir)
assert len(ckpts) == 1 assert len(ckpts) == 1
assert ckpts[0] == "epoch=2-step=14.ckpt" assert ckpts[0] == "epoch=2-step=14.ckpt"
@ -834,9 +839,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
assert w0.eq(w1).all() assert w0.eq(w1).all()
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize("mode", ["min", "max"]) @pytest.mark.parametrize("mode", ["min", "max"])
def test_checkpointing_with_nan_as_first(tmpdir, mode: int): def test_checkpointing_with_nan_as_first(tmpdir, mode):
monitor = [float("nan")] monitor = [float("nan")]
monitor += [5, 7, 8] if mode == "max" else [8, 7, 5] monitor += [5, 7, 8] if mode == "max" else [8, 7, 5]
@ -847,8 +851,11 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode: int):
model = CurrentModel() model = CurrentModel()
callback = ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir)
callback._save_model = MagicMock()
trainer = Trainer( trainer = Trainer(
callbacks=[ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir)], callbacks=[callback],
default_root_dir=tmpdir, default_root_dir=tmpdir,
val_check_interval=1.0, val_check_interval=1.0,
max_epochs=len(monitor), max_epochs=len(monitor),
@ -856,7 +863,8 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode: int):
trainer.fit(model) trainer.fit(model)
# check that last one is also the best one # check that last one is also the best one
assert trainer.dev_debugger.checkpoint_callback_history[-1]["epoch"] == len(monitor) - 1 assert callback._save_model.call_count == len(monitor)
assert mode == "min" and callback.best_model_score == 5 or mode == "max" and callback.best_model_score == 8
def test_checkpoint_repeated_strategy(tmpdir): def test_checkpoint_repeated_strategy(tmpdir):

View File

@ -64,6 +64,7 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
return super().test_step(batch, batch_idx) return super().test_step(batch, batch_idx)
checkpoint_callback = ModelCheckpoint() checkpoint_callback = ModelCheckpoint()
checkpoint_callback.save_checkpoint = Mock()
early_stopping_callback = EarlyStopping() early_stopping_callback = EarlyStopping()
early_stopping_callback._evaluate_stopping_criteria = Mock() early_stopping_callback._evaluate_stopping_criteria = Mock()
trainer_config = dict( trainer_config = dict(
@ -95,8 +96,8 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
# checkpoint callback should not have been called with fast_dev_run # checkpoint callback should not have been called with fast_dev_run
assert trainer.checkpoint_callback == checkpoint_callback assert trainer.checkpoint_callback == checkpoint_callback
checkpoint_callback.save_checkpoint.assert_not_called()
assert not os.path.exists(checkpoint_callback.dirpath) assert not os.path.exists(checkpoint_callback.dirpath)
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
# early stopping should not have been called with fast_dev_run # early stopping should not have been called with fast_dev_run
assert trainer.early_stopping_callback == early_stopping_callback assert trainer.early_stopping_callback == early_stopping_callback