added UserWarnings if max_epochs not set in the Trainer class (#10700)

This commit is contained in:
Rajath Bharadwaj 2021-12-06 15:14:25 +05:30 committed by GitHub
parent 99bb62ae64
commit 7914e5c157
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 56 additions and 8 deletions

View File

@ -39,6 +39,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860)) - Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))
### Changed ### Changed
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) - Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))

View File

@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
from datetime import timedelta
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -22,11 +23,13 @@ from torch.optim import Optimizer
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import PossibleUserWarning
def check_finite_loss(loss: Optional[torch.Tensor]) -> None: def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
@ -61,6 +64,40 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in
return hiddens return hiddens
def _parse_loop_limits(
min_steps: Optional[int],
max_steps: int,
min_epochs: Optional[int],
max_epochs: int,
max_time: Optional[Union[str, timedelta, Dict[str, int]]],
) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]:
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
values the user has selected.
Args:
min_steps: Minimum number of steps.
max_steps: Maximum number of steps.
min_epochs: Minimum number of epochs.
max_epochs: Maximum number of epochs.
max_time: Maximum time for the training.
Returns:
The parsed limits, with default values being set for the ones that the user did not specify.
"""
if max_epochs is None:
if max_steps == -1 and max_time is None:
rank_zero_warn(
"`max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit,"
" set `max_epochs=-1`.",
category=PossibleUserWarning,
)
max_epochs = 1000
else:
max_epochs = -1
min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs
return min_steps, max_steps, min_epochs, max_epochs, max_time
def _build_training_step_kwargs( def _build_training_step_kwargs(
lightning_module: "pl.LightningModule", lightning_module: "pl.LightningModule",
optimizers: Sequence[Optimizer], optimizers: Sequence[Optimizer],

View File

@ -38,6 +38,7 @@ from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.loops.utilities import _parse_loop_limits
from pytorch_lightning.plugins import ( from pytorch_lightning.plugins import (
ApexMixedPrecisionPlugin, ApexMixedPrecisionPlugin,
DDPSpawnPlugin, DDPSpawnPlugin,
@ -455,13 +456,11 @@ class Trainer(
self.signal_connector = SignalConnector(self) self.signal_connector = SignalConnector(self)
self.tuner = Tuner(self) self.tuner = Tuner(self)
fit_loop = FitLoop( min_steps, max_steps, min_epochs, max_epochs, max_time = _parse_loop_limits(
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs), min_steps, max_steps, min_epochs, max_epochs, max_time
max_epochs=(
max_epochs if max_epochs is not None else (1000 if (max_steps == -1 and max_time is None) else -1)
),
) )
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
training_batch_loop = TrainingBatchLoop() training_batch_loop = TrainingBatchLoop()
training_validation_loop = EvaluationLoop() training_validation_loop = EvaluationLoop()
training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop) training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop)

View File

@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir):
@RunIf(ipu=True) @RunIf(ipu=True)
def test_no_warning_plugin(tmpdir): def test_no_warning_plugin(tmpdir):
with pytest.warns(None) as record: with pytest.warns(None) as record:
Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options())) Trainer(default_root_dir=tmpdir, max_epochs=1, strategy=IPUPlugin(training_opts=poptorch.Options()))
assert len(record) == 0 assert len(record) == 0

View File

@ -1,6 +1,7 @@
import pytest import pytest
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests.helpers import BoringModel from tests.helpers import BoringModel
@ -33,3 +34,10 @@ def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_ste
# check training stopped at max_epochs or max_steps # check training stopped at max_epochs or max_steps
if trainer.max_steps and not trainer.max_epochs: if trainer.max_steps and not trainer.max_epochs:
assert trainer.global_step == trainer.max_steps assert trainer.global_step == trainer.max_steps
def test_max_epochs_not_set_warning():
"""Test that a warning is emitted when `max_epochs` was not set by the user."""
with pytest.warns(PossibleUserWarning, match="`max_epochs` was not set. Setting it to 1000 epochs."):
trainer = Trainer(max_epochs=None)
assert trainer.max_epochs == 1000