Fix EarlyStopping logic when min_epochs not met (#6705)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2021-04-06 13:41:07 +02:00 committed by GitHub
parent f581411210
commit 127c52af74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 4 deletions

View File

@ -209,6 +209,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730))
## [1.2.6] - 2021-03-30
### Changed
@ -241,12 +247,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434))
- Fixed duplicate logs appearing in console when using the python logging module ([#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565))
- Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741))
- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730))
## [1.2.4] - 2021-03-16

View File

@ -614,6 +614,7 @@ class Trainer(
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
' not been met. Training will continue...'
)
self.should_stop = False
# hook
self.train_loop.on_train_end()

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
import pickle
@ -528,6 +529,40 @@ def test_trainer_min_steps_and_epochs(tmpdir):
assert trainer.global_step >= math.floor(num_train_samples * 1.5), "Model did not train for at least min_steps"
def test_trainer_min_steps_and_min_epochs_not_reached(tmpdir, caplog):
""" Test that min_epochs/min_steps in Trainer are enforced even if EarlyStopping is triggered. """
class TestModel(BoringModel):
training_step_invoked = 0
def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
output["loss"] = output["loss"] * 0.0 # force minimal loss to trigger early stopping
self.log("loss", output["loss"])
self.training_step_invoked += 1
assert not self.trainer.should_stop
return output
model = TestModel()
early_stop = EarlyStopping(monitor="loss", patience=0)
min_epochs = 5
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
min_epochs=min_epochs,
limit_val_batches=0,
limit_train_batches=2,
callbacks=[early_stop]
)
with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"):
trainer.fit(model)
message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue"
num_messages = len([record.message for record in caplog.records if message in record.message])
assert num_messages == min_epochs - 2
assert model.training_step_invoked == min_epochs * 2
def test_trainer_max_steps_accumulate_batches(tmpdir):
"""Verify model trains according to specified max steps with grad accumulated batches"""
model = BoringModel()