From a44881cd90fcd0f6a2a81920ff402fd3ef48f2d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 2 Feb 2022 20:57:08 +0100 Subject: [PATCH] Changes in preparation to #8578 (#11562) --- CHANGELOG.md | 3 + .../callbacks/model_checkpoint.py | 6 +- pytorch_lightning/core/lightning.py | 5 +- pytorch_lightning/loops/base.py | 1 + .../loops/epoch/training_epoch_loop.py | 6 +- pytorch_lightning/loops/fit_loop.py | 24 ++----- .../connectors/checkpoint_connector.py | 4 +- pytorch_lightning/trainer/trainer.py | 4 +- pytorch_lightning/tuner/batch_size_scaling.py | 5 +- pytorch_lightning/tuner/lr_finder.py | 4 +- tests/checkpointing/test_model_checkpoint.py | 9 +-- .../checkpointing/test_trainer_checkpoint.py | 4 -- tests/core/test_metric_result_integration.py | 5 -- tests/models/test_restore.py | 69 +++++++++++++++++++ tests/trainer/test_trainer.py | 2 +- 15 files changed, 95 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa7c4f9b05..f75dffd925 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -419,6 +419,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `Strategy.on_tpu` property ([#11536](https://github.com/PyTorchLightning/pytorch-lightning/pull/11536)) +- Removed `FitLoop.current_epoch` getter and setter ([#11562](https://github.com/PyTorchLightning/pytorch-lightning/pull/11562)) + + - Removed access to `_short_id` in `NeptuneLogger` ([#11517](https://github.com/PyTorchLightning/pytorch-lightning/pull/11517)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 5187ab3ef6..442bd7a692 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -366,16 +366,14 @@ class ModelCheckpoint(Callback): This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. """ - epoch = trainer.current_epoch - global_step = trainer.global_step - self._validate_monitor_key(trainer) # track epoch when ckpt was last checked + global_step = trainer.global_step self._last_global_step_saved = global_step # what can be monitored - monitor_candidates = self._monitor_candidates(trainer, epoch=epoch, step=global_step) + monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=global_step) # callback supports multiple simultaneous modes # here we call each mode sequentially diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 05cc8d87ea..2bccb6b732 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -199,10 +199,7 @@ class LightningModule( @property def current_epoch(self) -> int: - """The current epoch in the Trainer. - - If no Trainer is attached, this propery is 0. - """ + """The current epoch in the ``Trainer``, or 0 if not attached.""" return self.trainer.current_epoch if self.trainer else 0 @property diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 7876e4d44e..8581fc0ae2 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -206,6 +206,7 @@ class Loop(ABC, Generic[T]): self._restarting = False except StopIteration: break + self._restarting = False output = self.on_run_end() return output diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index b23608e0ef..20b8f6ae47 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -98,11 +98,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): @property def done(self) -> bool: - """Returns whether the training should be stopped. - - The criteria are that the number of steps reached the max steps, the last batch is reached or the trainer - signals to stop (e.g. by early stopping). - """ + """Evaluates when to leave the loop.""" return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop def connect( # type: ignore[override] diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index bed80f5b96..4367e375e0 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -56,16 +56,6 @@ class FitLoop(Loop[None]): self._is_fresh_start_epoch: bool = True self._outputs: _EPOCH_OUTPUTS_TYPE = [] - @property - def current_epoch(self) -> int: - """Return the current epoch.""" - return self.epoch_progress.current.completed - - @current_epoch.setter - def current_epoch(self, value: int) -> None: - """Setter for the current epoch.""" - self.epoch_progress.current.completed = value - @property def global_step(self) -> int: """Returns the global step.""" @@ -149,19 +139,15 @@ class FitLoop(Loop[None]): @property def done(self) -> bool: - """Evaluates when to leave the loop. - - Returns True if trainer.should_stop was set (e.g. by early stopping) or if the maximum number of steps or epochs - is reached. - """ + """Evaluates when to leave the loop.""" # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) - stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs) + stop_epochs = _is_max_limit_reached(self.epoch_progress.current.completed, self.max_epochs) should_stop = False if self.trainer.should_stop: # early stopping - met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True + met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True if met_min_epochs and met_min_steps: should_stop = True @@ -219,7 +205,7 @@ class FitLoop(Loop[None]): getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) ): # set seed for distributed sampler (enables shuffling for each epoch) - self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) + self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.completed) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) @@ -307,7 +293,7 @@ class FitLoop(Loop[None]): # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit # To simulate that current behavior, we decrement here. # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007 - self.current_epoch = max(self.current_epoch - 1, 0) + self.epoch_progress.current.completed = max(self.epoch_progress.current.completed - 1, 0) # hook self.trainer._call_callback_hooks("on_train_end") diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5c437bfd88..7a9fed45d8 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -218,7 +218,9 @@ class CheckpointConnector: return self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"] - self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"] + # set the `current_epoch` value for old checkpoints without the progress tracking state. + # it will be overwritten by the loop's state if it was also saved + self.trainer.fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ac01227fd0..6cf6dd51cf 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -336,7 +336,6 @@ class Trainer( To enable infinite training, set ``max_epochs = -1``. min_epochs: Force training for at least these many epochs. Disabled by default (None). - If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``. max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set @@ -2349,7 +2348,8 @@ class Trainer( @property def current_epoch(self) -> int: - return self.fit_loop.current_epoch + """The current epoch, updated after the epoch end hooks are run.""" + return self.fit_loop.epoch_progress.current.completed @property def max_epochs(self) -> int: diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 788395f676..3ce0dc03ec 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -60,10 +60,10 @@ def scale_batch_size( # Save initial model, that is loaded after batch size is found ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") - trainer.fit_loop.current_epoch -= 1 + trainer.fit_loop.epoch_progress.current.completed -= 1 trainer.fit_loop.global_step -= 1 trainer.save_checkpoint(ckpt_path) - trainer.fit_loop.current_epoch += 1 + trainer.fit_loop.epoch_progress.current.completed += 1 trainer.fit_loop.global_step += 1 params = __scale_batch_dump_params(trainer) @@ -110,7 +110,6 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]: def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None: trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times - trainer.fit_loop.current_epoch = 0 trainer.fit_loop.max_steps = steps_per_trial # take few steps trainer.logger = DummyLogger() if trainer.logger is not None else None trainer.callbacks = [] # not needed before full run diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 0be49535e0..48387feb35 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -204,10 +204,10 @@ def lr_find( # Save initial model, that is loaded after learning rate is found ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") - trainer.fit_loop.current_epoch -= 1 + trainer.fit_loop.epoch_progress.current.completed -= 1 trainer.fit_loop.global_step -= 1 trainer.save_checkpoint(ckpt_path) - trainer.fit_loop.current_epoch += 1 + trainer.fit_loop.epoch_progress.current.completed += 1 trainer.fit_loop.global_step += 1 params = __lr_finder_dump_params(trainer) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0371d02d4b..315c33bb6c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -162,9 +162,9 @@ def test_model_checkpoint_score_and_ckpt( for epoch in range(max_epochs): score = model.scores[epoch] expected_score = getattr(model, f"{monitor}s")[epoch].mean().item() - expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt" assert math.isclose(score, expected_score, rel_tol=1e-4) + expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt" chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) @@ -462,7 +462,6 @@ class ModelCheckpointExtensionTest(ModelCheckpoint): def test_model_checkpoint_file_extension(tmpdir): """Test ModelCheckpoint with different file extension.""" - model = LogInTwoMethods() model_checkpoint = ModelCheckpointExtensionTest( monitor="early_stop_on", dirpath=tmpdir, save_top_k=1, save_last=True @@ -613,7 +612,7 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): ) trainer.fit(model) - # check that the correct ckpts were created + # check that the correct ckpts were created, the modulo condition is checked in `ModelCheckpoint` expected = [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) @@ -967,15 +966,13 @@ def test_checkpoint_repeated_strategy_extended(tmpdir): assert_checkpoint_content(ckpt_dir) # load from checkpoint - trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)] trainer = pl.Trainer(**trainer_config) assert_trainer_init(trainer) model = ExtendedBoringModel() trainer.test(model) - assert trainer.global_step == 0 - assert trainer.current_epoch == 0 + assert_trainer_init(trainer) trainer.fit(model, ckpt_path=chk) assert trainer.global_step == epochs * limit_train_batches diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index e023a4a347..24268e3cfc 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from copy import deepcopy import torch @@ -53,8 +52,6 @@ def test_finetuning_with_ckpt_path(tmpdir): assert os.listdir(tmpdir) == ["epoch=00.ckpt"] best_model_paths = [checkpoint_callback.best_model_path] - results = [] - for idx in range(3, 6): # load from checkpoint trainer = pl.Trainer( @@ -67,7 +64,6 @@ def test_finetuning_with_ckpt_path(tmpdir): ) trainer.fit(model, ckpt_path=best_model_paths[-1]) trainer.test() - results.append(deepcopy(trainer.callback_metrics)) best_model_paths.append(trainer.checkpoint_callback.best_model_path) for idx, best_model_path in enumerate(best_model_paths): diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index cd320d44ee..05cdb4fe92 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -33,7 +33,6 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ( _ResultMetric, _Sync, ) -from pytorch_lightning.utilities.imports import _fault_tolerant_training from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -373,13 +372,9 @@ class DummyMeanMetric(Metric): def result_collection_reload(accelerator="auto", devices=1, **kwargs): - """This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault Tolerant Training is correct.""" - if not _fault_tolerant_training(): - pytest.skip("Fault tolerant not available") - class CustomException(Exception): pass diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index ab1c05acb8..7f1a402329 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -222,6 +222,75 @@ def test_trainer_properties_restore_ckpt_path(tmpdir): trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt) +def test_correct_step_and_epoch(tmpdir): + model = BoringModel() + first_max_epochs = 2 + train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, max_epochs=first_max_epochs, limit_train_batches=train_batches, limit_val_batches=0 + ) + assert trainer.current_epoch == 0 + assert trainer.global_step == 0 + + trainer.fit(model) + # TODO(@carmocca): should not need `-1` + assert trainer.current_epoch == first_max_epochs - 1 + assert trainer.global_step == first_max_epochs * train_batches + + # save checkpoint after loop ends, training end called, epoch count increased + ckpt_path = str(tmpdir / "model.ckpt") + trainer.save_checkpoint(ckpt_path) + + ckpt = torch.load(ckpt_path) + assert ckpt["epoch"] == first_max_epochs + # TODO(@carmocca): should not need `+1` + assert ckpt["global_step"] == first_max_epochs * train_batches + 1 + + max_epochs = first_max_epochs + 2 + trainer = Trainer( + default_root_dir=tmpdir, max_epochs=max_epochs, limit_train_batches=train_batches, limit_val_batches=0 + ) + # the ckpt state is not loaded at this point + assert trainer.current_epoch == 0 + assert trainer.global_step == 0 + + class TestModel(BoringModel): + def on_pretrain_routine_end(self) -> None: + assert self.trainer.current_epoch == first_max_epochs + # TODO(@carmocca): should not need `+1` + assert self.trainer.global_step == first_max_epochs * train_batches + 1 + + trainer.fit(TestModel(), ckpt_path=ckpt_path) + # TODO(@carmocca): should not need `-1` + assert trainer.current_epoch == max_epochs - 1 + # TODO(@carmocca): should not need `+1` + assert trainer.global_step == max_epochs * train_batches + 1 + + +def test_fit_twice(tmpdir): + epochs = [] + + class TestModel(BoringModel): + def on_train_epoch_end(self, *_): + epochs.append(self.current_epoch) + + trainer = Trainer( + max_epochs=2, + limit_train_batches=1, + limit_val_batches=1, + default_root_dir=tmpdir, + logger=False, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + trainer.fit(TestModel()) + trainer.fit_loop.max_epochs = 4 + trainer.fit(TestModel()) + # TODO(@carmocca): 1 should not be duplicated + assert epochs == [0, 1, 1, 2, 3] + + def test_try_resume_from_non_existing_checkpoint(tmpdir): """Test that trying to resume from non-existing `ckpt_path` fails with an error.""" model = BoringModel() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 97793cbe39..3c3dc80e25 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -331,7 +331,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files) # emulate callback's calls during the training for i, loss in enumerate(losses): - trainer.fit_loop.current_epoch = i + trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch` trainer.fit_loop.global_step = i trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)}) checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)