Fix EarlyStopping logic when min_epochs not met (#6705)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
f581411210
commit
127c52af74
10
CHANGELOG.md
10
CHANGELOG.md
|
@ -209,6 +209,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))
|
||||
|
||||
|
||||
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
|
||||
|
||||
|
||||
- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730))
|
||||
|
||||
|
||||
## [1.2.6] - 2021-03-30
|
||||
|
||||
### Changed
|
||||
|
@ -241,12 +247,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434))
|
||||
- Fixed duplicate logs appearing in console when using the python logging module ([#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
|
||||
- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565))
|
||||
|
||||
|
||||
- Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741))
|
||||
|
||||
- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730))
|
||||
|
||||
|
||||
## [1.2.4] - 2021-03-16
|
||||
|
||||
|
|
|
@ -614,6 +614,7 @@ class Trainer(
|
|||
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
|
||||
' not been met. Training will continue...'
|
||||
)
|
||||
self.should_stop = False
|
||||
|
||||
# hook
|
||||
self.train_loop.on_train_end()
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
|
@ -528,6 +529,40 @@ def test_trainer_min_steps_and_epochs(tmpdir):
|
|||
assert trainer.global_step >= math.floor(num_train_samples * 1.5), "Model did not train for at least min_steps"
|
||||
|
||||
|
||||
def test_trainer_min_steps_and_min_epochs_not_reached(tmpdir, caplog):
|
||||
""" Test that min_epochs/min_steps in Trainer are enforced even if EarlyStopping is triggered. """
|
||||
|
||||
class TestModel(BoringModel):
|
||||
training_step_invoked = 0
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = super().training_step(batch, batch_idx)
|
||||
output["loss"] = output["loss"] * 0.0 # force minimal loss to trigger early stopping
|
||||
self.log("loss", output["loss"])
|
||||
self.training_step_invoked += 1
|
||||
assert not self.trainer.should_stop
|
||||
return output
|
||||
|
||||
model = TestModel()
|
||||
early_stop = EarlyStopping(monitor="loss", patience=0)
|
||||
min_epochs = 5
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
progress_bar_refresh_rate=0,
|
||||
min_epochs=min_epochs,
|
||||
limit_val_batches=0,
|
||||
limit_train_batches=2,
|
||||
callbacks=[early_stop]
|
||||
)
|
||||
with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"):
|
||||
trainer.fit(model)
|
||||
|
||||
message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue"
|
||||
num_messages = len([record.message for record in caplog.records if message in record.message])
|
||||
assert num_messages == min_epochs - 2
|
||||
assert model.training_step_invoked == min_epochs * 2
|
||||
|
||||
|
||||
def test_trainer_max_steps_accumulate_batches(tmpdir):
|
||||
"""Verify model trains according to specified max steps with grad accumulated batches"""
|
||||
model = BoringModel()
|
||||
|
|
Loading…
Reference in New Issue