Remove checkpoint tracking from internal debugger (#9326)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2021-09-08 02:42:31 +02:00 committed by GitHub
parent ca679cd78f
commit 91ce0d0a99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 27 additions and 51 deletions

View File

@ -492,9 +492,6 @@ class ModelCheckpoint(Callback):
log.debug(f"Removed checkpoint: {filepath}")
def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)
# make paths
if trainer.should_rank_save_checkpoint:
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

View File

@ -43,7 +43,6 @@ class InternalDebugger:
def __init__(self, trainer: "pl.Trainer") -> None:
self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1"
self.trainer = trainer
self.checkpoint_callback_history: List[Dict[str, Any]] = []
self.events: List[Dict[str, Any]] = []
self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = []
self.train_dataloader_calls: List[Dict[str, Any]] = []
@ -124,15 +123,3 @@ class InternalDebugger:
"new_lr": new_lr,
}
self.saved_lr_scheduler_updates.append(loss_dict)
@enabled_only
def track_checkpointing_history(self, filepath: str) -> None:
cb = self.trainer.checkpoint_callback
debug_dict = {
"epoch": self.trainer.current_epoch,
"global_step": self.trainer.global_step,
"monitor": cb.monitor if cb is not None else None,
"rank": self.trainer.global_rank,
"filepath": filepath,
}
self.checkpoint_callback_history.append(debug_dict)

View File

@ -17,34 +17,17 @@ from unittest import mock
import pytest
import torch
from pytorch_lightning import callbacks, seed_everything, Trainer
from pytorch_lightning import callbacks, Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_mc_called(tmpdir):
seed_everything(1234)
# -----------------
# TRAIN LOOP ONLY
# -----------------
train_step_only_model = BoringModel()
train_step_only_model.validation_step = None
def test_checkpoint_callback_disabled(tmpdir):
# no callback
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
trainer.fit(train_step_only_model)
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
# -----------------
# TRAIN + VAL LOOP ONLY
# -----------------
val_train_model = BoringModel()
# no callback
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
trainer.fit(val_train_model)
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
assert not trainer.checkpoint_callbacks
trainer.fit(BoringModel())
assert not trainer.checkpoint_callbacks
@mock.patch("torch.save")

View File

@ -23,7 +23,7 @@ from logging import INFO
from pathlib import Path
from typing import Union
from unittest import mock
from unittest.mock import Mock
from unittest.mock import call, MagicMock, Mock, patch
import cloudpickle
import pytest
@ -752,7 +752,6 @@ def test_ckpt_metric_names(tmpdir):
assert len(val) > 3
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_default_checkpoint_behavior(tmpdir):
seed_everything(1234)
@ -761,14 +760,20 @@ def test_default_checkpoint_behavior(tmpdir):
default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=0, limit_train_batches=5, limit_val_batches=5
)
trainer.fit(model)
results = trainer.test()
with patch.object(ModelCheckpoint, "_save_model", wraps=trainer.checkpoint_callback._save_model) as save_mock:
trainer.fit(model)
results = trainer.test()
assert len(results) == 1
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3
# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(os.path.join(tmpdir, "lightning_logs", "version_0", "checkpoints"))
save_dir = tmpdir / "lightning_logs" / "version_0" / "checkpoints"
save_mock.assert_has_calls(
[
call(trainer, save_dir / "epoch=0-step=4.ckpt"),
call(trainer, save_dir / "epoch=1-step=9.ckpt"),
call(trainer, save_dir / "epoch=2-step=14.ckpt"),
]
)
ckpts = os.listdir(save_dir)
assert len(ckpts) == 1
assert ckpts[0] == "epoch=2-step=14.ckpt"
@ -834,9 +839,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
assert w0.eq(w1).all()
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize("mode", ["min", "max"])
def test_checkpointing_with_nan_as_first(tmpdir, mode: int):
def test_checkpointing_with_nan_as_first(tmpdir, mode):
monitor = [float("nan")]
monitor += [5, 7, 8] if mode == "max" else [8, 7, 5]
@ -847,8 +851,11 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode: int):
model = CurrentModel()
callback = ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir)
callback._save_model = MagicMock()
trainer = Trainer(
callbacks=[ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir)],
callbacks=[callback],
default_root_dir=tmpdir,
val_check_interval=1.0,
max_epochs=len(monitor),
@ -856,7 +863,8 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode: int):
trainer.fit(model)
# check that last one is also the best one
assert trainer.dev_debugger.checkpoint_callback_history[-1]["epoch"] == len(monitor) - 1
assert callback._save_model.call_count == len(monitor)
assert mode == "min" and callback.best_model_score == 5 or mode == "max" and callback.best_model_score == 8
def test_checkpoint_repeated_strategy(tmpdir):

View File

@ -64,6 +64,7 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
return super().test_step(batch, batch_idx)
checkpoint_callback = ModelCheckpoint()
checkpoint_callback.save_checkpoint = Mock()
early_stopping_callback = EarlyStopping()
early_stopping_callback._evaluate_stopping_criteria = Mock()
trainer_config = dict(
@ -95,8 +96,8 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
# checkpoint callback should not have been called with fast_dev_run
assert trainer.checkpoint_callback == checkpoint_callback
checkpoint_callback.save_checkpoint.assert_not_called()
assert not os.path.exists(checkpoint_callback.dirpath)
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
# early stopping should not have been called with fast_dev_run
assert trainer.early_stopping_callback == early_stopping_callback