Changes in preparation to #8578 (#11562)

This commit is contained in:
Carlos Mocholí 2022-02-02 20:57:08 +01:00 committed by GitHub
parent 79a3ff690b
commit a44881cd90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 95 additions and 56 deletions

View File

@ -419,6 +419,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `Strategy.on_tpu` property ([#11536](https://github.com/PyTorchLightning/pytorch-lightning/pull/11536)) - Removed `Strategy.on_tpu` property ([#11536](https://github.com/PyTorchLightning/pytorch-lightning/pull/11536))
- Removed `FitLoop.current_epoch` getter and setter ([#11562](https://github.com/PyTorchLightning/pytorch-lightning/pull/11562))
- Removed access to `_short_id` in `NeptuneLogger` ([#11517](https://github.com/PyTorchLightning/pytorch-lightning/pull/11517)) - Removed access to `_short_id` in `NeptuneLogger` ([#11517](https://github.com/PyTorchLightning/pytorch-lightning/pull/11517))

View File

@ -366,16 +366,14 @@ class ModelCheckpoint(Callback):
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
""" """
epoch = trainer.current_epoch
global_step = trainer.global_step
self._validate_monitor_key(trainer) self._validate_monitor_key(trainer)
# track epoch when ckpt was last checked # track epoch when ckpt was last checked
global_step = trainer.global_step
self._last_global_step_saved = global_step self._last_global_step_saved = global_step
# what can be monitored # what can be monitored
monitor_candidates = self._monitor_candidates(trainer, epoch=epoch, step=global_step) monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=global_step)
# callback supports multiple simultaneous modes # callback supports multiple simultaneous modes
# here we call each mode sequentially # here we call each mode sequentially

View File

@ -199,10 +199,7 @@ class LightningModule(
@property @property
def current_epoch(self) -> int: def current_epoch(self) -> int:
"""The current epoch in the Trainer. """The current epoch in the ``Trainer``, or 0 if not attached."""
If no Trainer is attached, this propery is 0.
"""
return self.trainer.current_epoch if self.trainer else 0 return self.trainer.current_epoch if self.trainer else 0
@property @property

View File

@ -206,6 +206,7 @@ class Loop(ABC, Generic[T]):
self._restarting = False self._restarting = False
except StopIteration: except StopIteration:
break break
self._restarting = False
output = self.on_run_end() output = self.on_run_end()
return output return output

View File

@ -98,11 +98,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
@property @property
def done(self) -> bool: def done(self) -> bool:
"""Returns whether the training should be stopped. """Evaluates when to leave the loop."""
The criteria are that the number of steps reached the max steps, the last batch is reached or the trainer
signals to stop (e.g. by early stopping).
"""
return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop
def connect( # type: ignore[override] def connect( # type: ignore[override]

View File

@ -56,16 +56,6 @@ class FitLoop(Loop[None]):
self._is_fresh_start_epoch: bool = True self._is_fresh_start_epoch: bool = True
self._outputs: _EPOCH_OUTPUTS_TYPE = [] self._outputs: _EPOCH_OUTPUTS_TYPE = []
@property
def current_epoch(self) -> int:
"""Return the current epoch."""
return self.epoch_progress.current.completed
@current_epoch.setter
def current_epoch(self, value: int) -> None:
"""Setter for the current epoch."""
self.epoch_progress.current.completed = value
@property @property
def global_step(self) -> int: def global_step(self) -> int:
"""Returns the global step.""" """Returns the global step."""
@ -149,19 +139,15 @@ class FitLoop(Loop[None]):
@property @property
def done(self) -> bool: def done(self) -> bool:
"""Evaluates when to leave the loop. """Evaluates when to leave the loop."""
Returns True if trainer.should_stop was set (e.g. by early stopping) or if the maximum number of steps or epochs
is reached.
"""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs) stop_epochs = _is_max_limit_reached(self.epoch_progress.current.completed, self.max_epochs)
should_stop = False should_stop = False
if self.trainer.should_stop: if self.trainer.should_stop:
# early stopping # early stopping
met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps: if met_min_epochs and met_min_steps:
should_stop = True should_stop = True
@ -219,7 +205,7 @@ class FitLoop(Loop[None]):
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
): ):
# set seed for distributed sampler (enables shuffling for each epoch) # set seed for distributed sampler (enables shuffling for each epoch)
self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.completed)
# changing gradient according accumulation_scheduler # changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
@ -307,7 +293,7 @@ class FitLoop(Loop[None]):
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
# To simulate that current behavior, we decrement here. # To simulate that current behavior, we decrement here.
# TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007 # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
self.current_epoch = max(self.current_epoch - 1, 0) self.epoch_progress.current.completed = max(self.epoch_progress.current.completed - 1, 0)
# hook # hook
self.trainer._call_callback_hooks("on_train_end") self.trainer._call_callback_hooks("on_train_end")

View File

@ -218,7 +218,9 @@ class CheckpointConnector:
return return
self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"] self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"]
self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"] # set the `current_epoch` value for old checkpoints without the progress tracking state.
# it will be overwritten by the loop's state if it was also saved
self.trainer.fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
assert self.trainer.state.fn is not None assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops") state_dict = self._loaded_checkpoint.get("loops")

View File

@ -336,7 +336,6 @@ class Trainer(
To enable infinite training, set ``max_epochs = -1``. To enable infinite training, set ``max_epochs = -1``.
min_epochs: Force training for at least these many epochs. Disabled by default (None). min_epochs: Force training for at least these many epochs. Disabled by default (None).
If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``.
max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
@ -2349,7 +2348,8 @@ class Trainer(
@property @property
def current_epoch(self) -> int: def current_epoch(self) -> int:
return self.fit_loop.current_epoch """The current epoch, updated after the epoch end hooks are run."""
return self.fit_loop.epoch_progress.current.completed
@property @property
def max_epochs(self) -> int: def max_epochs(self) -> int:

View File

@ -60,10 +60,10 @@ def scale_batch_size(
# Save initial model, that is loaded after batch size is found # Save initial model, that is loaded after batch size is found
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
trainer.fit_loop.current_epoch -= 1 trainer.fit_loop.epoch_progress.current.completed -= 1
trainer.fit_loop.global_step -= 1 trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path) trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.current_epoch += 1 trainer.fit_loop.epoch_progress.current.completed += 1
trainer.fit_loop.global_step += 1 trainer.fit_loop.global_step += 1
params = __scale_batch_dump_params(trainer) params = __scale_batch_dump_params(trainer)
@ -110,7 +110,6 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None: def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None:
trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times trainer.auto_lr_find = False # avoid lr find being called multiple times
trainer.fit_loop.current_epoch = 0
trainer.fit_loop.max_steps = steps_per_trial # take few steps trainer.fit_loop.max_steps = steps_per_trial # take few steps
trainer.logger = DummyLogger() if trainer.logger is not None else None trainer.logger = DummyLogger() if trainer.logger is not None else None
trainer.callbacks = [] # not needed before full run trainer.callbacks = [] # not needed before full run

View File

@ -204,10 +204,10 @@ def lr_find(
# Save initial model, that is loaded after learning rate is found # Save initial model, that is loaded after learning rate is found
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt") ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
trainer.fit_loop.current_epoch -= 1 trainer.fit_loop.epoch_progress.current.completed -= 1
trainer.fit_loop.global_step -= 1 trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path) trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.current_epoch += 1 trainer.fit_loop.epoch_progress.current.completed += 1
trainer.fit_loop.global_step += 1 trainer.fit_loop.global_step += 1
params = __lr_finder_dump_params(trainer) params = __lr_finder_dump_params(trainer)

View File

@ -162,9 +162,9 @@ def test_model_checkpoint_score_and_ckpt(
for epoch in range(max_epochs): for epoch in range(max_epochs):
score = model.scores[epoch] score = model.scores[epoch]
expected_score = getattr(model, f"{monitor}s")[epoch].mean().item() expected_score = getattr(model, f"{monitor}s")[epoch].mean().item()
expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
assert math.isclose(score, expected_score, rel_tol=1e-4) assert math.isclose(score, expected_score, rel_tol=1e-4)
expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk["epoch"] == epoch + 1 assert chk["epoch"] == epoch + 1
assert chk["global_step"] == limit_train_batches * (epoch + 1) assert chk["global_step"] == limit_train_batches * (epoch + 1)
@ -462,7 +462,6 @@ class ModelCheckpointExtensionTest(ModelCheckpoint):
def test_model_checkpoint_file_extension(tmpdir): def test_model_checkpoint_file_extension(tmpdir):
"""Test ModelCheckpoint with different file extension.""" """Test ModelCheckpoint with different file extension."""
model = LogInTwoMethods() model = LogInTwoMethods()
model_checkpoint = ModelCheckpointExtensionTest( model_checkpoint = ModelCheckpointExtensionTest(
monitor="early_stop_on", dirpath=tmpdir, save_top_k=1, save_last=True monitor="early_stop_on", dirpath=tmpdir, save_top_k=1, save_last=True
@ -613,7 +612,7 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs):
) )
trainer.fit(model) trainer.fit(model)
# check that the correct ckpts were created # check that the correct ckpts were created, the modulo condition is checked in `ModelCheckpoint`
expected = [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] expected = [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected) assert set(os.listdir(tmpdir)) == set(expected)
@ -967,15 +966,13 @@ def test_checkpoint_repeated_strategy_extended(tmpdir):
assert_checkpoint_content(ckpt_dir) assert_checkpoint_content(ckpt_dir)
# load from checkpoint # load from checkpoint
trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)]
trainer = pl.Trainer(**trainer_config) trainer = pl.Trainer(**trainer_config)
assert_trainer_init(trainer) assert_trainer_init(trainer)
model = ExtendedBoringModel() model = ExtendedBoringModel()
trainer.test(model) trainer.test(model)
assert trainer.global_step == 0 assert_trainer_init(trainer)
assert trainer.current_epoch == 0
trainer.fit(model, ckpt_path=chk) trainer.fit(model, ckpt_path=chk)
assert trainer.global_step == epochs * limit_train_batches assert trainer.global_step == epochs * limit_train_batches

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from copy import deepcopy
import torch import torch
@ -53,8 +52,6 @@ def test_finetuning_with_ckpt_path(tmpdir):
assert os.listdir(tmpdir) == ["epoch=00.ckpt"] assert os.listdir(tmpdir) == ["epoch=00.ckpt"]
best_model_paths = [checkpoint_callback.best_model_path] best_model_paths = [checkpoint_callback.best_model_path]
results = []
for idx in range(3, 6): for idx in range(3, 6):
# load from checkpoint # load from checkpoint
trainer = pl.Trainer( trainer = pl.Trainer(
@ -67,7 +64,6 @@ def test_finetuning_with_ckpt_path(tmpdir):
) )
trainer.fit(model, ckpt_path=best_model_paths[-1]) trainer.fit(model, ckpt_path=best_model_paths[-1])
trainer.test() trainer.test()
results.append(deepcopy(trainer.callback_metrics))
best_model_paths.append(trainer.checkpoint_callback.best_model_path) best_model_paths.append(trainer.checkpoint_callback.best_model_path)
for idx, best_model_path in enumerate(best_model_paths): for idx, best_model_path in enumerate(best_model_paths):

View File

@ -33,7 +33,6 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import (
_ResultMetric, _ResultMetric,
_Sync, _Sync,
) )
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from tests.helpers import BoringModel from tests.helpers import BoringModel
from tests.helpers.runif import RunIf from tests.helpers.runif import RunIf
@ -373,13 +372,9 @@ class DummyMeanMetric(Metric):
def result_collection_reload(accelerator="auto", devices=1, **kwargs): def result_collection_reload(accelerator="auto", devices=1, **kwargs):
"""This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault """This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault
Tolerant Training is correct.""" Tolerant Training is correct."""
if not _fault_tolerant_training():
pytest.skip("Fault tolerant not available")
class CustomException(Exception): class CustomException(Exception):
pass pass

View File

@ -222,6 +222,75 @@ def test_trainer_properties_restore_ckpt_path(tmpdir):
trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt) trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)
def test_correct_step_and_epoch(tmpdir):
model = BoringModel()
first_max_epochs = 2
train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=first_max_epochs, limit_train_batches=train_batches, limit_val_batches=0
)
assert trainer.current_epoch == 0
assert trainer.global_step == 0
trainer.fit(model)
# TODO(@carmocca): should not need `-1`
assert trainer.current_epoch == first_max_epochs - 1
assert trainer.global_step == first_max_epochs * train_batches
# save checkpoint after loop ends, training end called, epoch count increased
ckpt_path = str(tmpdir / "model.ckpt")
trainer.save_checkpoint(ckpt_path)
ckpt = torch.load(ckpt_path)
assert ckpt["epoch"] == first_max_epochs
# TODO(@carmocca): should not need `+1`
assert ckpt["global_step"] == first_max_epochs * train_batches + 1
max_epochs = first_max_epochs + 2
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=max_epochs, limit_train_batches=train_batches, limit_val_batches=0
)
# the ckpt state is not loaded at this point
assert trainer.current_epoch == 0
assert trainer.global_step == 0
class TestModel(BoringModel):
def on_pretrain_routine_end(self) -> None:
assert self.trainer.current_epoch == first_max_epochs
# TODO(@carmocca): should not need `+1`
assert self.trainer.global_step == first_max_epochs * train_batches + 1
trainer.fit(TestModel(), ckpt_path=ckpt_path)
# TODO(@carmocca): should not need `-1`
assert trainer.current_epoch == max_epochs - 1
# TODO(@carmocca): should not need `+1`
assert trainer.global_step == max_epochs * train_batches + 1
def test_fit_twice(tmpdir):
epochs = []
class TestModel(BoringModel):
def on_train_epoch_end(self, *_):
epochs.append(self.current_epoch)
trainer = Trainer(
max_epochs=2,
limit_train_batches=1,
limit_val_batches=1,
default_root_dir=tmpdir,
logger=False,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
trainer.fit(TestModel())
trainer.fit_loop.max_epochs = 4
trainer.fit(TestModel())
# TODO(@carmocca): 1 should not be duplicated
assert epochs == [0, 1, 1, 2, 3]
def test_try_resume_from_non_existing_checkpoint(tmpdir): def test_try_resume_from_non_existing_checkpoint(tmpdir):
"""Test that trying to resume from non-existing `ckpt_path` fails with an error.""" """Test that trying to resume from non-existing `ckpt_path` fails with an error."""
model = BoringModel() model = BoringModel()

View File

@ -331,7 +331,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
# emulate callback's calls during the training # emulate callback's calls during the training
for i, loss in enumerate(losses): for i, loss in enumerate(losses):
trainer.fit_loop.current_epoch = i trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
trainer.fit_loop.global_step = i trainer.fit_loop.global_step = i
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)}) trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)