added UserWarnings if max_epochs not set in the Trainer class (#10700)
This commit is contained in:
parent
99bb62ae64
commit
7914e5c157
|
@ -39,6 +39,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))
|
||||
|
||||
|
||||
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -22,11 +23,13 @@ from torch.optim import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins import ParallelPlugin
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||
|
||||
|
||||
def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
|
||||
|
@ -61,6 +64,40 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in
|
|||
return hiddens
|
||||
|
||||
|
||||
def _parse_loop_limits(
|
||||
min_steps: Optional[int],
|
||||
max_steps: int,
|
||||
min_epochs: Optional[int],
|
||||
max_epochs: int,
|
||||
max_time: Optional[Union[str, timedelta, Dict[str, int]]],
|
||||
) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]:
|
||||
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
|
||||
values the user has selected.
|
||||
|
||||
Args:
|
||||
min_steps: Minimum number of steps.
|
||||
max_steps: Maximum number of steps.
|
||||
min_epochs: Minimum number of epochs.
|
||||
max_epochs: Maximum number of epochs.
|
||||
max_time: Maximum time for the training.
|
||||
|
||||
Returns:
|
||||
The parsed limits, with default values being set for the ones that the user did not specify.
|
||||
"""
|
||||
if max_epochs is None:
|
||||
if max_steps == -1 and max_time is None:
|
||||
rank_zero_warn(
|
||||
"`max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit,"
|
||||
" set `max_epochs=-1`.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
max_epochs = 1000
|
||||
else:
|
||||
max_epochs = -1
|
||||
min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs
|
||||
return min_steps, max_steps, min_epochs, max_epochs, max_time
|
||||
|
||||
|
||||
def _build_training_step_kwargs(
|
||||
lightning_module: "pl.LightningModule",
|
||||
optimizers: Sequence[Optimizer],
|
||||
|
|
|
@ -38,6 +38,7 @@ from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
|||
from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop
|
||||
from pytorch_lightning.loops.utilities import _parse_loop_limits
|
||||
from pytorch_lightning.plugins import (
|
||||
ApexMixedPrecisionPlugin,
|
||||
DDPSpawnPlugin,
|
||||
|
@ -455,13 +456,11 @@ class Trainer(
|
|||
self.signal_connector = SignalConnector(self)
|
||||
self.tuner = Tuner(self)
|
||||
|
||||
fit_loop = FitLoop(
|
||||
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
|
||||
max_epochs=(
|
||||
max_epochs if max_epochs is not None else (1000 if (max_steps == -1 and max_time is None) else -1)
|
||||
),
|
||||
min_steps, max_steps, min_epochs, max_epochs, max_time = _parse_loop_limits(
|
||||
min_steps, max_steps, min_epochs, max_epochs, max_time
|
||||
)
|
||||
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
|
||||
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
|
||||
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
|
||||
training_batch_loop = TrainingBatchLoop()
|
||||
training_validation_loop = EvaluationLoop()
|
||||
training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop)
|
||||
|
|
|
@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir):
|
|||
@RunIf(ipu=True)
|
||||
def test_no_warning_plugin(tmpdir):
|
||||
with pytest.warns(None) as record:
|
||||
Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options()))
|
||||
Trainer(default_root_dir=tmpdir, max_epochs=1, strategy=IPUPlugin(training_opts=poptorch.Options()))
|
||||
assert len(record) == 0
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
|
@ -33,3 +34,10 @@ def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_ste
|
|||
# check training stopped at max_epochs or max_steps
|
||||
if trainer.max_steps and not trainer.max_epochs:
|
||||
assert trainer.global_step == trainer.max_steps
|
||||
|
||||
|
||||
def test_max_epochs_not_set_warning():
|
||||
"""Test that a warning is emitted when `max_epochs` was not set by the user."""
|
||||
with pytest.warns(PossibleUserWarning, match="`max_epochs` was not set. Setting it to 1000 epochs."):
|
||||
trainer = Trainer(max_epochs=None)
|
||||
assert trainer.max_epochs == 1000
|
||||
|
|
Loading…
Reference in New Issue