Add logging messages to notify when `FitLoop` stopping conditions are met (#9749)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Jinyoung Lim 2022-07-23 05:07:47 -07:00 committed by GitHub
parent 4f53e7132f
commit ae9803137a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 11 deletions

View File

@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330))
- Added logging messages to notify when `FitLoop` stopping conditions are met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749))
- Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224)

View File

@ -33,7 +33,7 @@ from pytorch_lightning.utilities.fetching import (
InterBatchParallelDataFetcher,
)
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
log = logging.getLogger(__name__)
@ -150,31 +150,40 @@ class FitLoop(Loop[None]):
@property
def done(self) -> bool:
"""Evaluates when to leave the loop."""
if self.trainer.num_training_batches == 0:
rank_zero_info("`Trainer.fit` stopped: No training batches.")
return True
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
if stop_steps:
rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.")
return True
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
# we use it here because the checkpoint data won't have `completed` increased yet
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
if stop_epochs:
# in case they are not equal, override so `trainer.current_epoch` has the expected value
self.epoch_progress.current.completed = self.epoch_progress.current.processed
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
return True
should_stop = False
if self.trainer.should_stop:
# early stopping
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
should_stop = True
self.trainer.should_stop = True
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
return True
else:
log.info(
"Trainer was signaled to stop but required minimum epochs"
f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
" not been met. Training will continue..."
rank_zero_info(
f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or"
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
)
self.trainer.should_stop = should_stop
return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0
self.trainer.should_stop = False
return False
@property
def skip(self) -> bool:

View File

@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock
import pytest
import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loops import FitLoop
def test_outputs_format(tmpdir):
@ -136,6 +140,49 @@ def test_should_stop_mid_epoch(tmpdir):
assert model.validation_called_at == (0, 5)
def test_fit_loop_done_log_messages(caplog):
fit_loop = FitLoop()
trainer = Mock(spec=Trainer)
fit_loop.trainer = trainer
trainer.should_stop = False
trainer.num_training_batches = 5
assert not fit_loop.done
assert not caplog.messages
trainer.num_training_batches = 0
assert fit_loop.done
assert "No training batches" in caplog.text
caplog.clear()
trainer.num_training_batches = 5
epoch_loop = Mock()
epoch_loop.global_step = 10
fit_loop.connect(epoch_loop=epoch_loop)
fit_loop.max_steps = 10
assert fit_loop.done
assert "max_steps=10` reached" in caplog.text
caplog.clear()
fit_loop.max_steps = 20
fit_loop.epoch_progress.current.processed = 3
fit_loop.max_epochs = 3
trainer.should_stop = True
assert fit_loop.done
assert "max_epochs=3` reached" in caplog.text
caplog.clear()
fit_loop.max_epochs = 5
fit_loop.epoch_loop.min_steps = 0
with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"):
assert fit_loop.done
assert "should_stop` was set" in caplog.text
fit_loop.epoch_loop.min_steps = 100
assert not fit_loop.done
assert "was signaled to stop but" in caplog.text
def test_warning_valid_train_step_end(tmpdir):
class ValidTrainStepEndModel(BoringModel):
def training_step(self, batch, batch_idx):

View File

@ -616,7 +616,7 @@ def test_trainer_min_steps_and_min_epochs_not_reached(tmpdir, caplog):
with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"):
trainer.fit(model)
message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue"
message = f"min_epochs={min_epochs}` or `min_steps=None` has not been met. Training will continue"
num_messages = sum(1 for record in caplog.records if message in record.message)
assert num_messages == min_epochs - 2
assert model.training_step_invoked == min_epochs * 2