Add logging messages to notify when `FitLoop` stopping conditions are met (#9749)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
4f53e7132f
commit
ae9803137a
|
@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `XLAEnvironment` cluster environment plugin ([#11330](https://github.com/PyTorchLightning/pytorch-lightning/pull/11330))
|
||||
|
||||
|
||||
- Added logging messages to notify when `FitLoop` stopping conditions are met ([#9749](https://github.com/PyTorchLightning/pytorch-lightning/pull/9749))
|
||||
|
||||
|
||||
- Added support for calling unknown methods with `DummyLogger` ([#13224](https://github.com/PyTorchLightning/pytorch-lightning/pull/13224)
|
||||
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ from pytorch_lightning.utilities.fetching import (
|
|||
InterBatchParallelDataFetcher,
|
||||
)
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -150,31 +150,40 @@ class FitLoop(Loop[None]):
|
|||
@property
|
||||
def done(self) -> bool:
|
||||
"""Evaluates when to leave the loop."""
|
||||
if self.trainer.num_training_batches == 0:
|
||||
rank_zero_info("`Trainer.fit` stopped: No training batches.")
|
||||
return True
|
||||
|
||||
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
|
||||
stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
|
||||
if stop_steps:
|
||||
rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.")
|
||||
return True
|
||||
|
||||
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
|
||||
# we use it here because the checkpoint data won't have `completed` increased yet
|
||||
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
|
||||
if stop_epochs:
|
||||
# in case they are not equal, override so `trainer.current_epoch` has the expected value
|
||||
self.epoch_progress.current.completed = self.epoch_progress.current.processed
|
||||
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
|
||||
return True
|
||||
|
||||
should_stop = False
|
||||
if self.trainer.should_stop:
|
||||
# early stopping
|
||||
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
|
||||
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
|
||||
if met_min_epochs and met_min_steps:
|
||||
should_stop = True
|
||||
self.trainer.should_stop = True
|
||||
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
|
||||
return True
|
||||
else:
|
||||
log.info(
|
||||
"Trainer was signaled to stop but required minimum epochs"
|
||||
f" ({self.min_epochs}) or minimum steps ({self.min_steps}) has"
|
||||
" not been met. Training will continue..."
|
||||
rank_zero_info(
|
||||
f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or"
|
||||
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
|
||||
)
|
||||
self.trainer.should_stop = should_stop
|
||||
|
||||
return stop_steps or should_stop or stop_epochs or self.trainer.num_training_batches == 0
|
||||
self.trainer.should_stop = False
|
||||
return False
|
||||
|
||||
@property
|
||||
def skip(self) -> bool:
|
||||
|
|
|
@ -11,11 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import seed_everything, Trainer
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.loops import FitLoop
|
||||
|
||||
|
||||
def test_outputs_format(tmpdir):
|
||||
|
@ -136,6 +140,49 @@ def test_should_stop_mid_epoch(tmpdir):
|
|||
assert model.validation_called_at == (0, 5)
|
||||
|
||||
|
||||
def test_fit_loop_done_log_messages(caplog):
|
||||
fit_loop = FitLoop()
|
||||
trainer = Mock(spec=Trainer)
|
||||
fit_loop.trainer = trainer
|
||||
|
||||
trainer.should_stop = False
|
||||
trainer.num_training_batches = 5
|
||||
assert not fit_loop.done
|
||||
assert not caplog.messages
|
||||
|
||||
trainer.num_training_batches = 0
|
||||
assert fit_loop.done
|
||||
assert "No training batches" in caplog.text
|
||||
caplog.clear()
|
||||
trainer.num_training_batches = 5
|
||||
|
||||
epoch_loop = Mock()
|
||||
epoch_loop.global_step = 10
|
||||
fit_loop.connect(epoch_loop=epoch_loop)
|
||||
fit_loop.max_steps = 10
|
||||
assert fit_loop.done
|
||||
assert "max_steps=10` reached" in caplog.text
|
||||
caplog.clear()
|
||||
fit_loop.max_steps = 20
|
||||
|
||||
fit_loop.epoch_progress.current.processed = 3
|
||||
fit_loop.max_epochs = 3
|
||||
trainer.should_stop = True
|
||||
assert fit_loop.done
|
||||
assert "max_epochs=3` reached" in caplog.text
|
||||
caplog.clear()
|
||||
fit_loop.max_epochs = 5
|
||||
|
||||
fit_loop.epoch_loop.min_steps = 0
|
||||
with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"):
|
||||
assert fit_loop.done
|
||||
assert "should_stop` was set" in caplog.text
|
||||
|
||||
fit_loop.epoch_loop.min_steps = 100
|
||||
assert not fit_loop.done
|
||||
assert "was signaled to stop but" in caplog.text
|
||||
|
||||
|
||||
def test_warning_valid_train_step_end(tmpdir):
|
||||
class ValidTrainStepEndModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
|
|
|
@ -616,7 +616,7 @@ def test_trainer_min_steps_and_min_epochs_not_reached(tmpdir, caplog):
|
|||
with caplog.at_level(logging.INFO, logger="pytorch_lightning.trainer.trainer"):
|
||||
trainer.fit(model)
|
||||
|
||||
message = f"minimum epochs ({min_epochs}) or minimum steps (None) has not been met. Training will continue"
|
||||
message = f"min_epochs={min_epochs}` or `min_steps=None` has not been met. Training will continue"
|
||||
num_messages = sum(1 for record in caplog.records if message in record.message)
|
||||
assert num_messages == min_epochs - 2
|
||||
assert model.training_step_invoked == min_epochs * 2
|
||||
|
|
Loading…
Reference in New Issue