added UserWarnings if max_epochs not set in the Trainer class (#10700)
This commit is contained in:
parent
99bb62ae64
commit
7914e5c157
|
@ -39,6 +39,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))
|
- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/issues/10860))
|
||||||
|
|
||||||
|
|
||||||
|
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
|
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
|
||||||
|
|
|
@ -13,8 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import timedelta
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -22,11 +23,13 @@ from torch.optim import Optimizer
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.plugins import ParallelPlugin
|
from pytorch_lightning.plugins import ParallelPlugin
|
||||||
|
from pytorch_lightning.utilities import rank_zero_warn
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
||||||
from pytorch_lightning.utilities.memory import recursive_detach
|
from pytorch_lightning.utilities.memory import recursive_detach
|
||||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||||
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
|
|
||||||
|
|
||||||
def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
|
def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
|
||||||
|
@ -61,6 +64,40 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in
|
||||||
return hiddens
|
return hiddens
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_loop_limits(
|
||||||
|
min_steps: Optional[int],
|
||||||
|
max_steps: int,
|
||||||
|
min_epochs: Optional[int],
|
||||||
|
max_epochs: int,
|
||||||
|
max_time: Optional[Union[str, timedelta, Dict[str, int]]],
|
||||||
|
) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]:
|
||||||
|
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
|
||||||
|
values the user has selected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_steps: Minimum number of steps.
|
||||||
|
max_steps: Maximum number of steps.
|
||||||
|
min_epochs: Minimum number of epochs.
|
||||||
|
max_epochs: Maximum number of epochs.
|
||||||
|
max_time: Maximum time for the training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed limits, with default values being set for the ones that the user did not specify.
|
||||||
|
"""
|
||||||
|
if max_epochs is None:
|
||||||
|
if max_steps == -1 and max_time is None:
|
||||||
|
rank_zero_warn(
|
||||||
|
"`max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit,"
|
||||||
|
" set `max_epochs=-1`.",
|
||||||
|
category=PossibleUserWarning,
|
||||||
|
)
|
||||||
|
max_epochs = 1000
|
||||||
|
else:
|
||||||
|
max_epochs = -1
|
||||||
|
min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs
|
||||||
|
return min_steps, max_steps, min_epochs, max_epochs, max_time
|
||||||
|
|
||||||
|
|
||||||
def _build_training_step_kwargs(
|
def _build_training_step_kwargs(
|
||||||
lightning_module: "pl.LightningModule",
|
lightning_module: "pl.LightningModule",
|
||||||
optimizers: Sequence[Optimizer],
|
optimizers: Sequence[Optimizer],
|
||||||
|
|
|
@ -38,6 +38,7 @@ from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||||
from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop
|
from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop
|
||||||
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
|
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
|
||||||
from pytorch_lightning.loops.fit_loop import FitLoop
|
from pytorch_lightning.loops.fit_loop import FitLoop
|
||||||
|
from pytorch_lightning.loops.utilities import _parse_loop_limits
|
||||||
from pytorch_lightning.plugins import (
|
from pytorch_lightning.plugins import (
|
||||||
ApexMixedPrecisionPlugin,
|
ApexMixedPrecisionPlugin,
|
||||||
DDPSpawnPlugin,
|
DDPSpawnPlugin,
|
||||||
|
@ -455,13 +456,11 @@ class Trainer(
|
||||||
self.signal_connector = SignalConnector(self)
|
self.signal_connector = SignalConnector(self)
|
||||||
self.tuner = Tuner(self)
|
self.tuner = Tuner(self)
|
||||||
|
|
||||||
fit_loop = FitLoop(
|
min_steps, max_steps, min_epochs, max_epochs, max_time = _parse_loop_limits(
|
||||||
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
|
min_steps, max_steps, min_epochs, max_epochs, max_time
|
||||||
max_epochs=(
|
|
||||||
max_epochs if max_epochs is not None else (1000 if (max_steps == -1 and max_time is None) else -1)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
|
fit_loop = FitLoop(min_epochs=min_epochs, max_epochs=max_epochs)
|
||||||
|
training_epoch_loop = TrainingEpochLoop(min_steps=min_steps, max_steps=max_steps)
|
||||||
training_batch_loop = TrainingBatchLoop()
|
training_batch_loop = TrainingBatchLoop()
|
||||||
training_validation_loop = EvaluationLoop()
|
training_validation_loop = EvaluationLoop()
|
||||||
training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop)
|
training_epoch_loop.connect(batch_loop=training_batch_loop, val_loop=training_validation_loop)
|
||||||
|
|
|
@ -120,7 +120,7 @@ def test_warning_if_ipus_not_used(tmpdir):
|
||||||
@RunIf(ipu=True)
|
@RunIf(ipu=True)
|
||||||
def test_no_warning_plugin(tmpdir):
|
def test_no_warning_plugin(tmpdir):
|
||||||
with pytest.warns(None) as record:
|
with pytest.warns(None) as record:
|
||||||
Trainer(default_root_dir=tmpdir, strategy=IPUPlugin(training_opts=poptorch.Options()))
|
Trainer(default_root_dir=tmpdir, max_epochs=1, strategy=IPUPlugin(training_opts=poptorch.Options()))
|
||||||
assert len(record) == 0
|
assert len(record) == 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
|
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||||
from tests.helpers import BoringModel
|
from tests.helpers import BoringModel
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,3 +34,10 @@ def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_ste
|
||||||
# check training stopped at max_epochs or max_steps
|
# check training stopped at max_epochs or max_steps
|
||||||
if trainer.max_steps and not trainer.max_epochs:
|
if trainer.max_steps and not trainer.max_epochs:
|
||||||
assert trainer.global_step == trainer.max_steps
|
assert trainer.global_step == trainer.max_steps
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_epochs_not_set_warning():
|
||||||
|
"""Test that a warning is emitted when `max_epochs` was not set by the user."""
|
||||||
|
with pytest.warns(PossibleUserWarning, match="`max_epochs` was not set. Setting it to 1000 epochs."):
|
||||||
|
trainer = Trainer(max_epochs=None)
|
||||||
|
assert trainer.max_epochs == 1000
|
||||||
|
|
Loading…
Reference in New Issue