diff --git a/CHANGELOG.md b/CHANGELOG.md index f49707b296..dc11391e98 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860)) + +- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700)) + + ### Changed - Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 3396f75057..dfbccfadcc 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -13,8 +13,9 @@ # limitations under the License. from collections import OrderedDict from contextlib import contextmanager +from datetime import timedelta from functools import lru_cache -from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple +from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -22,11 +23,13 @@ from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins import ParallelPlugin +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.warnings import PossibleUserWarning def check_finite_loss(loss: Optional[torch.Tensor]) -> None: @@ -61,6 +64,40 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in return hiddens +def _parse_loop_limits( + min_steps: Optional[int], + max_steps: int, + min_epochs: Optional[int], + max_epochs: int, + max_time: Optional[Union[str, timedelta, Dict[str, int]]], +) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]: + """This utility computes the default values for the minimum and maximum number of steps and epochs given the + values the user has selected. + + Args: + min_steps: Minimum number of steps. + max_steps: Maximum number of steps. + min_epochs: Minimum number of epochs. + max_epochs: Maximum number of epochs. + max_time: Maximum time for the training. + + Returns: + The parsed limits, with default values being set for the ones that the user did not specify. + """ + if max_epochs is None: + if max_steps == -1 and max_time is None: + rank_zero_warn( + "`max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit," + " set `max_epochs=-1`.", + category=PossibleUserWarning, + ) + max_epochs = 1000 + else: + max_epochs = -1 + min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs + return min_steps, max_steps, min_epochs, max_epochs, max_time + + def _build_training_step_kwargs( lightning_module: "pl.LightningModule", optimizers: Sequence[Optimizer], diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4fc7d2f394..40368737e2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,6 +38,7 @@ from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.loops.utilities import _parse_loop_limits from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, DDPSpawnPlugin, @@ -455,13 +456,11 @@ class Trainer( self.signal_connector = SignalConnector(self) self.tuner = Tuner(self) - fit_loop = FitLoop( - min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs), - max_epochs=( - max_epochs if max_epochs is not None else (1000 if (max_steps == -1 and max_time is None) else -1) - ), + min_steps, max_steps, min_epochs, max_epochs, max_time = _parse_loop_limits( + min_steps, max_steps, min_epochs, max_epochs, max_time ) - training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) + fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs) + training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps) training_batch_loop = TrainingBatchLoop() training_validation_loop = EvaluationLoop() training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop) diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index bb9ef53e96..87154efbd4 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir): @RunIf(ipu=True) def test_no_warning_plugin(tmpdir): with pytest.warns(None) as record: - Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options())) + Trainer(default_root_dir=tmpdir, max_epochs=1, strategy=IPUPlugin(training_opts=poptorch.Options())) assert len(record) == 0 diff --git a/tests/trainer/flags/test_min_max_epochs.py b/tests/trainer/flags/test_min_max_epochs.py index 989dde6e79..e410e806ea 100644 --- a/tests/trainer/flags/test_min_max_epochs.py +++ b/tests/trainer/flags/test_min_max_epochs.py @@ -1,6 +1,7 @@ import pytest from pytorch_lightning import Trainer +from pytorch_lightning.utilities.warnings import PossibleUserWarning from tests.helpers import BoringModel @@ -33,3 +34,10 @@ def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_ste # check training stopped at max_epochs or max_steps if trainer.max_steps and not trainer.max_epochs: assert trainer.global_step == trainer.max_steps + + +def test_max_epochs_not_set_warning(): + """Test that a warning is emitted when `max_epochs` was not set by the user.""" + with pytest.warns(PossibleUserWarning, match="`max_epochs` was not set. Setting it to 1000 epochs."): + trainer = Trainer(max_epochs=None) + assert trainer.max_epochs == 1000