[2/2] Remove training loop force calling early stopping callback (#7069)
* rebase * doc * Update training_loop.py * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md
This commit is contained in:
parent
a5ac3f8a16
commit
14b8dd479a
|
@ -141,6 +141,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Changed
|
||||
|
||||
|
||||
- Changed `EarlyStopping` callback from by default running `EarlyStopping.on_validation_end` if only training is run. Set `check_on_train_epoch_end` to run the callback at the end of the train epoch instead of at the end of the validation epoch ([#7069](https://github.com/PyTorchLightning/pytorch-lightning/pull/7069))
|
||||
|
||||
|
||||
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
|
||||
|
||||
|
||||
|
@ -224,6 +228,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Removed
|
||||
|
||||
|
||||
- Removed training loop explicitly calling `EarlyStopping.on_validation_end` if no validation is run ([#7069](https://github.com/PyTorchLightning/pytorch-lightning/pull/7069))
|
||||
|
||||
|
||||
- Removed `automatic_optimization` as a property from the training loop in favor of `LightningModule.automatic_optimization` ([#7130](https://github.com/PyTorchLightning/pytorch-lightning/pull/7130))
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ from typing import Dict, List, Optional, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.plugins import ParallelPlugin
|
||||
|
@ -148,15 +147,6 @@ class TrainLoop:
|
|||
for cb in callbacks:
|
||||
cb.on_validation_end(self.trainer, model)
|
||||
|
||||
def check_early_stopping_callback(self, should_update):
|
||||
# TODO bake this logic into the EarlyStopping callback
|
||||
if should_update and self.trainer.checkpoint_connector.has_trained:
|
||||
callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)]
|
||||
model = self.trainer.lightning_module
|
||||
|
||||
for cb in callbacks:
|
||||
cb.on_validation_end(self.trainer, model)
|
||||
|
||||
def on_train_epoch_start(self, epoch):
|
||||
|
||||
# update training progress in trainer
|
||||
|
@ -556,7 +546,6 @@ class TrainLoop:
|
|||
|
||||
if should_train_only:
|
||||
self.check_checkpoint_callback(True)
|
||||
self.check_early_stopping_callback(True)
|
||||
|
||||
if should_check_val:
|
||||
self.trainer.validating = True
|
||||
|
|
|
@ -169,7 +169,9 @@ def test_early_stopping_patience_train(
|
|||
if validation_step_none:
|
||||
model.validation_step = None
|
||||
|
||||
early_stop_callback = EarlyStopping(monitor="train_loss", patience=patience, verbose=True)
|
||||
early_stop_callback = EarlyStopping(
|
||||
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=validation_step_none
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[early_stop_callback],
|
||||
|
@ -200,7 +202,7 @@ def test_early_stopping_no_val_step(tmpdir):
|
|||
model.validation_step = None
|
||||
model.val_dataloader = None
|
||||
|
||||
stopping = EarlyStopping(monitor='train_loss', min_delta=0.1, patience=0)
|
||||
stopping = EarlyStopping(monitor='train_loss', min_delta=0.1, patience=0, check_on_train_epoch_end=True)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[stopping],
|
||||
|
|
|
@ -548,7 +548,7 @@ def test_trainer_min_steps_and_min_epochs_not_reached(tmpdir, caplog):
|
|||
return output
|
||||
|
||||
model = TestModel()
|
||||
early_stop = EarlyStopping(monitor="loss", patience=0)
|
||||
early_stop = EarlyStopping(monitor="loss", patience=0, check_on_train_epoch_end=True)
|
||||
min_epochs = 5
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
|
Loading…
Reference in New Issue