Save the loop progress state by default (#10784)
This commit is contained in:
parent
fa6d17c96f
commit
7e10f6d41f
|
@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))
|
||||
|
||||
|
||||
- Save the `Loop`'s state by default in the checkpoint ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784))
|
||||
|
||||
|
||||
- Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324))
|
||||
|
||||
|
||||
|
|
|
@ -142,6 +142,13 @@ class Timer(Callback):
|
|||
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._end_time[RunningStage.TESTING] = time.monotonic()
|
||||
|
||||
def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
||||
# this checks the time after the state is reloaded, regardless of the interval.
|
||||
# this is necessary in case we load a state whose timer is already depleted
|
||||
if self._duration is None:
|
||||
return
|
||||
self._check_time_remaining(trainer)
|
||||
|
||||
def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
||||
if self._interval != Interval.step or self._duration is None:
|
||||
return
|
||||
|
|
|
@ -303,7 +303,7 @@ class LightningModule(
|
|||
add_dataloader_idx: bool = True,
|
||||
batch_size: Optional[int] = None,
|
||||
metric_attribute: Optional[str] = None,
|
||||
rank_zero_only: Optional[bool] = None,
|
||||
rank_zero_only: bool = False,
|
||||
) -> None:
|
||||
"""Log a key, value pair.
|
||||
|
||||
|
@ -441,7 +441,7 @@ class LightningModule(
|
|||
sync_dist_group: Optional[Any] = None,
|
||||
add_dataloader_idx: bool = True,
|
||||
batch_size: Optional[int] = None,
|
||||
rank_zero_only: Optional[bool] = None,
|
||||
rank_zero_only: bool = False,
|
||||
) -> None:
|
||||
"""Log a dictionary of values at once.
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from torchmetrics import Metric
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
T = TypeVar("T") # the output type of `run`
|
||||
|
@ -273,9 +274,11 @@ class Loop(ABC, Generic[T]):
|
|||
|
||||
destination[prefix + "state_dict"] = self.on_save_checkpoint()
|
||||
|
||||
# do not get the mode from `self.trainer` because it might not have been attached yet
|
||||
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
|
||||
for k, v in self.__dict__.items():
|
||||
key = prefix + k
|
||||
if isinstance(v, BaseProgress):
|
||||
if ft_enabled and isinstance(v, BaseProgress):
|
||||
destination[key] = v.state_dict()
|
||||
elif isinstance(v, Loop):
|
||||
v.state_dict(destination, key + ".")
|
||||
|
@ -302,6 +305,10 @@ class Loop(ABC, Generic[T]):
|
|||
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
|
||||
for k, v in self.__dict__.items():
|
||||
key = prefix + k
|
||||
if key not in state_dict:
|
||||
# no state for this object, maybe we are loading an old checkpoint
|
||||
continue
|
||||
|
||||
if isinstance(v, BaseProgress):
|
||||
v.load_state_dict(state_dict[key])
|
||||
elif (
|
||||
|
@ -330,4 +337,6 @@ class Loop(ABC, Generic[T]):
|
|||
v.reset(metrics=False)
|
||||
|
||||
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
|
||||
self.restarting = True
|
||||
|
||||
if _FaultTolerantMode.detect_current_mode().is_enabled:
|
||||
self.restarting = True
|
||||
|
|
|
@ -373,9 +373,8 @@ class CheckpointConnector:
|
|||
"global_step": global_step,
|
||||
"pytorch-lightning_version": pl.__version__,
|
||||
"state_dict": self._get_lightning_module_state_dict(),
|
||||
"loops": self._get_loops_state_dict(),
|
||||
}
|
||||
if _fault_tolerant_training():
|
||||
checkpoint["loops"] = self._get_loops_state_dict()
|
||||
|
||||
if not weights_only:
|
||||
# dump callbacks
|
||||
|
|
|
@ -616,8 +616,8 @@ class ResultCollection(dict):
|
|||
|
||||
def sync(self) -> None:
|
||||
for result_metric in self.result_metrics:
|
||||
if result_metric.is_tensor:
|
||||
result_metric.sync()
|
||||
if result_metric.is_tensor and not result_metric._is_synced:
|
||||
result_metric.sync(should_sync=not result_metric.meta.sync.rank_zero_only)
|
||||
|
||||
def unsync(self) -> None:
|
||||
for result_metric in self.result_metrics:
|
||||
|
|
|
@ -168,15 +168,12 @@ def test_timer_resume_training(tmpdir):
|
|||
assert trainer.current_epoch < 99
|
||||
saved_global_step = trainer.global_step
|
||||
|
||||
# resume training (with depleted timer
|
||||
# resume training (with depleted timer)
|
||||
timer = Timer(duration=timedelta(milliseconds=200))
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[timer, checkpoint_callback],
|
||||
)
|
||||
trainer = Trainer(default_root_dir=tmpdir, callbacks=timer)
|
||||
trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path)
|
||||
assert timer._offset > 0
|
||||
assert trainer.global_step == saved_global_step + 1
|
||||
assert trainer.global_step == saved_global_step
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
@ -37,6 +39,7 @@ def test_loops_state_dict():
|
|||
assert fit_loop.state_dict() == new_fit_loop.state_dict()
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
def test_loops_state_dict_structure():
|
||||
trainer = Trainer()
|
||||
trainer.train_dataloader = Mock()
|
||||
|
|
|
@ -213,6 +213,7 @@ def test_loop_restore():
|
|||
assert loop.outputs == list(range(10))
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
def test_loop_hierarchy():
|
||||
@dataclass
|
||||
class SimpleProgress(BaseProgress):
|
||||
|
|
|
@ -497,6 +497,7 @@ def _run_trainer_model_hook_system_fit(kwargs, tmpdir, automatic_optimization):
|
|||
"optimizer_states": ANY,
|
||||
"pytorch-lightning_version": __version__,
|
||||
"state_dict": ANY,
|
||||
"loops": ANY,
|
||||
}
|
||||
if kwargs.get("amp_backend") == "native":
|
||||
saved_ckpt["native_amp_scaling_state"] = ANY
|
||||
|
@ -624,6 +625,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
|
|||
"optimizer_states": ANY,
|
||||
"pytorch-lightning_version": __version__,
|
||||
"state_dict": ANY,
|
||||
"loops": ANY,
|
||||
}
|
||||
# TODO: wrong saved epoch, should be 0
|
||||
saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2}
|
||||
|
|
|
@ -660,7 +660,7 @@ def test_benchmark_option(tmpdir):
|
|||
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
|
||||
@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2))
|
||||
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
|
||||
def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn):
|
||||
def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn):
|
||||
class TestModel(BoringModel):
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self.log("foo", -batch_idx)
|
||||
|
|
Loading…
Reference in New Issue