Fix mypy typing errors in pytorch_lightning/trainer/trainer.py (#14204)

Co-authored-by: otaj <ota@lightning.ai>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Lee Jungwon 2022-09-30 19:50:42 +09:00 committed by GitHub
parent 021c2f1447
commit a9142d637a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 125 additions and 97 deletions

View File

@ -57,7 +57,6 @@ warn_no_return = "False"
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.trainer.trainer",
"lightning_app.api.http_methods",
"lightning_app.api.request_types",
"lightning_app.cli.app-template.app",

View File

@ -82,7 +82,7 @@ class _StrategyRegistry(dict):
return do_register
def get(self, name: str, default: Optional[Any] = None) -> Any:
def get(self, name: str, default: Optional[Strategy] = None) -> Strategy: # type: ignore[override]
"""Calls the registered strategy with the required parameters and returns the strategy object.
Args:

View File

@ -222,7 +222,9 @@ class ProgressBarBase(Callback):
if not trainer.is_global_zero:
self.disable()
def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
def get_metrics(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
r"""
Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
Implement this to override the items displayed in the progress bar.

View File

@ -160,6 +160,7 @@ class StochasticWeightAveraging(Callback):
if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
assert trainer.max_epochs is not None
if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)
@ -189,6 +190,7 @@ class StochasticWeightAveraging(Callback):
for lr, group in zip(self._swa_lrs, optimizer.param_groups):
group["initial_lr"] = lr
assert trainer.max_epochs is not None
self._swa_scheduler = cast(
_LRScheduler,
SWALR(

View File

@ -56,7 +56,7 @@ class EvaluationLoop(DataLoaderLoop):
self._results = _ResultCollection(training=False)
self._outputs: List[EPOCH_OUTPUT] = []
self._logged_outputs: List[_OUT_DICT] = []
self._max_batches: List[int] = []
self._max_batches: List[Union[int, float]] = []
self._has_run: bool = False
self._data_fetcher: Optional[AbstractDataFetcher] = None
@ -213,7 +213,7 @@ class EvaluationLoop(DataLoaderLoop):
self._results.cpu()
self.epoch_loop.teardown()
def _get_max_batches(self) -> List[int]:
def _get_max_batches(self) -> List[Union[int, float]]:
"""Returns the max number of batches for each dataloader."""
if self.trainer.testing:
max_batches = self.trainer.num_test_batches

View File

@ -1,4 +1,4 @@
from typing import Any, List, Optional, Sequence
from typing import Any, List, Optional, Sequence, Union
from torch.utils.data import DataLoader
@ -52,7 +52,7 @@ class PredictionLoop(DataLoaderLoop):
return length
@property
def max_batches(self) -> List[int]:
def max_batches(self) -> List[Union[int, float]]:
"""The max number of batches this loop will run for each dataloader."""
return self.trainer.num_predict_batches

View File

@ -292,6 +292,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
# TODO: fault-tolerance requires a minimum number of batches so probably should be > 0
and self.batch_progress.current.ready # did start
):
assert isinstance(trainer.train_dataloader, CombinedLoader)
loader: CombinedLoader = trainer.train_dataloader
state = loader.state_dict(has_completed=self._has_completed())
if state:

View File

@ -23,7 +23,7 @@ from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
@ -48,7 +48,7 @@ class FitLoop(Loop[None]):
def __init__(
self,
min_epochs: int = 0,
min_epochs: Optional[int] = 0,
max_epochs: Optional[int] = None,
) -> None:
super().__init__()
@ -233,6 +233,7 @@ class FitLoop(Loop[None]):
self._outputs = []
if self.trainer.train_dataloader is not None:
assert isinstance(self.trainer.train_dataloader, CombinedLoader)
_set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed)
# changing gradient according accumulation_scheduler

View File

@ -14,7 +14,7 @@
from collections import OrderedDict
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Generator, List, Optional, Sequence, Tuple
from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@ -29,6 +29,7 @@ from pytorch_lightning.loops import Loop
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
@ -69,7 +70,11 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in
def _parse_loop_limits(
min_steps: Optional[int], max_steps: int, min_epochs: Optional[int], max_epochs: int, trainer: "pl.Trainer"
min_steps: Optional[int],
max_steps: int,
min_epochs: Optional[int],
max_epochs: Optional[int],
trainer: "pl.Trainer",
) -> Tuple[int, int]:
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
values the user has selected.
@ -216,13 +221,14 @@ def _reset_progress(loop: Loop) -> None:
_reset_progress(v)
def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None:
def _set_sampler_epoch(dataloader: Union[DataLoader, CombinedLoader], epoch: int) -> None:
"""Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader.
Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
"""
for sampler_name in ("sampler", "batch_sampler"):
sampler = getattr(dataloader, sampler_name, None)
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):

View File

@ -101,6 +101,7 @@ class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
# `require_backward_grad_sync` will be reset in the
# ddp_strategy `post_training_step` hook
if not pl_module.automatic_optimization:
assert trainer.model is not None
trainer.model.require_backward_grad_sync = False # type: ignore[assignment]
return output
if trainer.testing:

View File

@ -109,7 +109,7 @@ class AcceleratorConnector:
sync_batchnorm: bool = False,
benchmark: Optional[bool] = None,
replace_sampler_ddp: bool = True,
deterministic: Union[bool, _LITERAL_WARN] = False,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = False,
auto_select_gpus: bool = False,
num_processes: Optional[int] = None, # deprecated
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
@ -663,7 +663,8 @@ class AcceleratorConnector:
if isinstance(self._strategy_flag, str):
self.strategy = StrategyRegistry.get(self._strategy_flag)
elif isinstance(self._strategy_flag, Strategy):
self.strategy = self._strategy_flag
# TODO(lite): remove ignore after merging lite and PL strategies
self.strategy = self._strategy_flag # type: ignore[assignment]
else:
raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")
@ -687,9 +688,7 @@ class AcceleratorConnector:
)
return TPUBf16PrecisionPlugin()
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecisionPlugin(
self._precision_flag, self._amp_type_flag, self._amp_level_flag # type: ignore
)
return DeepSpeedPrecisionPlugin(self._precision_flag, self._amp_type_flag, self._amp_level_flag)
if self._precision_flag == 32:
return PrecisionPlugin()

View File

@ -57,23 +57,17 @@ class DataConnector:
def _should_reload_train_dl(self) -> bool:
"""Check if train dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return n_epochs and (
self.trainer._last_train_dl_reload_epoch is None
or self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs
)
return n_epochs and self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs
@property
def _should_reload_val_dl(self) -> bool:
"""Check if validation dataloader should be reloaded."""
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
return n_epochs and (
self.trainer._last_val_dl_reload_epoch is None
or self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs
)
return n_epochs and self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs
def on_trainer_init(
self,
val_check_interval: Union[int, float],
val_check_interval: Optional[Union[int, float]],
reload_dataloaders_every_n_epochs: int,
check_val_every_n_epoch: Optional[int],
) -> None:
@ -347,7 +341,7 @@ class DataConnector:
def _reset_eval_dataloader(
self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
) -> Tuple[List[Union[float, int]], List[DataLoader]]:
"""Generic method to reset a dataloader for evaluation.
Args:
@ -387,7 +381,7 @@ class DataConnector:
dataloaders, dtype=DataLoader, function=_auto_add_worker_init_fn, rank=self.trainer.global_rank
)
loader_num_batches = []
loader_num_batches: List[Union[int, float]] = []
# determine number of batches
module = model or self.trainer.lightning_module or self.datamodule
@ -398,6 +392,7 @@ class DataConnector:
)
if orig_num_batches == 0:
assert isinstance(orig_num_batches, int)
loader_num_batches.append(orig_num_batches)
continue

View File

@ -28,7 +28,6 @@ from pytorch_lightning.utilities.auto_restart import (
MergedIteratorState,
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
@ -457,6 +456,8 @@ class CombinedLoader:
Returns:
the wrapped loaders
"""
from pytorch_lightning.utilities.data import get_len
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
length = _nested_calc_num_data(all_lengths, max)
@ -473,6 +474,8 @@ class CombinedLoader:
def _apply_cycle_iterator_length(self) -> None:
"""When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to
all dataloaders."""
from pytorch_lightning.utilities.data import get_len
if self.mode != "max_size_cycle":
return
@ -509,6 +512,8 @@ class CombinedLoader:
Returns:
length: the minimum length of loaders
"""
from pytorch_lightning.utilities.data import get_len
all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
if isinstance(all_lengths, (int, float)):

View File

@ -25,7 +25,7 @@ import logging
import math
import os
import warnings
from argparse import ArgumentParser, Namespace
from argparse import _ArgumentGroup, ArgumentParser, Namespace
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
@ -73,7 +73,7 @@ from pytorch_lightning.trainer.connectors.callback_connector import CallbackConn
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _PBAR_DICT, _ResultCollection
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.trainer.supporters import CombinedLoader
@ -457,7 +457,7 @@ class Trainer:
self._call_callback_hooks("on_init_start")
# init data flags
self.check_val_every_n_epoch: int
self.check_val_every_n_epoch: Optional[int]
self._data_connector.on_trainer_init(
val_check_interval,
reload_dataloaders_every_n_epochs,
@ -484,7 +484,7 @@ class Trainer:
f"`track_grad_norm` must be a positive number or 'inf' (infinity norm). Got {track_grad_norm}."
)
self.gradient_clip_val: Union[int, float] = gradient_clip_val
self.gradient_clip_val: Optional[Union[int, float]] = gradient_clip_val
self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = (
GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None
)
@ -504,8 +504,13 @@ class Trainer:
self._logger_connector.on_trainer_init(logger, log_every_n_steps, move_metrics_to_cpu)
# init debugging flags
self.val_check_batch: Union[int, float]
self.val_check_interval: Union[int, float]
self.num_sanity_val_steps: Union[float, int]
self.num_sanity_val_steps: Union[int, float]
self.limit_train_batches: Union[int, float]
self.limit_val_batches: Union[int, float]
self.limit_test_batches: Union[int, float]
self.limit_predict_batches: Union[int, float]
setup._init_debugging_flags(
self,
limit_train_batches,
@ -527,17 +532,19 @@ class Trainer:
self.should_stop = False
self.state = TrainerState()
self.num_training_batches = float("inf")
self.train_dataloader = None
self.num_sanity_val_batches = []
self.num_test_batches = []
self.num_val_batches = []
self.num_predict_batches = []
self.test_dataloaders = None
self.val_dataloaders = None
self.predict_dataloaders = None
self._last_train_dl_reload_epoch = None
self._last_val_dl_reload_epoch: Optional[int] = None
self.train_dataloader: Optional[Union[CombinedLoader, TRAIN_DATALOADERS]] = None
self.num_sanity_val_batches: List[Union[int, float]] = []
self.num_test_batches: List[Union[int, float]] = []
self.num_val_batches: List[Union[int, float]] = []
self.num_predict_batches: List[Union[int, float]] = []
self.test_dataloaders: Optional[List[DataLoader]] = None
self.val_dataloaders: Optional[List[DataLoader]] = None
self.predict_dataloaders: Optional[List[DataLoader]] = None
self._last_train_dl_reload_epoch = float("-inf")
self._last_val_dl_reload_epoch = float("-inf")
def fit(
self,
@ -605,13 +612,16 @@ class Trainer:
# TODO: ckpt_path only in v2.0
ckpt_path = ckpt_path or self.resume_from_checkpoint
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
self.state.fn, ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
self.state.fn,
ckpt_path, # type: ignore[arg-type]
model_provided=True,
model_connected=self.lightning_module is not None,
)
results = self._run(model, ckpt_path=self.ckpt_path)
self._run(model, ckpt_path=self.ckpt_path)
assert self.state.stopped
self.training = False
return results
return
def validate(
self,
@ -659,7 +669,7 @@ class Trainer:
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT:
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
# --------------------
# SETUP HOOK
# --------------------
@ -751,7 +761,7 @@ class Trainer:
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT:
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
# --------------------
# SETUP HOOK
# --------------------
@ -853,7 +863,7 @@ class Trainer:
self.state.status = TrainerStatus.RUNNING
self.predicting = True
self.predict_loop.return_predictions = return_predictions
self.predict_loop.return_predictions = return_predictions # type: ignore[assignment]
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
@ -1103,7 +1113,7 @@ class Trainer:
logger.log_graph(self.lightning_module)
logger.save()
def _teardown(self):
def _teardown(self) -> None:
"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
Callback; those are handled by :meth:`_call_teardown_hook`."""
self.strategy.teardown()
@ -1114,7 +1124,7 @@ class Trainer:
self._logger_connector.teardown()
self._signal_connector.teardown()
def _run_stage(self):
def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
self.strategy.barrier("run-stage")
self.strategy.dispatch(self)
@ -1122,9 +1132,9 @@ class Trainer:
return self._run_evaluate()
if self.predicting:
return self._run_predict()
return self._run_train()
self._run_train()
def _pre_training_routine(self):
def _pre_training_routine(self) -> None:
# wait for all to join if on distributed
self.strategy.barrier("setup_training")
@ -1147,6 +1157,7 @@ class Trainer:
self._run_sanity_check()
# enable train mode
assert self.model is not None
self.model.train()
torch.set_grad_enabled(True)
@ -1229,6 +1240,7 @@ class Trainer:
self.state.stage = stage
def _call_setup_hook(self) -> None:
assert self.state.fn is not None
fn = self.state.fn._setup_fn
self.strategy.barrier("pre_setup")
@ -1252,6 +1264,7 @@ class Trainer:
self._call_callback_hooks("on_configure_sharded_model")
def _call_teardown_hook(self) -> None:
assert self.state.fn is not None
fn = self.state.fn._setup_fn
if self.datamodule is not None:
@ -1392,7 +1405,7 @@ class Trainer:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = "on_load_checkpoint"
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks")
if callback_states is None:
return
@ -1420,7 +1433,7 @@ class Trainer:
def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint, calls every callback's `load_state_dict`."""
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks")
if callback_states is None:
return
@ -1458,6 +1471,7 @@ class Trainer:
torch._C._log_api_usage_once("lightning.trainer." + event)
def __setup_profiler(self) -> None:
assert self.state.fn is not None
local_rank = self.local_rank if self.world_size > 1 else None
self.profiler._lightning_module = proxy(self.lightning_module)
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
@ -1516,7 +1530,7 @@ class Trainer:
module = model or self.lightning_module or self.datamodule
orig_train_batches = self.num_training_batches = (
len(self.train_dataloader)
len(self.train_dataloader) # type: ignore[arg-type]
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
else float("inf")
)
@ -1661,11 +1675,13 @@ class Trainer:
@property
def accelerator(self) -> Accelerator:
assert self.strategy.accelerator
return self.strategy.accelerator
@property
def strategy(self) -> Strategy:
return self._accelerator_connector.strategy
# TODO(lite): remove ignore after merging lite and PL strategies
return self._accelerator_connector.strategy # type: ignore[return-value]
@property
def precision_plugin(self) -> PrecisionPlugin:
@ -1702,6 +1718,7 @@ class Trainer:
if isinstance(self.strategy, ParallelStrategy)
else [self.strategy.root_device]
)
assert devices is not None
device_ids = []
for idx, device in enumerate(devices):
if isinstance(device, torch.device):
@ -1718,14 +1735,14 @@ class Trainer:
@property
def lightning_module(self) -> "pl.LightningModule":
# TODO: this is actually an optional return
return self.strategy.lightning_module
return self.strategy.lightning_module # type: ignore[return-value]
@property
def optimizers(self) -> List[Optimizer]:
return self.strategy.optimizers
@optimizers.setter
def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
def optimizers(self, new_optims: List[Optimizer]) -> None:
self.strategy.optimizers = new_optims
@property
@ -1757,7 +1774,7 @@ class Trainer:
return getattr(self.precision_plugin, "scaler", None)
@property
def model(self) -> torch.nn.Module:
def model(self) -> Optional[torch.nn.Module]:
"""The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.
To access the pure LightningModule, use
@ -1783,10 +1800,10 @@ class Trainer:
@property
def log_dir(self) -> Optional[str]:
if len(self.loggers) > 0:
if isinstance(self.loggers[0], TensorBoardLogger):
dirpath = self.loggers[0].log_dir
else:
if not isinstance(self.loggers[0], TensorBoardLogger):
dirpath = self.loggers[0].save_dir
else:
dirpath = self.loggers[0].log_dir
else:
dirpath = self.default_root_dir
@ -1915,7 +1932,7 @@ class Trainer:
return {k: v.default for k, v in init_signature.parameters.items()}
@classmethod
def from_argparse_args(cls: Any, args: Union[Namespace, ArgumentParser], **kwargs) -> Any:
def from_argparse_args(cls: Any, args: Union[Namespace, ArgumentParser], **kwargs: Any) -> Any:
return from_argparse_args(cls, args, **kwargs)
@classmethod
@ -1927,7 +1944,7 @@ class Trainer:
return parse_env_variables(cls)
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> Union[_ArgumentGroup, ArgumentParser]:
return add_argparse_args(cls, parent_parser, **kwargs)
"""
@ -1995,7 +2012,7 @@ class Trainer:
@property
def evaluating(self) -> bool:
return self.state.stage and self.state.stage.evaluating
return self.state.stage is not None and self.state.stage.evaluating
@property
def sanity_checking(self) -> bool:
@ -2026,11 +2043,11 @@ class Trainer:
return self.fit_loop.epoch_progress.current.completed
@property
def max_epochs(self) -> int:
def max_epochs(self) -> Optional[int]:
return self.fit_loop.max_epochs
@property
def min_epochs(self) -> int:
def min_epochs(self) -> Optional[int]:
return self.fit_loop.min_epochs
@property
@ -2051,7 +2068,7 @@ class Trainer:
return self._fit_loop
@fit_loop.setter
def fit_loop(self, loop: FitLoop):
def fit_loop(self, loop: FitLoop) -> None:
"""Attach a custom fit loop to this Trainer.
It will run with
@ -2065,7 +2082,7 @@ class Trainer:
return self._validate_loop
@validate_loop.setter
def validate_loop(self, loop: EvaluationLoop):
def validate_loop(self, loop: EvaluationLoop) -> None:
"""Attach a custom validation loop to this Trainer.
It will run with
@ -2080,7 +2097,7 @@ class Trainer:
return self._test_loop
@test_loop.setter
def test_loop(self, loop: EvaluationLoop):
def test_loop(self, loop: EvaluationLoop) -> None:
"""Attach a custom test loop to this Trainer.
It will run with
@ -2094,7 +2111,7 @@ class Trainer:
return self._predict_loop
@predict_loop.setter
def predict_loop(self, loop: PredictionLoop):
def predict_loop(self, loop: PredictionLoop) -> None:
"""Attach a custom prediction loop to this Trainer.
It will run with
@ -2146,17 +2163,17 @@ class Trainer:
self._loggers = loggers if loggers else []
@property
def callback_metrics(self) -> Dict[str, Tensor]:
def callback_metrics(self) -> Dict:
# TODO: the true typing return can include dictionaries as defined in
# `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT`
return self._logger_connector.callback_metrics
@property
def logged_metrics(self) -> dict:
def logged_metrics(self) -> _OUT_DICT:
return self._logger_connector.logged_metrics
@property
def progress_bar_metrics(self) -> dict:
def progress_bar_metrics(self) -> _PBAR_DICT:
return self._logger_connector.progress_bar_metrics
@property
@ -2172,7 +2189,7 @@ class Trainer:
def _should_terminate_gracefully(self) -> bool:
value = torch.tensor(int(self._terminate_gracefully), device=self.strategy.root_device)
return self.strategy.reduce(value, reduce_op="sum") > 0
return bool(self.strategy.reduce(value, reduce_op="sum") > 0)
"""
Other
@ -2216,6 +2233,7 @@ class Trainer:
if total_batches == float("inf"):
return self.max_steps
assert self.max_epochs is not None
self.accumulate_grad_batches = accumulation_scheduler.get_accumulate_grad_batches(self.current_epoch)
effective_batch_size = self.accumulate_grad_batches
max_estimated_steps = math.ceil(total_batches / effective_batch_size) * max(self.max_epochs, 1)

View File

@ -17,8 +17,6 @@ import uuid
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_setattr
@ -277,13 +275,15 @@ def _adjust_batch_size(
rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
if trainer.state.fn == "fit":
from pytorch_lightning.trainer.supporters import CombinedLoader
if trainer.train_dataloader is None:
trainer.reset_train_dataloader()
assert trainer.train_dataloader is not None
# TODO: should we check val_dataloaders here too?
assert isinstance(trainer.train_dataloader, CombinedLoader)
if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer):
new_size = min(new_size, len(trainer.train_dataloader.dataset))
# at this moment, `train_dataloader` is already a CombinedLoader. len can return a size or infinity
new_size = min(new_size, len(trainer.train_dataloader.dataset)) # type: ignore[arg-type]
else:
stage = trainer.state.stage
assert stage is not None
@ -302,11 +302,14 @@ def _adjust_batch_size(
return new_size, changed
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer") -> bool:
def _is_valid_batch_size(
batch_size: int, dataloader: "pl.trainer.supporters.CombinedLoader", trainer: "pl.Trainer"
) -> bool:
from pytorch_lightning.utilities.data import has_len_all_ranks
module = trainer.lightning_module or trainer.datamodule
return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader)
has_len = has_len_all_ranks(dataloader, trainer.strategy, module)
return not has_len or batch_size <= len(dataloader) # type: ignore[arg-type]
def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

View File

@ -35,6 +35,7 @@ from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterab
from lightning_lite.utilities.data import has_len as new_has_len
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -96,14 +97,14 @@ def extract_batch_size(batch: BType) -> int:
def has_len_all_ranks(
dataloader: DataLoader,
dataloader: Union[DataLoader, CombinedLoader],
strategy: "pl.strategies.Strategy",
model: Union["pl.LightningModule", "pl.LightningDataModule"],
) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader."""
try:
local_length = len(dataloader)
local_length = len(dataloader) # type: ignore [arg-type] # we are checking with duck-typing
total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum")
if total_length == 0:
@ -129,7 +130,8 @@ def has_len_all_ranks(
except (TypeError, NotImplementedError):
has_len = False
if has_len and new_has_iterable_dataset(dataloader):
# we are checking using lightning_lite, which doesn't know CombinedLoader
if has_len and new_has_iterable_dataset(dataloader): # type: ignore [arg-type]
rank_zero_warn(
"Your `IterableDataset` has `__len__` defined."
" In combination with multi-process data loading (when num_workers > 1),"

View File

@ -17,9 +17,8 @@ import copy
import inspect
import pickle
import types
from argparse import Namespace
from dataclasses import fields, is_dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import Any, Dict, List, MutableMapping, Optional, Sequence, Tuple, Type, Union
from torch import nn
from typing_extensions import Literal
@ -94,18 +93,13 @@ def is_picklable(obj: object) -> bool:
return False
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
def clean_namespace(hparams: MutableMapping) -> None:
"""Removes all unpicklable entries from hparams."""
hparams_dict = hparams
if isinstance(hparams, Namespace):
hparams_dict = hparams.__dict__
del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]
del_attrs = [k for k, v in hparams.items() if not is_picklable(v)]
for k in del_attrs:
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
del hparams_dict[k]
del hparams[k]
def parse_class_init_keys(