Fix mypy typing errors in pytorch_lightning/trainer/trainer.py (#14204)
Co-authored-by: otaj <ota@lightning.ai> Co-authored-by: Jirka <jirka.borovec@seznam.cz> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
021c2f1447
commit
a9142d637a
|
@ -57,7 +57,6 @@ warn_no_return = "False"
|
|||
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
|
||||
module = [
|
||||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
"pytorch_lightning.trainer.trainer",
|
||||
"lightning_app.api.http_methods",
|
||||
"lightning_app.api.request_types",
|
||||
"lightning_app.cli.app-template.app",
|
||||
|
|
|
@ -82,7 +82,7 @@ class _StrategyRegistry(dict):
|
|||
|
||||
return do_register
|
||||
|
||||
def get(self, name: str, default: Optional[Any] = None) -> Any:
|
||||
def get(self, name: str, default: Optional[Strategy] = None) -> Strategy: # type: ignore[override]
|
||||
"""Calls the registered strategy with the required parameters and returns the strategy object.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -222,7 +222,9 @@ class ProgressBarBase(Callback):
|
|||
if not trainer.is_global_zero:
|
||||
self.disable()
|
||||
|
||||
def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
|
||||
def get_metrics(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
|
||||
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
|
||||
r"""
|
||||
Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics.
|
||||
Implement this to override the items displayed in the progress bar.
|
||||
|
|
|
@ -160,6 +160,7 @@ class StochasticWeightAveraging(Callback):
|
|||
if len(trainer.lr_scheduler_configs) > 1:
|
||||
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
|
||||
|
||||
assert trainer.max_epochs is not None
|
||||
if isinstance(self._swa_epoch_start, float):
|
||||
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)
|
||||
|
||||
|
@ -189,6 +190,7 @@ class StochasticWeightAveraging(Callback):
|
|||
for lr, group in zip(self._swa_lrs, optimizer.param_groups):
|
||||
group["initial_lr"] = lr
|
||||
|
||||
assert trainer.max_epochs is not None
|
||||
self._swa_scheduler = cast(
|
||||
_LRScheduler,
|
||||
SWALR(
|
||||
|
|
|
@ -56,7 +56,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
self._results = _ResultCollection(training=False)
|
||||
self._outputs: List[EPOCH_OUTPUT] = []
|
||||
self._logged_outputs: List[_OUT_DICT] = []
|
||||
self._max_batches: List[int] = []
|
||||
self._max_batches: List[Union[int, float]] = []
|
||||
self._has_run: bool = False
|
||||
self._data_fetcher: Optional[AbstractDataFetcher] = None
|
||||
|
||||
|
@ -213,7 +213,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
self._results.cpu()
|
||||
self.epoch_loop.teardown()
|
||||
|
||||
def _get_max_batches(self) -> List[int]:
|
||||
def _get_max_batches(self) -> List[Union[int, float]]:
|
||||
"""Returns the max number of batches for each dataloader."""
|
||||
if self.trainer.testing:
|
||||
max_batches = self.trainer.num_test_batches
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, List, Optional, Sequence
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -52,7 +52,7 @@ class PredictionLoop(DataLoaderLoop):
|
|||
return length
|
||||
|
||||
@property
|
||||
def max_batches(self) -> List[int]:
|
||||
def max_batches(self) -> List[Union[int, float]]:
|
||||
"""The max number of batches this loop will run for each dataloader."""
|
||||
return self.trainer.num_predict_batches
|
||||
|
||||
|
|
|
@ -292,6 +292,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
# TODO: fault-tolerance requires a minimum number of batches so probably should be > 0
|
||||
and self.batch_progress.current.ready # did start
|
||||
):
|
||||
assert isinstance(trainer.train_dataloader, CombinedLoader)
|
||||
loader: CombinedLoader = trainer.train_dataloader
|
||||
state = loader.state_dict(has_completed=self._has_completed())
|
||||
if state:
|
||||
|
|
|
@ -23,7 +23,7 @@ from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _
|
|||
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.progress import Progress
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import (
|
||||
AbstractDataFetcher,
|
||||
|
@ -48,7 +48,7 @@ class FitLoop(Loop[None]):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
min_epochs: int = 0,
|
||||
min_epochs: Optional[int] = 0,
|
||||
max_epochs: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -233,6 +233,7 @@ class FitLoop(Loop[None]):
|
|||
self._outputs = []
|
||||
|
||||
if self.trainer.train_dataloader is not None:
|
||||
assert isinstance(self.trainer.train_dataloader, CombinedLoader)
|
||||
_set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed)
|
||||
|
||||
# changing gradient according accumulation_scheduler
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from typing import Any, Generator, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -29,6 +29,7 @@ from pytorch_lightning.loops import Loop
|
|||
from pytorch_lightning.strategies.parallel import ParallelStrategy
|
||||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
@ -69,7 +70,11 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in
|
|||
|
||||
|
||||
def _parse_loop_limits(
|
||||
min_steps: Optional[int], max_steps: int, min_epochs: Optional[int], max_epochs: int, trainer: "pl.Trainer"
|
||||
min_steps: Optional[int],
|
||||
max_steps: int,
|
||||
min_epochs: Optional[int],
|
||||
max_epochs: Optional[int],
|
||||
trainer: "pl.Trainer",
|
||||
) -> Tuple[int, int]:
|
||||
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
|
||||
values the user has selected.
|
||||
|
@ -216,13 +221,14 @@ def _reset_progress(loop: Loop) -> None:
|
|||
_reset_progress(v)
|
||||
|
||||
|
||||
def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None:
|
||||
def _set_sampler_epoch(dataloader: Union[DataLoader, CombinedLoader], epoch: int) -> None:
|
||||
"""Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader.
|
||||
|
||||
Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a
|
||||
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
|
||||
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
|
||||
"""
|
||||
|
||||
for sampler_name in ("sampler", "batch_sampler"):
|
||||
sampler = getattr(dataloader, sampler_name, None)
|
||||
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):
|
||||
|
|
|
@ -101,6 +101,7 @@ class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
|
|||
# `require_backward_grad_sync` will be reset in the
|
||||
# ddp_strategy `post_training_step` hook
|
||||
if not pl_module.automatic_optimization:
|
||||
assert trainer.model is not None
|
||||
trainer.model.require_backward_grad_sync = False # type: ignore[assignment]
|
||||
return output
|
||||
if trainer.testing:
|
||||
|
|
|
@ -109,7 +109,7 @@ class AcceleratorConnector:
|
|||
sync_batchnorm: bool = False,
|
||||
benchmark: Optional[bool] = None,
|
||||
replace_sampler_ddp: bool = True,
|
||||
deterministic: Union[bool, _LITERAL_WARN] = False,
|
||||
deterministic: Optional[Union[bool, _LITERAL_WARN]] = False,
|
||||
auto_select_gpus: bool = False,
|
||||
num_processes: Optional[int] = None, # deprecated
|
||||
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
|
||||
|
@ -663,7 +663,8 @@ class AcceleratorConnector:
|
|||
if isinstance(self._strategy_flag, str):
|
||||
self.strategy = StrategyRegistry.get(self._strategy_flag)
|
||||
elif isinstance(self._strategy_flag, Strategy):
|
||||
self.strategy = self._strategy_flag
|
||||
# TODO(lite): remove ignore after merging lite and PL strategies
|
||||
self.strategy = self._strategy_flag # type: ignore[assignment]
|
||||
else:
|
||||
raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}")
|
||||
|
||||
|
@ -687,9 +688,7 @@ class AcceleratorConnector:
|
|||
)
|
||||
return TPUBf16PrecisionPlugin()
|
||||
if isinstance(self.strategy, DeepSpeedStrategy):
|
||||
return DeepSpeedPrecisionPlugin(
|
||||
self._precision_flag, self._amp_type_flag, self._amp_level_flag # type: ignore
|
||||
)
|
||||
return DeepSpeedPrecisionPlugin(self._precision_flag, self._amp_type_flag, self._amp_level_flag)
|
||||
|
||||
if self._precision_flag == 32:
|
||||
return PrecisionPlugin()
|
||||
|
|
|
@ -57,23 +57,17 @@ class DataConnector:
|
|||
def _should_reload_train_dl(self) -> bool:
|
||||
"""Check if train dataloader should be reloaded."""
|
||||
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
|
||||
return n_epochs and (
|
||||
self.trainer._last_train_dl_reload_epoch is None
|
||||
or self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs
|
||||
)
|
||||
return n_epochs and self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs
|
||||
|
||||
@property
|
||||
def _should_reload_val_dl(self) -> bool:
|
||||
"""Check if validation dataloader should be reloaded."""
|
||||
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
|
||||
return n_epochs and (
|
||||
self.trainer._last_val_dl_reload_epoch is None
|
||||
or self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs
|
||||
)
|
||||
return n_epochs and self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs
|
||||
|
||||
def on_trainer_init(
|
||||
self,
|
||||
val_check_interval: Union[int, float],
|
||||
val_check_interval: Optional[Union[int, float]],
|
||||
reload_dataloaders_every_n_epochs: int,
|
||||
check_val_every_n_epoch: Optional[int],
|
||||
) -> None:
|
||||
|
@ -347,7 +341,7 @@ class DataConnector:
|
|||
|
||||
def _reset_eval_dataloader(
|
||||
self, mode: RunningStage, model: Optional["pl.LightningModule"] = None
|
||||
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
|
||||
) -> Tuple[List[Union[float, int]], List[DataLoader]]:
|
||||
"""Generic method to reset a dataloader for evaluation.
|
||||
|
||||
Args:
|
||||
|
@ -387,7 +381,7 @@ class DataConnector:
|
|||
dataloaders, dtype=DataLoader, function=_auto_add_worker_init_fn, rank=self.trainer.global_rank
|
||||
)
|
||||
|
||||
loader_num_batches = []
|
||||
loader_num_batches: List[Union[int, float]] = []
|
||||
|
||||
# determine number of batches
|
||||
module = model or self.trainer.lightning_module or self.datamodule
|
||||
|
@ -398,6 +392,7 @@ class DataConnector:
|
|||
)
|
||||
|
||||
if orig_num_batches == 0:
|
||||
assert isinstance(orig_num_batches, int)
|
||||
loader_num_batches.append(orig_num_batches)
|
||||
continue
|
||||
|
||||
|
|
|
@ -28,7 +28,6 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
MergedIteratorState,
|
||||
patch_dataloader_iterator,
|
||||
)
|
||||
from pytorch_lightning.utilities.data import get_len
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
@ -457,6 +456,8 @@ class CombinedLoader:
|
|||
Returns:
|
||||
the wrapped loaders
|
||||
"""
|
||||
from pytorch_lightning.utilities.data import get_len
|
||||
|
||||
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
|
||||
|
||||
length = _nested_calc_num_data(all_lengths, max)
|
||||
|
@ -473,6 +474,8 @@ class CombinedLoader:
|
|||
def _apply_cycle_iterator_length(self) -> None:
|
||||
"""When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to
|
||||
all dataloaders."""
|
||||
from pytorch_lightning.utilities.data import get_len
|
||||
|
||||
if self.mode != "max_size_cycle":
|
||||
return
|
||||
|
||||
|
@ -509,6 +512,8 @@ class CombinedLoader:
|
|||
Returns:
|
||||
length: the minimum length of loaders
|
||||
"""
|
||||
from pytorch_lightning.utilities.data import get_len
|
||||
|
||||
all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
|
||||
|
||||
if isinstance(all_lengths, (int, float)):
|
||||
|
|
|
@ -25,7 +25,7 @@ import logging
|
|||
import math
|
||||
import os
|
||||
import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from argparse import _ArgumentGroup, ArgumentParser, Namespace
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
|
@ -73,7 +73,7 @@ from pytorch_lightning.trainer.connectors.callback_connector import CallbackConn
|
|||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
|
||||
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
||||
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _PBAR_DICT, _ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
|
@ -457,7 +457,7 @@ class Trainer:
|
|||
self._call_callback_hooks("on_init_start")
|
||||
|
||||
# init data flags
|
||||
self.check_val_every_n_epoch: int
|
||||
self.check_val_every_n_epoch: Optional[int]
|
||||
self._data_connector.on_trainer_init(
|
||||
val_check_interval,
|
||||
reload_dataloaders_every_n_epochs,
|
||||
|
@ -484,7 +484,7 @@ class Trainer:
|
|||
f"`track_grad_norm` must be a positive number or 'inf' (infinity norm). Got {track_grad_norm}."
|
||||
)
|
||||
|
||||
self.gradient_clip_val: Union[int, float] = gradient_clip_val
|
||||
self.gradient_clip_val: Optional[Union[int, float]] = gradient_clip_val
|
||||
self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = (
|
||||
GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None
|
||||
)
|
||||
|
@ -504,8 +504,13 @@ class Trainer:
|
|||
self._logger_connector.on_trainer_init(logger, log_every_n_steps, move_metrics_to_cpu)
|
||||
|
||||
# init debugging flags
|
||||
self.val_check_batch: Union[int, float]
|
||||
self.val_check_interval: Union[int, float]
|
||||
self.num_sanity_val_steps: Union[float, int]
|
||||
self.num_sanity_val_steps: Union[int, float]
|
||||
self.limit_train_batches: Union[int, float]
|
||||
self.limit_val_batches: Union[int, float]
|
||||
self.limit_test_batches: Union[int, float]
|
||||
self.limit_predict_batches: Union[int, float]
|
||||
setup._init_debugging_flags(
|
||||
self,
|
||||
limit_train_batches,
|
||||
|
@ -527,17 +532,19 @@ class Trainer:
|
|||
self.should_stop = False
|
||||
self.state = TrainerState()
|
||||
self.num_training_batches = float("inf")
|
||||
self.train_dataloader = None
|
||||
|
||||
self.num_sanity_val_batches = []
|
||||
self.num_test_batches = []
|
||||
self.num_val_batches = []
|
||||
self.num_predict_batches = []
|
||||
self.test_dataloaders = None
|
||||
self.val_dataloaders = None
|
||||
self.predict_dataloaders = None
|
||||
self._last_train_dl_reload_epoch = None
|
||||
self._last_val_dl_reload_epoch: Optional[int] = None
|
||||
self.train_dataloader: Optional[Union[CombinedLoader, TRAIN_DATALOADERS]] = None
|
||||
|
||||
self.num_sanity_val_batches: List[Union[int, float]] = []
|
||||
self.num_test_batches: List[Union[int, float]] = []
|
||||
self.num_val_batches: List[Union[int, float]] = []
|
||||
self.num_predict_batches: List[Union[int, float]] = []
|
||||
|
||||
self.test_dataloaders: Optional[List[DataLoader]] = None
|
||||
self.val_dataloaders: Optional[List[DataLoader]] = None
|
||||
self.predict_dataloaders: Optional[List[DataLoader]] = None
|
||||
self._last_train_dl_reload_epoch = float("-inf")
|
||||
self._last_val_dl_reload_epoch = float("-inf")
|
||||
|
||||
def fit(
|
||||
self,
|
||||
|
@ -605,13 +612,16 @@ class Trainer:
|
|||
# TODO: ckpt_path only in v2.0
|
||||
ckpt_path = ckpt_path or self.resume_from_checkpoint
|
||||
self._ckpt_path = self._checkpoint_connector._set_ckpt_path(
|
||||
self.state.fn, ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
|
||||
self.state.fn,
|
||||
ckpt_path, # type: ignore[arg-type]
|
||||
model_provided=True,
|
||||
model_connected=self.lightning_module is not None,
|
||||
)
|
||||
results = self._run(model, ckpt_path=self.ckpt_path)
|
||||
self._run(model, ckpt_path=self.ckpt_path)
|
||||
|
||||
assert self.state.stopped
|
||||
self.training = False
|
||||
return results
|
||||
return
|
||||
|
||||
def validate(
|
||||
self,
|
||||
|
@ -659,7 +669,7 @@ class Trainer:
|
|||
ckpt_path: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
datamodule: Optional[LightningDataModule] = None,
|
||||
) -> _EVALUATE_OUTPUT:
|
||||
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
|
||||
# --------------------
|
||||
# SETUP HOOK
|
||||
# --------------------
|
||||
|
@ -751,7 +761,7 @@ class Trainer:
|
|||
ckpt_path: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
datamodule: Optional[LightningDataModule] = None,
|
||||
) -> _EVALUATE_OUTPUT:
|
||||
) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
|
||||
# --------------------
|
||||
# SETUP HOOK
|
||||
# --------------------
|
||||
|
@ -853,7 +863,7 @@ class Trainer:
|
|||
self.state.status = TrainerStatus.RUNNING
|
||||
self.predicting = True
|
||||
|
||||
self.predict_loop.return_predictions = return_predictions
|
||||
self.predict_loop.return_predictions = return_predictions # type: ignore[assignment]
|
||||
|
||||
# if a datamodule comes in as the second arg, then fix it for the user
|
||||
if isinstance(dataloaders, LightningDataModule):
|
||||
|
@ -1103,7 +1113,7 @@ class Trainer:
|
|||
logger.log_graph(self.lightning_module)
|
||||
logger.save()
|
||||
|
||||
def _teardown(self):
|
||||
def _teardown(self) -> None:
|
||||
"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
|
||||
Callback; those are handled by :meth:`_call_teardown_hook`."""
|
||||
self.strategy.teardown()
|
||||
|
@ -1114,7 +1124,7 @@ class Trainer:
|
|||
self._logger_connector.teardown()
|
||||
self._signal_connector.teardown()
|
||||
|
||||
def _run_stage(self):
|
||||
def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]:
|
||||
self.strategy.barrier("run-stage")
|
||||
self.strategy.dispatch(self)
|
||||
|
||||
|
@ -1122,9 +1132,9 @@ class Trainer:
|
|||
return self._run_evaluate()
|
||||
if self.predicting:
|
||||
return self._run_predict()
|
||||
return self._run_train()
|
||||
self._run_train()
|
||||
|
||||
def _pre_training_routine(self):
|
||||
def _pre_training_routine(self) -> None:
|
||||
# wait for all to join if on distributed
|
||||
self.strategy.barrier("setup_training")
|
||||
|
||||
|
@ -1147,6 +1157,7 @@ class Trainer:
|
|||
self._run_sanity_check()
|
||||
|
||||
# enable train mode
|
||||
assert self.model is not None
|
||||
self.model.train()
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
|
@ -1229,6 +1240,7 @@ class Trainer:
|
|||
self.state.stage = stage
|
||||
|
||||
def _call_setup_hook(self) -> None:
|
||||
assert self.state.fn is not None
|
||||
fn = self.state.fn._setup_fn
|
||||
|
||||
self.strategy.barrier("pre_setup")
|
||||
|
@ -1252,6 +1264,7 @@ class Trainer:
|
|||
self._call_callback_hooks("on_configure_sharded_model")
|
||||
|
||||
def _call_teardown_hook(self) -> None:
|
||||
assert self.state.fn is not None
|
||||
fn = self.state.fn._setup_fn
|
||||
|
||||
if self.datamodule is not None:
|
||||
|
@ -1392,7 +1405,7 @@ class Trainer:
|
|||
prev_fx_name = pl_module._current_fx_name
|
||||
pl_module._current_fx_name = "on_load_checkpoint"
|
||||
|
||||
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
|
||||
callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks")
|
||||
|
||||
if callback_states is None:
|
||||
return
|
||||
|
@ -1420,7 +1433,7 @@ class Trainer:
|
|||
|
||||
def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None:
|
||||
"""Called when loading a model checkpoint, calls every callback's `load_state_dict`."""
|
||||
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
|
||||
callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks")
|
||||
|
||||
if callback_states is None:
|
||||
return
|
||||
|
@ -1458,6 +1471,7 @@ class Trainer:
|
|||
torch._C._log_api_usage_once("lightning.trainer." + event)
|
||||
|
||||
def __setup_profiler(self) -> None:
|
||||
assert self.state.fn is not None
|
||||
local_rank = self.local_rank if self.world_size > 1 else None
|
||||
self.profiler._lightning_module = proxy(self.lightning_module)
|
||||
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
|
||||
|
@ -1516,7 +1530,7 @@ class Trainer:
|
|||
|
||||
module = model or self.lightning_module or self.datamodule
|
||||
orig_train_batches = self.num_training_batches = (
|
||||
len(self.train_dataloader)
|
||||
len(self.train_dataloader) # type: ignore[arg-type]
|
||||
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
|
||||
else float("inf")
|
||||
)
|
||||
|
@ -1661,11 +1675,13 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def accelerator(self) -> Accelerator:
|
||||
assert self.strategy.accelerator
|
||||
return self.strategy.accelerator
|
||||
|
||||
@property
|
||||
def strategy(self) -> Strategy:
|
||||
return self._accelerator_connector.strategy
|
||||
# TODO(lite): remove ignore after merging lite and PL strategies
|
||||
return self._accelerator_connector.strategy # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
def precision_plugin(self) -> PrecisionPlugin:
|
||||
|
@ -1702,6 +1718,7 @@ class Trainer:
|
|||
if isinstance(self.strategy, ParallelStrategy)
|
||||
else [self.strategy.root_device]
|
||||
)
|
||||
assert devices is not None
|
||||
device_ids = []
|
||||
for idx, device in enumerate(devices):
|
||||
if isinstance(device, torch.device):
|
||||
|
@ -1718,14 +1735,14 @@ class Trainer:
|
|||
@property
|
||||
def lightning_module(self) -> "pl.LightningModule":
|
||||
# TODO: this is actually an optional return
|
||||
return self.strategy.lightning_module
|
||||
return self.strategy.lightning_module # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
def optimizers(self) -> List[Optimizer]:
|
||||
return self.strategy.optimizers
|
||||
|
||||
@optimizers.setter
|
||||
def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
|
||||
def optimizers(self, new_optims: List[Optimizer]) -> None:
|
||||
self.strategy.optimizers = new_optims
|
||||
|
||||
@property
|
||||
|
@ -1757,7 +1774,7 @@ class Trainer:
|
|||
return getattr(self.precision_plugin, "scaler", None)
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
def model(self) -> Optional[torch.nn.Module]:
|
||||
"""The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel.
|
||||
|
||||
To access the pure LightningModule, use
|
||||
|
@ -1783,10 +1800,10 @@ class Trainer:
|
|||
@property
|
||||
def log_dir(self) -> Optional[str]:
|
||||
if len(self.loggers) > 0:
|
||||
if isinstance(self.loggers[0], TensorBoardLogger):
|
||||
dirpath = self.loggers[0].log_dir
|
||||
else:
|
||||
if not isinstance(self.loggers[0], TensorBoardLogger):
|
||||
dirpath = self.loggers[0].save_dir
|
||||
else:
|
||||
dirpath = self.loggers[0].log_dir
|
||||
else:
|
||||
dirpath = self.default_root_dir
|
||||
|
||||
|
@ -1915,7 +1932,7 @@ class Trainer:
|
|||
return {k: v.default for k, v in init_signature.parameters.items()}
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls: Any, args: Union[Namespace, ArgumentParser], **kwargs) -> Any:
|
||||
def from_argparse_args(cls: Any, args: Union[Namespace, ArgumentParser], **kwargs: Any) -> Any:
|
||||
return from_argparse_args(cls, args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
@ -1927,7 +1944,7 @@ class Trainer:
|
|||
return parse_env_variables(cls)
|
||||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> Union[_ArgumentGroup, ArgumentParser]:
|
||||
return add_argparse_args(cls, parent_parser, **kwargs)
|
||||
|
||||
"""
|
||||
|
@ -1995,7 +2012,7 @@ class Trainer:
|
|||
|
||||
@property
|
||||
def evaluating(self) -> bool:
|
||||
return self.state.stage and self.state.stage.evaluating
|
||||
return self.state.stage is not None and self.state.stage.evaluating
|
||||
|
||||
@property
|
||||
def sanity_checking(self) -> bool:
|
||||
|
@ -2026,11 +2043,11 @@ class Trainer:
|
|||
return self.fit_loop.epoch_progress.current.completed
|
||||
|
||||
@property
|
||||
def max_epochs(self) -> int:
|
||||
def max_epochs(self) -> Optional[int]:
|
||||
return self.fit_loop.max_epochs
|
||||
|
||||
@property
|
||||
def min_epochs(self) -> int:
|
||||
def min_epochs(self) -> Optional[int]:
|
||||
return self.fit_loop.min_epochs
|
||||
|
||||
@property
|
||||
|
@ -2051,7 +2068,7 @@ class Trainer:
|
|||
return self._fit_loop
|
||||
|
||||
@fit_loop.setter
|
||||
def fit_loop(self, loop: FitLoop):
|
||||
def fit_loop(self, loop: FitLoop) -> None:
|
||||
"""Attach a custom fit loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
|
@ -2065,7 +2082,7 @@ class Trainer:
|
|||
return self._validate_loop
|
||||
|
||||
@validate_loop.setter
|
||||
def validate_loop(self, loop: EvaluationLoop):
|
||||
def validate_loop(self, loop: EvaluationLoop) -> None:
|
||||
"""Attach a custom validation loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
|
@ -2080,7 +2097,7 @@ class Trainer:
|
|||
return self._test_loop
|
||||
|
||||
@test_loop.setter
|
||||
def test_loop(self, loop: EvaluationLoop):
|
||||
def test_loop(self, loop: EvaluationLoop) -> None:
|
||||
"""Attach a custom test loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
|
@ -2094,7 +2111,7 @@ class Trainer:
|
|||
return self._predict_loop
|
||||
|
||||
@predict_loop.setter
|
||||
def predict_loop(self, loop: PredictionLoop):
|
||||
def predict_loop(self, loop: PredictionLoop) -> None:
|
||||
"""Attach a custom prediction loop to this Trainer.
|
||||
|
||||
It will run with
|
||||
|
@ -2146,17 +2163,17 @@ class Trainer:
|
|||
self._loggers = loggers if loggers else []
|
||||
|
||||
@property
|
||||
def callback_metrics(self) -> Dict[str, Tensor]:
|
||||
def callback_metrics(self) -> Dict:
|
||||
# TODO: the true typing return can include dictionaries as defined in
|
||||
# `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT`
|
||||
return self._logger_connector.callback_metrics
|
||||
|
||||
@property
|
||||
def logged_metrics(self) -> dict:
|
||||
def logged_metrics(self) -> _OUT_DICT:
|
||||
return self._logger_connector.logged_metrics
|
||||
|
||||
@property
|
||||
def progress_bar_metrics(self) -> dict:
|
||||
def progress_bar_metrics(self) -> _PBAR_DICT:
|
||||
return self._logger_connector.progress_bar_metrics
|
||||
|
||||
@property
|
||||
|
@ -2172,7 +2189,7 @@ class Trainer:
|
|||
|
||||
def _should_terminate_gracefully(self) -> bool:
|
||||
value = torch.tensor(int(self._terminate_gracefully), device=self.strategy.root_device)
|
||||
return self.strategy.reduce(value, reduce_op="sum") > 0
|
||||
return bool(self.strategy.reduce(value, reduce_op="sum") > 0)
|
||||
|
||||
"""
|
||||
Other
|
||||
|
@ -2216,6 +2233,7 @@ class Trainer:
|
|||
if total_batches == float("inf"):
|
||||
return self.max_steps
|
||||
|
||||
assert self.max_epochs is not None
|
||||
self.accumulate_grad_batches = accumulation_scheduler.get_accumulate_grad_batches(self.current_epoch)
|
||||
effective_batch_size = self.accumulate_grad_batches
|
||||
max_estimated_steps = math.ceil(total_batches / effective_batch_size) * max(self.max_epochs, 1)
|
||||
|
|
|
@ -17,8 +17,6 @@ import uuid
|
|||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
|
||||
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_setattr
|
||||
|
@ -277,13 +275,15 @@ def _adjust_batch_size(
|
|||
rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
|
||||
|
||||
if trainer.state.fn == "fit":
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
|
||||
if trainer.train_dataloader is None:
|
||||
trainer.reset_train_dataloader()
|
||||
|
||||
assert trainer.train_dataloader is not None
|
||||
# TODO: should we check val_dataloaders here too?
|
||||
assert isinstance(trainer.train_dataloader, CombinedLoader)
|
||||
if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer):
|
||||
new_size = min(new_size, len(trainer.train_dataloader.dataset))
|
||||
# at this moment, `train_dataloader` is already a CombinedLoader. len can return a size or infinity
|
||||
new_size = min(new_size, len(trainer.train_dataloader.dataset)) # type: ignore[arg-type]
|
||||
else:
|
||||
stage = trainer.state.stage
|
||||
assert stage is not None
|
||||
|
@ -302,11 +302,14 @@ def _adjust_batch_size(
|
|||
return new_size, changed
|
||||
|
||||
|
||||
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer") -> bool:
|
||||
def _is_valid_batch_size(
|
||||
batch_size: int, dataloader: "pl.trainer.supporters.CombinedLoader", trainer: "pl.Trainer"
|
||||
) -> bool:
|
||||
from pytorch_lightning.utilities.data import has_len_all_ranks
|
||||
|
||||
module = trainer.lightning_module or trainer.datamodule
|
||||
return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader)
|
||||
has_len = has_len_all_ranks(dataloader, trainer.strategy, module)
|
||||
return not has_len or batch_size <= len(dataloader) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
|
|
|
@ -35,6 +35,7 @@ from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterab
|
|||
from lightning_lite.utilities.data import has_len as new_has_len
|
||||
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -96,14 +97,14 @@ def extract_batch_size(batch: BType) -> int:
|
|||
|
||||
|
||||
def has_len_all_ranks(
|
||||
dataloader: DataLoader,
|
||||
dataloader: Union[DataLoader, CombinedLoader],
|
||||
strategy: "pl.strategies.Strategy",
|
||||
model: Union["pl.LightningModule", "pl.LightningDataModule"],
|
||||
) -> bool:
|
||||
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
|
||||
infinite dataloader."""
|
||||
try:
|
||||
local_length = len(dataloader)
|
||||
local_length = len(dataloader) # type: ignore [arg-type] # we are checking with duck-typing
|
||||
total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum")
|
||||
|
||||
if total_length == 0:
|
||||
|
@ -129,7 +130,8 @@ def has_len_all_ranks(
|
|||
except (TypeError, NotImplementedError):
|
||||
has_len = False
|
||||
|
||||
if has_len and new_has_iterable_dataset(dataloader):
|
||||
# we are checking using lightning_lite, which doesn't know CombinedLoader
|
||||
if has_len and new_has_iterable_dataset(dataloader): # type: ignore [arg-type]
|
||||
rank_zero_warn(
|
||||
"Your `IterableDataset` has `__len__` defined."
|
||||
" In combination with multi-process data loading (when num_workers > 1),"
|
||||
|
|
|
@ -17,9 +17,8 @@ import copy
|
|||
import inspect
|
||||
import pickle
|
||||
import types
|
||||
from argparse import Namespace
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import Any, Dict, List, MutableMapping, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from torch import nn
|
||||
from typing_extensions import Literal
|
||||
|
@ -94,18 +93,13 @@ def is_picklable(obj: object) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
|
||||
def clean_namespace(hparams: MutableMapping) -> None:
|
||||
"""Removes all unpicklable entries from hparams."""
|
||||
|
||||
hparams_dict = hparams
|
||||
if isinstance(hparams, Namespace):
|
||||
hparams_dict = hparams.__dict__
|
||||
|
||||
del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]
|
||||
del_attrs = [k for k, v in hparams.items() if not is_picklable(v)]
|
||||
|
||||
for k in del_attrs:
|
||||
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
|
||||
del hparams_dict[k]
|
||||
del hparams[k]
|
||||
|
||||
|
||||
def parse_class_init_keys(
|
||||
|
|
Loading…
Reference in New Issue