added UserWarnings if max_epochs not set in the Trainer class (#10700)

This commit is contained in:
Rajath Bharadwaj 2021-12-06 15:14:25 +05:30 committed by GitHub
parent 99bb62ae64
commit 7914e5c157
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 8 deletions

View File

@ -39,6 +39,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))
### Changed
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))

View File

@ -13,8 +13,9 @@
# limitations under the License.
from collections import OrderedDict
from contextlib import contextmanager
from datetime import timedelta
from functools import lru_cache
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@ -22,11 +23,13 @@ from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import PossibleUserWarning
def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
@ -61,6 +64,40 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in
return hiddens
def _parse_loop_limits(
min_steps: Optional[int],
max_steps: int,
min_epochs: Optional[int],
max_epochs: int,
max_time: Optional[Union[str, timedelta, Dict[str, int]]],
) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]:
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
values the user has selected.
Args:
min_steps: Minimum number of steps.
max_steps: Maximum number of steps.
min_epochs: Minimum number of epochs.
max_epochs: Maximum number of epochs.
max_time: Maximum time for the training.
Returns:
The parsed limits, with default values being set for the ones that the user did not specify.
"""
if max_epochs is None:
if max_steps == -1 and max_time is None:
rank_zero_warn(
"`max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit,"
" set `max_epochs=-1`.",
category=PossibleUserWarning,
)
max_epochs = 1000
else:
max_epochs = -1
min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs
return min_steps, max_steps, min_epochs, max_epochs, max_time
def _build_training_step_kwargs(
lightning_module: "pl.LightningModule",
optimizers: Sequence[Optimizer],

View File

@ -38,6 +38,7 @@ from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.loops.utilities import _parse_loop_limits
from pytorch_lightning.plugins import (
ApexMixedPrecisionPlugin,
DDPSpawnPlugin,
@ -455,13 +456,11 @@ class Trainer(
self.signal_connector = SignalConnector(self)
self.tuner = Tuner(self)
fit_loop = FitLoop(
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
max_epochs=(
max_epochs if max_epochs is not None else (1000 if (max_steps == -1 and max_time is None) else -1)
),
min_steps, max_steps, min_epochs, max_epochs, max_time = _parse_loop_limits(
min_steps, max_steps, min_epochs, max_epochs, max_time
)
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
training_batch_loop = TrainingBatchLoop()
training_validation_loop = EvaluationLoop()
training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop)

View File

@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir):
@RunIf(ipu=True)
def test_no_warning_plugin(tmpdir):
with pytest.warns(None) as record:
Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options()))
Trainer(default_root_dir=tmpdir, max_epochs=1, strategy=IPUPlugin(training_opts=poptorch.Options()))
assert len(record) == 0

View File

@ -1,6 +1,7 @@
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests.helpers import BoringModel
@ -33,3 +34,10 @@ def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_ste
# check training stopped at max_epochs or max_steps
if trainer.max_steps and not trainer.max_epochs:
assert trainer.global_step == trainer.max_steps
def test_max_epochs_not_set_warning():
"""Test that a warning is emitted when `max_epochs` was not set by the user."""
with pytest.warns(PossibleUserWarning, match="`max_epochs` was not set. Setting it to 1000 epochs."):
trainer = Trainer(max_epochs=None)
assert trainer.max_epochs == 1000