From 7e10f6d41fee5c1ac4dceab8abc0177955239094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 17 Dec 2021 17:00:27 +0100 Subject: [PATCH] Save the loop progress state by default (#10784) --- CHANGELOG.md | 3 +++ pytorch_lightning/callbacks/timer.py | 7 +++++++ pytorch_lightning/core/lightning.py | 4 ++-- pytorch_lightning/loops/base.py | 13 +++++++++++-- .../trainer/connectors/checkpoint_connector.py | 3 +-- .../trainer/connectors/logger_connector/result.py | 4 ++-- tests/callbacks/test_timer.py | 9 +++------ tests/loops/test_loop_state_dict.py | 3 +++ tests/loops/test_loops.py | 1 + tests/models/test_hooks.py | 2 ++ tests/trainer/test_trainer.py | 2 +- 11 files changed, 36 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 482525fd35..91714ced49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) +- Save the `Loop`'s state by default in the checkpoint ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784)) + + - Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324)) diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index 810439b15b..86c84d61e0 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -142,6 +142,13 @@ class Timer(Callback): def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._end_time[RunningStage.TESTING] = time.monotonic() + def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: + # this checks the time after the state is reloaded, regardless of the interval. + # this is necessary in case we load a state whose timer is already depleted + if self._duration is None: + return + self._check_time_remaining(trainer) + def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: if self._interval != Interval.step or self._duration is None: return diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index fd285f7139..c7b6d1ced3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -303,7 +303,7 @@ class LightningModule( add_dataloader_idx: bool = True, batch_size: Optional[int] = None, metric_attribute: Optional[str] = None, - rank_zero_only: Optional[bool] = None, + rank_zero_only: bool = False, ) -> None: """Log a key, value pair. @@ -441,7 +441,7 @@ class LightningModule( sync_dist_group: Optional[Any] = None, add_dataloader_idx: bool = True, batch_size: Optional[int] = None, - rank_zero_only: Optional[bool] = None, + rank_zero_only: bool = False, ) -> None: """Log a dictionary of values at once. diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 1b3d332a78..d1bc1d5574 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -21,6 +21,7 @@ from torchmetrics import Metric import pytorch_lightning as pl from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException T = TypeVar("T") # the output type of `run` @@ -273,9 +274,11 @@ class Loop(ABC, Generic[T]): destination[prefix + "state_dict"] = self.on_save_checkpoint() + # do not get the mode from `self.trainer` because it might not have been attached yet + ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled for k, v in self.__dict__.items(): key = prefix + k - if isinstance(v, BaseProgress): + if ft_enabled and isinstance(v, BaseProgress): destination[key] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, key + ".") @@ -302,6 +305,10 @@ class Loop(ABC, Generic[T]): def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None: for k, v in self.__dict__.items(): key = prefix + k + if key not in state_dict: + # no state for this object, maybe we are loading an old checkpoint + continue + if isinstance(v, BaseProgress): v.load_state_dict(state_dict[key]) elif ( @@ -330,4 +337,6 @@ class Loop(ABC, Generic[T]): v.reset(metrics=False) self.on_load_checkpoint(state_dict[prefix + "state_dict"]) - self.restarting = True + + if _FaultTolerantMode.detect_current_mode().is_enabled: + self.restarting = True diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index a37cffb320..5ef468462d 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -373,9 +373,8 @@ class CheckpointConnector: "global_step": global_step, "pytorch-lightning_version": pl.__version__, "state_dict": self._get_lightning_module_state_dict(), + "loops": self._get_loops_state_dict(), } - if _fault_tolerant_training(): - checkpoint["loops"] = self._get_loops_state_dict() if not weights_only: # dump callbacks diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 9780bb3fd8..7dfc4622ce 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -616,8 +616,8 @@ class ResultCollection(dict): def sync(self) -> None: for result_metric in self.result_metrics: - if result_metric.is_tensor: - result_metric.sync() + if result_metric.is_tensor and not result_metric._is_synced: + result_metric.sync(should_sync=not result_metric.meta.sync.rank_zero_only) def unsync(self) -> None: for result_metric in self.result_metrics: diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index a1a8af0642..c7c4e0458e 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -168,15 +168,12 @@ def test_timer_resume_training(tmpdir): assert trainer.current_epoch < 99 saved_global_step = trainer.global_step - # resume training (with depleted timer + # resume training (with depleted timer) timer = Timer(duration=timedelta(milliseconds=200)) - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[timer, checkpoint_callback], - ) + trainer = Trainer(default_root_dir=tmpdir, callbacks=timer) trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path) assert timer._offset > 0 - assert trainer.global_step == saved_global_step + 1 + assert trainer.global_step == saved_global_step @RunIf(skip_windows=True) diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index ed4f5169cb..ced392f8c8 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from unittest import mock from unittest.mock import Mock import pytest @@ -37,6 +39,7 @@ def test_loops_state_dict(): assert fit_loop.state_dict() == new_fit_loop.state_dict() +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_loops_state_dict_structure(): trainer = Trainer() trainer.train_dataloader = Mock() diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 08ef7153e6..be9989d062 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -213,6 +213,7 @@ def test_loop_restore(): assert loop.outputs == list(range(10)) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_loop_hierarchy(): @dataclass class SimpleProgress(BaseProgress): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 11cf24ae70..09afc5b169 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -497,6 +497,7 @@ def _run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization): "optimizer_states": ANY, "pytorch-lightning_version": __version__, "state_dict": ANY, + "loops": ANY, } if kwargs.get("amp_backend") == "native": saved_ckpt["native_amp_scaling_state"] = ANY @@ -624,6 +625,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): "optimizer_states": ANY, "pytorch-lightning_version": __version__, "state_dict": ANY, + "loops": ANY, } # TODO: wrong saved epoch, should be 0 saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2} diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 762a42bebe..cad75fb234 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -660,7 +660,7 @@ def test_benchmark_option(tmpdir): @pytest.mark.parametrize("ckpt_path", (None, "best", "specific")) @pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2)) @pytest.mark.parametrize("fn", ("validate", "test", "predict")) -def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn): +def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn): class TestModel(BoringModel): def validation_step(self, batch, batch_idx): self.log("foo", -batch_idx)