Revert "Revert "checkpoint consolidation""
This reverts commit 3a9fde915a
.
This commit is contained in:
parent
3a9fde915a
commit
7a369f47e1
|
@ -109,6 +109,10 @@ class Callback(abc.ABC):
|
|||
"""Called when the epoch ends."""
|
||||
pass
|
||||
|
||||
def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None:
|
||||
"""Called when at the very end of train epoch."""
|
||||
pass
|
||||
|
||||
def on_batch_start(self, trainer, pl_module: LightningModule) -> None:
|
||||
"""Called when the training batch begins."""
|
||||
pass
|
||||
|
|
|
@ -143,6 +143,21 @@ class EarlyStopping(Callback):
|
|||
|
||||
self._run_early_stopping_check(trainer)
|
||||
|
||||
def on_train_epoch_final_end(self, trainer, pl_module):
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
if (
|
||||
trainer.state != TrainerState.FITTING or trainer.sanity_checking
|
||||
or not trainer.checkpoint_connector.has_trained
|
||||
):
|
||||
return
|
||||
# if validation is disabled or should skip, we run early stopping
|
||||
# at end of the training epoch
|
||||
if (
|
||||
trainer.disable_validation
|
||||
or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches)
|
||||
):
|
||||
self._run_early_stopping_check(trainer)
|
||||
|
||||
def _run_early_stopping_check(self, trainer):
|
||||
"""
|
||||
Checks whether the early stopping condition is met
|
||||
|
|
|
@ -53,6 +53,7 @@ class LambdaCallback(Callback):
|
|||
on_train_batch_end: Optional[Callable] = None,
|
||||
on_train_epoch_start: Optional[Callable] = None,
|
||||
on_train_epoch_end: Optional[Callable] = None,
|
||||
on_train_epoch_final_end: Optional[Callable] = None,
|
||||
on_validation_epoch_start: Optional[Callable] = None,
|
||||
on_validation_epoch_end: Optional[Callable] = None,
|
||||
on_test_epoch_start: Optional[Callable] = None,
|
||||
|
@ -155,3 +156,5 @@ class LambdaCallback(Callback):
|
|||
self.on_after_backward = on_after_backward
|
||||
if on_before_zero_grad is not None:
|
||||
self.on_before_zero_grad = on_before_zero_grad
|
||||
if on_train_epoch_final_end is not None:
|
||||
self.on_train_epoch_final_end = on_train_epoch_final_end
|
||||
|
|
|
@ -238,6 +238,37 @@ class ModelCheckpoint(Callback):
|
|||
return
|
||||
self.save_checkpoint(trainer)
|
||||
|
||||
def on_train_epoch_final_end(self, trainer, pl_module):
|
||||
"""
|
||||
at the end of each training epoch, checkpoint only when validation is skipped or disabled
|
||||
"""
|
||||
print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step))
|
||||
if (
|
||||
self._should_skip_saving_checkpoint(trainer)
|
||||
or not trainer.checkpoint_connector.has_trained
|
||||
):
|
||||
return
|
||||
# if validation is disabled or should skip, we checkpoint at end of the training epoch
|
||||
if (
|
||||
trainer.disable_validation
|
||||
or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches)
|
||||
):
|
||||
self.save_checkpoint(trainer)
|
||||
|
||||
def on_train_end(self, trainer, *args, **kwargs) -> None:
|
||||
"""
|
||||
checkpoints can be saved at the end of the trianing
|
||||
"""
|
||||
trainer.global_step -= 1
|
||||
if (
|
||||
not self._should_skip_saving_checkpoint(trainer)
|
||||
and trainer.checkpoint_connector.has_trained
|
||||
):
|
||||
if self.save_last and self.verbose:
|
||||
rank_zero_info("Saving latest checkpoint...")
|
||||
self.save_checkpoint(trainer)
|
||||
trainer.global_step += 1
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"monitor": self.monitor,
|
||||
|
|
|
@ -92,6 +92,13 @@ class TrainerCallbackHookMixin(ABC):
|
|||
for callback in self.callbacks:
|
||||
callback.on_train_epoch_end(self, self.lightning_module, outputs)
|
||||
|
||||
def on_train_epoch_final_end(self) -> None:
|
||||
"""
|
||||
Called when at the very end of train epoch.
|
||||
"""
|
||||
for callback in self.callbacks:
|
||||
callback.on_train_epoch_final_end(self, self.lightning_module)
|
||||
|
||||
def on_validation_epoch_start(self):
|
||||
"""Called when the epoch begins."""
|
||||
for callback in self.callbacks:
|
||||
|
|
|
@ -100,6 +100,11 @@ class CallbackHookNameValidator:
|
|||
"""Called when the epoch ends."""
|
||||
return {"on_step": [False], "on_epoch": [False, True]}
|
||||
|
||||
@staticmethod
|
||||
def _on_train_epoch_final_end_log():
|
||||
"""Called when at the very end of train epoch."""
|
||||
return {"on_step": [False], "on_epoch": [False, True]}
|
||||
|
||||
@staticmethod
|
||||
def _on_validation_epoch_start_log():
|
||||
"""Called when the epoch begins."""
|
||||
|
|
|
@ -121,12 +121,6 @@ class TrainLoop:
|
|||
return
|
||||
self._teardown_already_run = True
|
||||
|
||||
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
|
||||
# when a checkpoint was saved at the last step
|
||||
self.trainer.global_step -= 1
|
||||
self.check_checkpoint_callback(should_update=True, is_last=True)
|
||||
self.trainer.global_step += 1
|
||||
|
||||
# hook
|
||||
self.trainer.call_hook("on_train_end")
|
||||
|
||||
|
@ -145,28 +139,6 @@ class TrainLoop:
|
|||
# reset bookkeeping
|
||||
self.trainer._running_stage = None
|
||||
|
||||
def check_checkpoint_callback(self, should_update, is_last=False):
|
||||
# TODO bake this logic into the ModelCheckpoint callback
|
||||
if should_update and self.trainer.checkpoint_connector.has_trained:
|
||||
callbacks = self.trainer.checkpoint_callbacks
|
||||
|
||||
if is_last and any(cb.save_last and cb.verbose for cb in callbacks):
|
||||
rank_zero_info("Saving latest checkpoint...")
|
||||
|
||||
model = self.trainer.lightning_module
|
||||
|
||||
for cb in callbacks:
|
||||
cb.on_validation_end(self.trainer, model)
|
||||
|
||||
def check_early_stopping_callback(self, should_update):
|
||||
# TODO bake this logic into the EarlyStopping callback
|
||||
if should_update and self.trainer.checkpoint_connector.has_trained:
|
||||
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
|
||||
model = self.trainer.lightning_module
|
||||
|
||||
for cb in callbacks:
|
||||
cb.on_validation_end(self.trainer, model)
|
||||
|
||||
def on_train_epoch_start(self, epoch):
|
||||
|
||||
# update training progress in trainer
|
||||
|
@ -562,15 +534,14 @@ class TrainLoop:
|
|||
if (val_loop_called and not should_check_val) or should_train_only:
|
||||
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
|
||||
|
||||
if should_train_only:
|
||||
self.check_checkpoint_callback(True)
|
||||
self.check_early_stopping_callback(True)
|
||||
|
||||
if should_check_val:
|
||||
self.trainer.validating = True
|
||||
self.trainer.run_evaluation(on_epoch=True)
|
||||
self.trainer.training = True
|
||||
|
||||
if should_train_only:
|
||||
self.trainer.call_hook('on_train_epoch_final_end')
|
||||
|
||||
# increment the global step once
|
||||
# progress global step according to grads progress
|
||||
self.increment_accumulated_grad_global_step()
|
||||
|
|
|
@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int):
|
|||
trainer.fit(model)
|
||||
|
||||
# check that the correct ckpts were created
|
||||
expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else []
|
||||
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
|
||||
expected = (
|
||||
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs]
|
||||
if period > 0
|
||||
else []
|
||||
)
|
||||
expected.append(final_epoch_ckpt)
|
||||
assert set(os.listdir(tmpdir)) == set(expected)
|
||||
|
||||
|
||||
|
@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
|
|||
trainer.fit(model)
|
||||
|
||||
# check that the correct ckpts were created
|
||||
expected = [f'epoch={e}.ckpt' for e in range(epochs)
|
||||
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
|
||||
# check that the correct ckpts were created
|
||||
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
|
||||
expected = (
|
||||
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs]
|
||||
if every_n_val_epochs > 0
|
||||
else []
|
||||
)
|
||||
expected.append(final_epoch_ckpt)
|
||||
assert set(os.listdir(tmpdir)) == set(expected)
|
||||
|
||||
|
||||
|
@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc
|
|||
trainer.fit(model)
|
||||
|
||||
# check that the correct ckpts were created
|
||||
expected = [f'epoch={e}.ckpt' for e in range(epochs)
|
||||
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
|
||||
# check that the correct ckpts were created
|
||||
final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1)
|
||||
expected = (
|
||||
[f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs]
|
||||
if every_n_val_epochs > 0
|
||||
else []
|
||||
)
|
||||
expected.append(final_epoch_ckpt)
|
||||
assert set(os.listdir(tmpdir)) == set(expected)
|
||||
|
||||
|
||||
|
@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning(
|
|||
default_root_dir=tmpdir,
|
||||
callbacks=[ckpt],
|
||||
max_epochs=max_epochs,
|
||||
val_check_interval=0.1,
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
trainer.fit(model)
|
||||
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
|
||||
if verbose and save_last and not should_validate:
|
||||
# no validation, hence checkpoint triggered at the end of each training epoch
|
||||
assert caplog.messages.count('Saving latest checkpoint...') == False
|
||||
else:
|
||||
assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last)
|
||||
|
||||
|
||||
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
||||
|
|
|
@ -76,7 +76,7 @@ def reset_seed(seed=0):
|
|||
def set_random_master_port():
|
||||
reset_seed()
|
||||
port = RANDOM_PORTS.pop()
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
os.environ['MASTER_PORT'] = "29501"
|
||||
|
||||
|
||||
def init_checkpoint_callback(logger):
|
||||
|
|
|
@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir):
|
|||
'on_train_batch_start',
|
||||
'on_train_end',
|
||||
'on_train_epoch_end',
|
||||
'on_train_epoch_final_end',
|
||||
'on_train_epoch_start',
|
||||
'on_train_start',
|
||||
'on_validation_batch_end',
|
||||
|
|
Loading…
Reference in New Issue