Save the loop progress state by default (#10784)

This commit is contained in:
Carlos Mocholí 2021-12-17 17:00:27 +01:00 committed by GitHub
parent fa6d17c96f
commit 7e10f6d41f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 36 additions and 15 deletions

View File

@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))
- Save the `Loop`'s state by default in the checkpoint ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784))
- Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324))

View File

@ -142,6 +142,13 @@ class Timer(Callback):
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._end_time[RunningStage.TESTING] = time.monotonic()
def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
# this checks the time after the state is reloaded, regardless of the interval.
# this is necessary in case we load a state whose timer is already depleted
if self._duration is None:
return
self._check_time_remaining(trainer)
def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
if self._interval != Interval.step or self._duration is None:
return

View File

@ -303,7 +303,7 @@ class LightningModule(
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
rank_zero_only: Optional[bool] = None,
rank_zero_only: bool = False,
) -> None:
"""Log a key, value pair.
@ -441,7 +441,7 @@ class LightningModule(
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
rank_zero_only: Optional[bool] = None,
rank_zero_only: bool = False,
) -> None:
"""Log a dictionary of values at once.

View File

@ -21,6 +21,7 @@ from torchmetrics import Metric
import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
T = TypeVar("T") # the output type of `run`
@ -273,9 +274,11 @@ class Loop(ABC, Generic[T]):
destination[prefix + "state_dict"] = self.on_save_checkpoint()
# do not get the mode from `self.trainer` because it might not have been attached yet
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
if ft_enabled and isinstance(v, BaseProgress):
destination[key] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, key + ".")
@ -302,6 +305,10 @@ class Loop(ABC, Generic[T]):
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
for k, v in self.__dict__.items():
key = prefix + k
if key not in state_dict:
# no state for this object, maybe we are loading an old checkpoint
continue
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[key])
elif (
@ -330,4 +337,6 @@ class Loop(ABC, Generic[T]):
v.reset(metrics=False)
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True
if _FaultTolerantMode.detect_current_mode().is_enabled:
self.restarting = True

View File

@ -373,9 +373,8 @@ class CheckpointConnector:
"global_step": global_step,
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
"loops": self._get_loops_state_dict(),
}
if _fault_tolerant_training():
checkpoint["loops"] = self._get_loops_state_dict()
if not weights_only:
# dump callbacks

View File

@ -616,8 +616,8 @@ class ResultCollection(dict):
def sync(self) -> None:
for result_metric in self.result_metrics:
if result_metric.is_tensor:
result_metric.sync()
if result_metric.is_tensor and not result_metric._is_synced:
result_metric.sync(should_sync=not result_metric.meta.sync.rank_zero_only)
def unsync(self) -> None:
for result_metric in self.result_metrics:

View File

@ -168,15 +168,12 @@ def test_timer_resume_training(tmpdir):
assert trainer.current_epoch < 99
saved_global_step = trainer.global_step
# resume training (with depleted timer
# resume training (with depleted timer)
timer = Timer(duration=timedelta(milliseconds=200))
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[timer, checkpoint_callback],
)
trainer = Trainer(default_root_dir=tmpdir, callbacks=timer)
trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path)
assert timer._offset > 0
assert trainer.global_step == saved_global_step + 1
assert trainer.global_step == saved_global_step
@RunIf(skip_windows=True)

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import Mock
import pytest
@ -37,6 +39,7 @@ def test_loops_state_dict():
assert fit_loop.state_dict() == new_fit_loop.state_dict()
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_loops_state_dict_structure():
trainer = Trainer()
trainer.train_dataloader = Mock()

View File

@ -213,6 +213,7 @@ def test_loop_restore():
assert loop.outputs == list(range(10))
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_loop_hierarchy():
@dataclass
class SimpleProgress(BaseProgress):

View File

@ -497,6 +497,7 @@ def _run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization):
"optimizer_states": ANY,
"pytorch-lightning_version": __version__,
"state_dict": ANY,
"loops": ANY,
}
if kwargs.get("amp_backend") == "native":
saved_ckpt["native_amp_scaling_state"] = ANY
@ -624,6 +625,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
"optimizer_states": ANY,
"pytorch-lightning_version": __version__,
"state_dict": ANY,
"loops": ANY,
}
# TODO: wrong saved epoch, should be 0
saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2}

View File

@ -660,7 +660,7 @@ def test_benchmark_option(tmpdir):
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2))
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn):
def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", -batch_idx)