Remove checkpoint tracking from internal debugger (#9326)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
ca679cd78f
commit
91ce0d0a99
|
@ -492,9 +492,6 @@ class ModelCheckpoint(Callback):
|
||||||
log.debug(f"Removed checkpoint: {filepath}")
|
log.debug(f"Removed checkpoint: {filepath}")
|
||||||
|
|
||||||
def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
|
def _save_model(self, trainer: "pl.Trainer", filepath: str) -> None:
|
||||||
# in debugging, track when we save checkpoints
|
|
||||||
trainer.dev_debugger.track_checkpointing_history(filepath)
|
|
||||||
|
|
||||||
# make paths
|
# make paths
|
||||||
if trainer.should_rank_save_checkpoint:
|
if trainer.should_rank_save_checkpoint:
|
||||||
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
|
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||||
|
|
|
@ -43,7 +43,6 @@ class InternalDebugger:
|
||||||
def __init__(self, trainer: "pl.Trainer") -> None:
|
def __init__(self, trainer: "pl.Trainer") -> None:
|
||||||
self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1"
|
self.enabled = os.environ.get("PL_DEV_DEBUG", "0") == "1"
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
self.checkpoint_callback_history: List[Dict[str, Any]] = []
|
|
||||||
self.events: List[Dict[str, Any]] = []
|
self.events: List[Dict[str, Any]] = []
|
||||||
self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = []
|
self.saved_lr_scheduler_updates: List[Dict[str, Union[int, float, str, torch.Tensor, None]]] = []
|
||||||
self.train_dataloader_calls: List[Dict[str, Any]] = []
|
self.train_dataloader_calls: List[Dict[str, Any]] = []
|
||||||
|
@ -124,15 +123,3 @@ class InternalDebugger:
|
||||||
"new_lr": new_lr,
|
"new_lr": new_lr,
|
||||||
}
|
}
|
||||||
self.saved_lr_scheduler_updates.append(loss_dict)
|
self.saved_lr_scheduler_updates.append(loss_dict)
|
||||||
|
|
||||||
@enabled_only
|
|
||||||
def track_checkpointing_history(self, filepath: str) -> None:
|
|
||||||
cb = self.trainer.checkpoint_callback
|
|
||||||
debug_dict = {
|
|
||||||
"epoch": self.trainer.current_epoch,
|
|
||||||
"global_step": self.trainer.global_step,
|
|
||||||
"monitor": cb.monitor if cb is not None else None,
|
|
||||||
"rank": self.trainer.global_rank,
|
|
||||||
"filepath": filepath,
|
|
||||||
}
|
|
||||||
self.checkpoint_callback_history.append(debug_dict)
|
|
||||||
|
|
|
@ -17,34 +17,17 @@ from unittest import mock
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_lightning import callbacks, seed_everything, Trainer
|
from pytorch_lightning import callbacks, Trainer
|
||||||
from tests.helpers import BoringModel
|
from tests.helpers import BoringModel
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
|
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
def test_checkpoint_callback_disabled(tmpdir):
|
||||||
def test_mc_called(tmpdir):
|
|
||||||
seed_everything(1234)
|
|
||||||
|
|
||||||
# -----------------
|
|
||||||
# TRAIN LOOP ONLY
|
|
||||||
# -----------------
|
|
||||||
train_step_only_model = BoringModel()
|
|
||||||
train_step_only_model.validation_step = None
|
|
||||||
|
|
||||||
# no callback
|
# no callback
|
||||||
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
|
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
|
||||||
trainer.fit(train_step_only_model)
|
assert not trainer.checkpoint_callbacks
|
||||||
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
|
trainer.fit(BoringModel())
|
||||||
|
assert not trainer.checkpoint_callbacks
|
||||||
# -----------------
|
|
||||||
# TRAIN + VAL LOOP ONLY
|
|
||||||
# -----------------
|
|
||||||
val_train_model = BoringModel()
|
|
||||||
# no callback
|
|
||||||
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
|
|
||||||
trainer.fit(val_train_model)
|
|
||||||
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("torch.save")
|
@mock.patch("torch.save")
|
||||||
|
|
|
@ -23,7 +23,7 @@ from logging import INFO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import Mock
|
from unittest.mock import call, MagicMock, Mock, patch
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -752,7 +752,6 @@ def test_ckpt_metric_names(tmpdir):
|
||||||
assert len(val) > 3
|
assert len(val) > 3
|
||||||
|
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
||||||
def test_default_checkpoint_behavior(tmpdir):
|
def test_default_checkpoint_behavior(tmpdir):
|
||||||
seed_everything(1234)
|
seed_everything(1234)
|
||||||
|
|
||||||
|
@ -761,14 +760,20 @@ def test_default_checkpoint_behavior(tmpdir):
|
||||||
default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=0, limit_train_batches=5, limit_val_batches=5
|
default_root_dir=tmpdir, max_epochs=3, progress_bar_refresh_rate=0, limit_train_batches=5, limit_val_batches=5
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(model)
|
with patch.object(ModelCheckpoint, "_save_model", wraps=trainer.checkpoint_callback._save_model) as save_mock:
|
||||||
results = trainer.test()
|
trainer.fit(model)
|
||||||
|
results = trainer.test()
|
||||||
|
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3
|
save_dir = tmpdir / "lightning_logs" / "version_0" / "checkpoints"
|
||||||
|
save_mock.assert_has_calls(
|
||||||
# make sure the checkpoint we saved has the metric in the name
|
[
|
||||||
ckpts = os.listdir(os.path.join(tmpdir, "lightning_logs", "version_0", "checkpoints"))
|
call(trainer, save_dir / "epoch=0-step=4.ckpt"),
|
||||||
|
call(trainer, save_dir / "epoch=1-step=9.ckpt"),
|
||||||
|
call(trainer, save_dir / "epoch=2-step=14.ckpt"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
ckpts = os.listdir(save_dir)
|
||||||
assert len(ckpts) == 1
|
assert len(ckpts) == 1
|
||||||
assert ckpts[0] == "epoch=2-step=14.ckpt"
|
assert ckpts[0] == "epoch=2-step=14.ckpt"
|
||||||
|
|
||||||
|
@ -834,9 +839,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
||||||
assert w0.eq(w1).all()
|
assert w0.eq(w1).all()
|
||||||
|
|
||||||
|
|
||||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
|
||||||
@pytest.mark.parametrize("mode", ["min", "max"])
|
@pytest.mark.parametrize("mode", ["min", "max"])
|
||||||
def test_checkpointing_with_nan_as_first(tmpdir, mode: int):
|
def test_checkpointing_with_nan_as_first(tmpdir, mode):
|
||||||
monitor = [float("nan")]
|
monitor = [float("nan")]
|
||||||
monitor += [5, 7, 8] if mode == "max" else [8, 7, 5]
|
monitor += [5, 7, 8] if mode == "max" else [8, 7, 5]
|
||||||
|
|
||||||
|
@ -847,8 +851,11 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode: int):
|
||||||
|
|
||||||
model = CurrentModel()
|
model = CurrentModel()
|
||||||
|
|
||||||
|
callback = ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir)
|
||||||
|
callback._save_model = MagicMock()
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
callbacks=[ModelCheckpoint(monitor="abc", mode=mode, save_top_k=1, dirpath=tmpdir)],
|
callbacks=[callback],
|
||||||
default_root_dir=tmpdir,
|
default_root_dir=tmpdir,
|
||||||
val_check_interval=1.0,
|
val_check_interval=1.0,
|
||||||
max_epochs=len(monitor),
|
max_epochs=len(monitor),
|
||||||
|
@ -856,7 +863,8 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode: int):
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
# check that last one is also the best one
|
# check that last one is also the best one
|
||||||
assert trainer.dev_debugger.checkpoint_callback_history[-1]["epoch"] == len(monitor) - 1
|
assert callback._save_model.call_count == len(monitor)
|
||||||
|
assert mode == "min" and callback.best_model_score == 5 or mode == "max" and callback.best_model_score == 8
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_repeated_strategy(tmpdir):
|
def test_checkpoint_repeated_strategy(tmpdir):
|
||||||
|
|
|
@ -64,6 +64,7 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
|
||||||
return super().test_step(batch, batch_idx)
|
return super().test_step(batch, batch_idx)
|
||||||
|
|
||||||
checkpoint_callback = ModelCheckpoint()
|
checkpoint_callback = ModelCheckpoint()
|
||||||
|
checkpoint_callback.save_checkpoint = Mock()
|
||||||
early_stopping_callback = EarlyStopping()
|
early_stopping_callback = EarlyStopping()
|
||||||
early_stopping_callback._evaluate_stopping_criteria = Mock()
|
early_stopping_callback._evaluate_stopping_criteria = Mock()
|
||||||
trainer_config = dict(
|
trainer_config = dict(
|
||||||
|
@ -95,8 +96,8 @@ def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
|
||||||
|
|
||||||
# checkpoint callback should not have been called with fast_dev_run
|
# checkpoint callback should not have been called with fast_dev_run
|
||||||
assert trainer.checkpoint_callback == checkpoint_callback
|
assert trainer.checkpoint_callback == checkpoint_callback
|
||||||
|
checkpoint_callback.save_checkpoint.assert_not_called()
|
||||||
assert not os.path.exists(checkpoint_callback.dirpath)
|
assert not os.path.exists(checkpoint_callback.dirpath)
|
||||||
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
|
|
||||||
|
|
||||||
# early stopping should not have been called with fast_dev_run
|
# early stopping should not have been called with fast_dev_run
|
||||||
assert trainer.early_stopping_callback == early_stopping_callback
|
assert trainer.early_stopping_callback == early_stopping_callback
|
||||||
|
|
Loading…
Reference in New Issue