diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 62569843d2..cb4ef37b76 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -492,9 +492,6 @@ class ModelCheckpoint(Callback): log.debug(f"Removed checkpoint: {filepath}") def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None: - # in debugging, track when we save checkpoints - trainer.dev_debugger.track_checkpointing_history(filepath) - # make paths if trainer.should_rank_save_checkpoint: self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index e8942b730d..3860ff0ac0 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -43,7 +43,6 @@ class InternalDebugger: def __init__(self, trainer: "pl.Trainer") -> None: self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1" self.trainer = trainer - self.checkpoint_callback_history: List[Dict[str, Any]] = [] self.events: List[Dict[str, Any]] = [] self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = [] self.train_dataloader_calls: List[Dict[str, Any]] = [] @@ -124,15 +123,3 @@ class InternalDebugger: "new_lr": new_lr, } self.saved_lr_scheduler_updates.append(loss_dict) - - @enabled_only - def track_checkpointing_history(self, filepath: str) -> None: - cb = self.trainer.checkpoint_callback - debug_dict = { - "epoch": self.trainer.current_epoch, - "global_step": self.trainer.global_step, - "monitor": cb.monitor if cb is not None else None, - "rank": self.trainer.global_rank, - "filepath": filepath, - } - self.checkpoint_callback_history.append(debug_dict) diff --git a/tests/checkpointing/test_checkpoint_callback_frequency.py b/tests/checkpointing/test_checkpoint_callback_frequency.py index 2e7fcfb8c2..12ec14712f 100644 --- a/tests/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/checkpointing/test_checkpoint_callback_frequency.py @@ -17,34 +17,17 @@ from unittest import mock import pytest import torch -from pytorch_lightning import callbacks, seed_everything, Trainer +from pytorch_lightning import callbacks, Trainer from tests.helpers import BoringModel from tests.helpers.runif import RunIf -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_mc_called(tmpdir): - seed_everything(1234) - - # ----------------- - # TRAIN LOOP ONLY - # ----------------- - train_step_only_model = BoringModel() - train_step_only_model.validation_step = None - +def test_checkpoint_callback_disabled(tmpdir): # no callback trainer = Trainer(max_epochs=3, checkpoint_callback=False) - trainer.fit(train_step_only_model) - assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 - - # ----------------- - # TRAIN + VAL LOOP ONLY - # ----------------- - val_train_model = BoringModel() - # no callback - trainer = Trainer(max_epochs=3, checkpoint_callback=False) - trainer.fit(val_train_model) - assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 + assert not trainer.checkpoint_callbacks + trainer.fit(BoringModel()) + assert not trainer.checkpoint_callbacks @mock.patch("torch.save") diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index c7f4bd0e80..3c70bd6dae 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -23,7 +23,7 @@ from logging import INFO from pathlib import Path from typing import Union from unittest import mock -from unittest.mock import Mock +from unittest.mock import call, MagicMock, Mock, patch import cloudpickle import pytest @@ -752,7 +752,6 @@ def test_ckpt_metric_names(tmpdir): assert len(val) > 3 -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) def test_default_checkpoint_behavior(tmpdir): seed_everything(1234) @@ -761,14 +760,20 @@ def test_default_checkpoint_behavior(tmpdir): default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=0, limit_train_batches=5, limit_val_batches=5 ) - trainer.fit(model) - results = trainer.test() + with patch.object(ModelCheckpoint, "_save_model", wraps=trainer.checkpoint_callback._save_model) as save_mock: + trainer.fit(model) + results = trainer.test() assert len(results) == 1 - assert len(trainer.dev_debugger.checkpoint_callback_history) == 3 - - # make sure the checkpoint we saved has the metric in the name - ckpts = os.listdir(os.path.join(tmpdir, "lightning_logs", "version_0", "checkpoints")) + save_dir = tmpdir / "lightning_logs" / "version_0" / "checkpoints" + save_mock.assert_has_calls( + [ + call(trainer, save_dir / "epoch=0-step=4.ckpt"), + call(trainer, save_dir / "epoch=1-step=9.ckpt"), + call(trainer, save_dir / "epoch=2-step=14.ckpt"), + ] + ) + ckpts = os.listdir(save_dir) assert len(ckpts) == 1 assert ckpts[0] == "epoch=2-step=14.ckpt" @@ -834,9 +839,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert w0.eq(w1).all() -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize("mode", ["min", "max"]) -def test_checkpointing_with_nan_as_first(tmpdir, mode: int): +def test_checkpointing_with_nan_as_first(tmpdir, mode): monitor = [float("nan")] monitor += [5, 7, 8] if mode == "max" else [8, 7, 5] @@ -847,8 +851,11 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode: int): model = CurrentModel() + callback = ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir) + callback._save_model = MagicMock() + trainer = Trainer( - callbacks=[ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir)], + callbacks=[callback], default_root_dir=tmpdir, val_check_interval=1.0, max_epochs=len(monitor), @@ -856,7 +863,8 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode: int): trainer.fit(model) # check that last one is also the best one - assert trainer.dev_debugger.checkpoint_callback_history[-1]["epoch"] == len(monitor) - 1 + assert callback._save_model.call_count == len(monitor) + assert mode == "min" and callback.best_model_score == 5 or mode == "max" and callback.best_model_score == 8 def test_checkpoint_repeated_strategy(tmpdir): diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index f6c9ea0198..2816fe92bc 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -64,6 +64,7 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run): return super().test_step(batch, batch_idx) checkpoint_callback = ModelCheckpoint() + checkpoint_callback.save_checkpoint = Mock() early_stopping_callback = EarlyStopping() early_stopping_callback._evaluate_stopping_criteria = Mock() trainer_config = dict( @@ -95,8 +96,8 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run): # checkpoint callback should not have been called with fast_dev_run assert trainer.checkpoint_callback == checkpoint_callback + checkpoint_callback.save_checkpoint.assert_not_called() assert not os.path.exists(checkpoint_callback.dirpath) - assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 # early stopping should not have been called with fast_dev_run assert trainer.early_stopping_callback == early_stopping_callback