diff --git a/pyproject.toml b/pyproject.toml index f1adbc1935..f91142ff77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,6 @@ warn_no_return = "False" # mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",' module = [ "pytorch_lightning.callbacks.progress.rich_progress", - "pytorch_lightning.trainer.trainer", "lightning_app.api.http_methods", "lightning_app.api.request_types", "lightning_app.cli.app-template.app", diff --git a/src/lightning_lite/strategies/registry.py b/src/lightning_lite/strategies/registry.py index 4b35e82e3a..e6ac3ee96f 100644 --- a/src/lightning_lite/strategies/registry.py +++ b/src/lightning_lite/strategies/registry.py @@ -82,7 +82,7 @@ class _StrategyRegistry(dict): return do_register - def get(self, name: str, default: Optional[Any] = None) -> Any: + def get(self, name: str, default: Optional[Strategy] = None) -> Strategy: # type: ignore[override] """Calls the registered strategy with the required parameters and returns the strategy object. Args: diff --git a/src/pytorch_lightning/callbacks/progress/base.py b/src/pytorch_lightning/callbacks/progress/base.py index 4fd4597c99..7dc555ee76 100644 --- a/src/pytorch_lightning/callbacks/progress/base.py +++ b/src/pytorch_lightning/callbacks/progress/base.py @@ -222,7 +222,9 @@ class ProgressBarBase(Callback): if not trainer.is_global_zero: self.disable() - def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]: + def get_metrics( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> Dict[str, Union[int, str, float, Dict[str, float]]]: r""" Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. Implement this to override the items displayed in the progress bar. diff --git a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py index 5f36096fa1..9f727072fb 100644 --- a/src/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/src/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -160,6 +160,7 @@ class StochasticWeightAveraging(Callback): if len(trainer.lr_scheduler_configs) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") + assert trainer.max_epochs is not None if isinstance(self._swa_epoch_start, float): self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) @@ -189,6 +190,7 @@ class StochasticWeightAveraging(Callback): for lr, group in zip(self._swa_lrs, optimizer.param_groups): group["initial_lr"] = lr + assert trainer.max_epochs is not None self._swa_scheduler = cast( _LRScheduler, SWALR( diff --git a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py index bbd269b704..60125e1174 100644 --- a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -56,7 +56,7 @@ class EvaluationLoop(DataLoaderLoop): self._results = _ResultCollection(training=False) self._outputs: List[EPOCH_OUTPUT] = [] self._logged_outputs: List[_OUT_DICT] = [] - self._max_batches: List[int] = [] + self._max_batches: List[Union[int, float]] = [] self._has_run: bool = False self._data_fetcher: Optional[AbstractDataFetcher] = None @@ -213,7 +213,7 @@ class EvaluationLoop(DataLoaderLoop): self._results.cpu() self.epoch_loop.teardown() - def _get_max_batches(self) -> List[int]: + def _get_max_batches(self) -> List[Union[int, float]]: """Returns the max number of batches for each dataloader.""" if self.trainer.testing: max_batches = self.trainer.num_test_batches diff --git a/src/pytorch_lightning/loops/dataloader/prediction_loop.py b/src/pytorch_lightning/loops/dataloader/prediction_loop.py index dcd91ef058..5d25e13ddb 100644 --- a/src/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/src/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Union from torch.utils.data import DataLoader @@ -52,7 +52,7 @@ class PredictionLoop(DataLoaderLoop): return length @property - def max_batches(self) -> List[int]: + def max_batches(self) -> List[Union[int, float]]: """The max number of batches this loop will run for each dataloader.""" return self.trainer.num_predict_batches diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 0877bb7347..1cf2a5a6de 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -292,6 +292,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): # TODO: fault-tolerance requires a minimum number of batches so probably should be > 0 and self.batch_progress.current.ready # did start ): + assert isinstance(trainer.train_dataloader, CombinedLoader) loader: CombinedLoader = trainer.train_dataloader state = loader.state_dict(has_completed=self._has_completed()) if state: diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 48a5d1ef12..25b1bb8bc7 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -23,7 +23,7 @@ from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _ from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import Progress -from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( AbstractDataFetcher, @@ -48,7 +48,7 @@ class FitLoop(Loop[None]): def __init__( self, - min_epochs: int = 0, + min_epochs: Optional[int] = 0, max_epochs: Optional[int] = None, ) -> None: super().__init__() @@ -233,6 +233,7 @@ class FitLoop(Loop[None]): self._outputs = [] if self.trainer.train_dataloader is not None: + assert isinstance(self.trainer.train_dataloader, CombinedLoader) _set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed) # changing gradient according accumulation_scheduler diff --git a/src/pytorch_lightning/loops/utilities.py b/src/pytorch_lightning/loops/utilities.py index 3dcc2f6531..f60e00333e 100644 --- a/src/pytorch_lightning/loops/utilities.py +++ b/src/pytorch_lightning/loops/utilities.py @@ -14,7 +14,7 @@ from collections import OrderedDict from contextlib import contextmanager from functools import lru_cache -from typing import Any, Generator, List, Optional, Sequence, Tuple +from typing import Any, Generator, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -29,6 +29,7 @@ from pytorch_lightning.loops import Loop from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.progress import BaseProgress +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -69,7 +70,11 @@ def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: in def _parse_loop_limits( - min_steps: Optional[int], max_steps: int, min_epochs: Optional[int], max_epochs: int, trainer: "pl.Trainer" + min_steps: Optional[int], + max_steps: int, + min_epochs: Optional[int], + max_epochs: Optional[int], + trainer: "pl.Trainer", ) -> Tuple[int, int]: """This utility computes the default values for the minimum and maximum number of steps and epochs given the values the user has selected. @@ -216,13 +221,14 @@ def _reset_progress(loop: Loop) -> None: _reset_progress(v) -def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None: +def _set_sampler_epoch(dataloader: Union[DataLoader, CombinedLoader], epoch: int) -> None: """Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader. Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a :class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off. """ + for sampler_name in ("sampler", "batch_sampler"): sampler = getattr(dataloader, sampler_name, None) if sampler is not None and callable(getattr(sampler, "set_epoch", None)): diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 10ab5c06b2..c61d183974 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -101,6 +101,7 @@ class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module): # `require_backward_grad_sync` will be reset in the # ddp_strategy `post_training_step` hook if not pl_module.automatic_optimization: + assert trainer.model is not None trainer.model.require_backward_grad_sync = False # type: ignore[assignment] return output if trainer.testing: diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index c74d1144f3..8ba07c53d7 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -109,7 +109,7 @@ class AcceleratorConnector: sync_batchnorm: bool = False, benchmark: Optional[bool] = None, replace_sampler_ddp: bool = True, - deterministic: Union[bool, _LITERAL_WARN] = False, + deterministic: Optional[Union[bool, _LITERAL_WARN]] = False, auto_select_gpus: bool = False, num_processes: Optional[int] = None, # deprecated tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated @@ -663,7 +663,8 @@ class AcceleratorConnector: if isinstance(self._strategy_flag, str): self.strategy = StrategyRegistry.get(self._strategy_flag) elif isinstance(self._strategy_flag, Strategy): - self.strategy = self._strategy_flag + # TODO(lite): remove ignore after merging lite and PL strategies + self.strategy = self._strategy_flag # type: ignore[assignment] else: raise RuntimeError(f"{self.strategy} is not valid type: {self.strategy}") @@ -687,9 +688,7 @@ class AcceleratorConnector: ) return TPUBf16PrecisionPlugin() if isinstance(self.strategy, DeepSpeedStrategy): - return DeepSpeedPrecisionPlugin( - self._precision_flag, self._amp_type_flag, self._amp_level_flag # type: ignore - ) + return DeepSpeedPrecisionPlugin(self._precision_flag, self._amp_type_flag, self._amp_level_flag) if self._precision_flag == 32: return PrecisionPlugin() diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index dce9a4fbaa..47aae74f15 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -57,23 +57,17 @@ class DataConnector: def _should_reload_train_dl(self) -> bool: """Check if train dataloader should be reloaded.""" n_epochs = self.trainer.reload_dataloaders_every_n_epochs - return n_epochs and ( - self.trainer._last_train_dl_reload_epoch is None - or self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs - ) + return n_epochs and self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs @property def _should_reload_val_dl(self) -> bool: """Check if validation dataloader should be reloaded.""" n_epochs = self.trainer.reload_dataloaders_every_n_epochs - return n_epochs and ( - self.trainer._last_val_dl_reload_epoch is None - or self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs - ) + return n_epochs and self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs def on_trainer_init( self, - val_check_interval: Union[int, float], + val_check_interval: Optional[Union[int, float]], reload_dataloaders_every_n_epochs: int, check_val_every_n_epoch: Optional[int], ) -> None: @@ -347,7 +341,7 @@ class DataConnector: def _reset_eval_dataloader( self, mode: RunningStage, model: Optional["pl.LightningModule"] = None - ) -> Tuple[List[Union[int, float]], List[DataLoader]]: + ) -> Tuple[List[Union[float, int]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. Args: @@ -387,7 +381,7 @@ class DataConnector: dataloaders, dtype=DataLoader, function=_auto_add_worker_init_fn, rank=self.trainer.global_rank ) - loader_num_batches = [] + loader_num_batches: List[Union[int, float]] = [] # determine number of batches module = model or self.trainer.lightning_module or self.datamodule @@ -398,6 +392,7 @@ class DataConnector: ) if orig_num_batches == 0: + assert isinstance(orig_num_batches, int) loader_num_batches.append(orig_num_batches) continue diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py index 454143416f..c11756ed4b 100644 --- a/src/pytorch_lightning/trainer/supporters.py +++ b/src/pytorch_lightning/trainer/supporters.py @@ -28,7 +28,6 @@ from pytorch_lightning.utilities.auto_restart import ( MergedIteratorState, patch_dataloader_iterator, ) -from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -457,6 +456,8 @@ class CombinedLoader: Returns: the wrapped loaders """ + from pytorch_lightning.utilities.data import get_len + all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) length = _nested_calc_num_data(all_lengths, max) @@ -473,6 +474,8 @@ class CombinedLoader: def _apply_cycle_iterator_length(self) -> None: """When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to all dataloaders.""" + from pytorch_lightning.utilities.data import get_len + if self.mode != "max_size_cycle": return @@ -509,6 +512,8 @@ class CombinedLoader: Returns: length: the minimum length of loaders """ + from pytorch_lightning.utilities.data import get_len + all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) if isinstance(all_lengths, (int, float)): diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 578b714eba..9e788929f7 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -25,7 +25,7 @@ import logging import math import os import warnings -from argparse import ArgumentParser, Namespace +from argparse import _ArgumentGroup, ArgumentParser, Namespace from contextlib import contextmanager from copy import deepcopy from datetime import timedelta @@ -73,7 +73,7 @@ from pytorch_lightning.trainer.connectors.callback_connector import CallbackConn from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector -from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _PBAR_DICT, _ResultCollection from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.supporters import CombinedLoader @@ -457,7 +457,7 @@ class Trainer: self._call_callback_hooks("on_init_start") # init data flags - self.check_val_every_n_epoch: int + self.check_val_every_n_epoch: Optional[int] self._data_connector.on_trainer_init( val_check_interval, reload_dataloaders_every_n_epochs, @@ -484,7 +484,7 @@ class Trainer: f"`track_grad_norm` must be a positive number or 'inf' (infinity norm). Got {track_grad_norm}." ) - self.gradient_clip_val: Union[int, float] = gradient_clip_val + self.gradient_clip_val: Optional[Union[int, float]] = gradient_clip_val self.gradient_clip_algorithm: Optional[GradClipAlgorithmType] = ( GradClipAlgorithmType(gradient_clip_algorithm.lower()) if gradient_clip_algorithm is not None else None ) @@ -504,8 +504,13 @@ class Trainer: self._logger_connector.on_trainer_init(logger, log_every_n_steps, move_metrics_to_cpu) # init debugging flags + self.val_check_batch: Union[int, float] self.val_check_interval: Union[int, float] - self.num_sanity_val_steps: Union[float, int] + self.num_sanity_val_steps: Union[int, float] + self.limit_train_batches: Union[int, float] + self.limit_val_batches: Union[int, float] + self.limit_test_batches: Union[int, float] + self.limit_predict_batches: Union[int, float] setup._init_debugging_flags( self, limit_train_batches, @@ -527,17 +532,19 @@ class Trainer: self.should_stop = False self.state = TrainerState() self.num_training_batches = float("inf") - self.train_dataloader = None - self.num_sanity_val_batches = [] - self.num_test_batches = [] - self.num_val_batches = [] - self.num_predict_batches = [] - self.test_dataloaders = None - self.val_dataloaders = None - self.predict_dataloaders = None - self._last_train_dl_reload_epoch = None - self._last_val_dl_reload_epoch: Optional[int] = None + self.train_dataloader: Optional[Union[CombinedLoader, TRAIN_DATALOADERS]] = None + + self.num_sanity_val_batches: List[Union[int, float]] = [] + self.num_test_batches: List[Union[int, float]] = [] + self.num_val_batches: List[Union[int, float]] = [] + self.num_predict_batches: List[Union[int, float]] = [] + + self.test_dataloaders: Optional[List[DataLoader]] = None + self.val_dataloaders: Optional[List[DataLoader]] = None + self.predict_dataloaders: Optional[List[DataLoader]] = None + self._last_train_dl_reload_epoch = float("-inf") + self._last_val_dl_reload_epoch = float("-inf") def fit( self, @@ -605,13 +612,16 @@ class Trainer: # TODO: ckpt_path only in v2.0 ckpt_path = ckpt_path or self.resume_from_checkpoint self._ckpt_path = self._checkpoint_connector._set_ckpt_path( - self.state.fn, ckpt_path, model_provided=True, model_connected=self.lightning_module is not None + self.state.fn, + ckpt_path, # type: ignore[arg-type] + model_provided=True, + model_connected=self.lightning_module is not None, ) - results = self._run(model, ckpt_path=self.ckpt_path) + self._run(model, ckpt_path=self.ckpt_path) assert self.state.stopped self.training = False - return results + return def validate( self, @@ -659,7 +669,7 @@ class Trainer: ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, - ) -> _EVALUATE_OUTPUT: + ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # -------------------- # SETUP HOOK # -------------------- @@ -751,7 +761,7 @@ class Trainer: ckpt_path: Optional[str] = None, verbose: bool = True, datamodule: Optional[LightningDataModule] = None, - ) -> _EVALUATE_OUTPUT: + ) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: # -------------------- # SETUP HOOK # -------------------- @@ -853,7 +863,7 @@ class Trainer: self.state.status = TrainerStatus.RUNNING self.predicting = True - self.predict_loop.return_predictions = return_predictions + self.predict_loop.return_predictions = return_predictions # type: ignore[assignment] # if a datamodule comes in as the second arg, then fix it for the user if isinstance(dataloaders, LightningDataModule): @@ -1103,7 +1113,7 @@ class Trainer: logger.log_graph(self.lightning_module) logger.save() - def _teardown(self): + def _teardown(self) -> None: """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback; those are handled by :meth:`_call_teardown_hook`.""" self.strategy.teardown() @@ -1114,7 +1124,7 @@ class Trainer: self._logger_connector.teardown() self._signal_connector.teardown() - def _run_stage(self): + def _run_stage(self) -> Optional[Union[_PREDICT_OUTPUT, _EVALUATE_OUTPUT]]: self.strategy.barrier("run-stage") self.strategy.dispatch(self) @@ -1122,9 +1132,9 @@ class Trainer: return self._run_evaluate() if self.predicting: return self._run_predict() - return self._run_train() + self._run_train() - def _pre_training_routine(self): + def _pre_training_routine(self) -> None: # wait for all to join if on distributed self.strategy.barrier("setup_training") @@ -1147,6 +1157,7 @@ class Trainer: self._run_sanity_check() # enable train mode + assert self.model is not None self.model.train() torch.set_grad_enabled(True) @@ -1229,6 +1240,7 @@ class Trainer: self.state.stage = stage def _call_setup_hook(self) -> None: + assert self.state.fn is not None fn = self.state.fn._setup_fn self.strategy.barrier("pre_setup") @@ -1252,6 +1264,7 @@ class Trainer: self._call_callback_hooks("on_configure_sharded_model") def _call_teardown_hook(self) -> None: + assert self.state.fn is not None fn = self.state.fn._setup_fn if self.datamodule is not None: @@ -1392,7 +1405,7 @@ class Trainer: prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = "on_load_checkpoint" - callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") + callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") if callback_states is None: return @@ -1420,7 +1433,7 @@ class Trainer: def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint, calls every callback's `load_state_dict`.""" - callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") + callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks") if callback_states is None: return @@ -1458,6 +1471,7 @@ class Trainer: torch._C._log_api_usage_once("lightning.trainer." + event) def __setup_profiler(self) -> None: + assert self.state.fn is not None local_rank = self.local_rank if self.world_size > 1 else None self.profiler._lightning_module = proxy(self.lightning_module) self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir) @@ -1516,7 +1530,7 @@ class Trainer: module = model or self.lightning_module or self.datamodule orig_train_batches = self.num_training_batches = ( - len(self.train_dataloader) + len(self.train_dataloader) # type: ignore[arg-type] if has_len_all_ranks(self.train_dataloader, self.strategy, module) else float("inf") ) @@ -1661,11 +1675,13 @@ class Trainer: @property def accelerator(self) -> Accelerator: + assert self.strategy.accelerator return self.strategy.accelerator @property def strategy(self) -> Strategy: - return self._accelerator_connector.strategy + # TODO(lite): remove ignore after merging lite and PL strategies + return self._accelerator_connector.strategy # type: ignore[return-value] @property def precision_plugin(self) -> PrecisionPlugin: @@ -1702,6 +1718,7 @@ class Trainer: if isinstance(self.strategy, ParallelStrategy) else [self.strategy.root_device] ) + assert devices is not None device_ids = [] for idx, device in enumerate(devices): if isinstance(device, torch.device): @@ -1718,14 +1735,14 @@ class Trainer: @property def lightning_module(self) -> "pl.LightningModule": # TODO: this is actually an optional return - return self.strategy.lightning_module + return self.strategy.lightning_module # type: ignore[return-value] @property def optimizers(self) -> List[Optimizer]: return self.strategy.optimizers @optimizers.setter - def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None: + def optimizers(self, new_optims: List[Optimizer]) -> None: self.strategy.optimizers = new_optims @property @@ -1757,7 +1774,7 @@ class Trainer: return getattr(self.precision_plugin, "scaler", None) @property - def model(self) -> torch.nn.Module: + def model(self) -> Optional[torch.nn.Module]: """The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. To access the pure LightningModule, use @@ -1783,10 +1800,10 @@ class Trainer: @property def log_dir(self) -> Optional[str]: if len(self.loggers) > 0: - if isinstance(self.loggers[0], TensorBoardLogger): - dirpath = self.loggers[0].log_dir - else: + if not isinstance(self.loggers[0], TensorBoardLogger): dirpath = self.loggers[0].save_dir + else: + dirpath = self.loggers[0].log_dir else: dirpath = self.default_root_dir @@ -1915,7 +1932,7 @@ class Trainer: return {k: v.default for k, v in init_signature.parameters.items()} @classmethod - def from_argparse_args(cls: Any, args: Union[Namespace, ArgumentParser], **kwargs) -> Any: + def from_argparse_args(cls: Any, args: Union[Namespace, ArgumentParser], **kwargs: Any) -> Any: return from_argparse_args(cls, args, **kwargs) @classmethod @@ -1927,7 +1944,7 @@ class Trainer: return parse_env_variables(cls) @classmethod - def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: + def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs: Any) -> Union[_ArgumentGroup, ArgumentParser]: return add_argparse_args(cls, parent_parser, **kwargs) """ @@ -1995,7 +2012,7 @@ class Trainer: @property def evaluating(self) -> bool: - return self.state.stage and self.state.stage.evaluating + return self.state.stage is not None and self.state.stage.evaluating @property def sanity_checking(self) -> bool: @@ -2026,11 +2043,11 @@ class Trainer: return self.fit_loop.epoch_progress.current.completed @property - def max_epochs(self) -> int: + def max_epochs(self) -> Optional[int]: return self.fit_loop.max_epochs @property - def min_epochs(self) -> int: + def min_epochs(self) -> Optional[int]: return self.fit_loop.min_epochs @property @@ -2051,7 +2068,7 @@ class Trainer: return self._fit_loop @fit_loop.setter - def fit_loop(self, loop: FitLoop): + def fit_loop(self, loop: FitLoop) -> None: """Attach a custom fit loop to this Trainer. It will run with @@ -2065,7 +2082,7 @@ class Trainer: return self._validate_loop @validate_loop.setter - def validate_loop(self, loop: EvaluationLoop): + def validate_loop(self, loop: EvaluationLoop) -> None: """Attach a custom validation loop to this Trainer. It will run with @@ -2080,7 +2097,7 @@ class Trainer: return self._test_loop @test_loop.setter - def test_loop(self, loop: EvaluationLoop): + def test_loop(self, loop: EvaluationLoop) -> None: """Attach a custom test loop to this Trainer. It will run with @@ -2094,7 +2111,7 @@ class Trainer: return self._predict_loop @predict_loop.setter - def predict_loop(self, loop: PredictionLoop): + def predict_loop(self, loop: PredictionLoop) -> None: """Attach a custom prediction loop to this Trainer. It will run with @@ -2146,17 +2163,17 @@ class Trainer: self._loggers = loggers if loggers else [] @property - def callback_metrics(self) -> Dict[str, Tensor]: + def callback_metrics(self) -> Dict: # TODO: the true typing return can include dictionaries as defined in # `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT` return self._logger_connector.callback_metrics @property - def logged_metrics(self) -> dict: + def logged_metrics(self) -> _OUT_DICT: return self._logger_connector.logged_metrics @property - def progress_bar_metrics(self) -> dict: + def progress_bar_metrics(self) -> _PBAR_DICT: return self._logger_connector.progress_bar_metrics @property @@ -2172,7 +2189,7 @@ class Trainer: def _should_terminate_gracefully(self) -> bool: value = torch.tensor(int(self._terminate_gracefully), device=self.strategy.root_device) - return self.strategy.reduce(value, reduce_op="sum") > 0 + return bool(self.strategy.reduce(value, reduce_op="sum") > 0) """ Other @@ -2216,6 +2233,7 @@ class Trainer: if total_batches == float("inf"): return self.max_steps + assert self.max_epochs is not None self.accumulate_grad_batches = accumulation_scheduler.get_accumulate_grad_batches(self.current_epoch) effective_batch_size = self.accumulate_grad_batches max_estimated_steps = math.ceil(total_batches / effective_batch_size) * max(self.max_epochs, 1) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 781c7ee119..82ea298a6b 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -17,8 +17,6 @@ import uuid from copy import deepcopy from typing import Any, Dict, Optional, Tuple -from torch.utils.data import DataLoader - import pytorch_lightning as pl from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_setattr @@ -277,13 +275,15 @@ def _adjust_batch_size( rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") if trainer.state.fn == "fit": + from pytorch_lightning.trainer.supporters import CombinedLoader + if trainer.train_dataloader is None: trainer.reset_train_dataloader() - assert trainer.train_dataloader is not None - # TODO: should we check val_dataloaders here too? + assert isinstance(trainer.train_dataloader, CombinedLoader) if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer): - new_size = min(new_size, len(trainer.train_dataloader.dataset)) + # at this moment, `train_dataloader` is already a CombinedLoader. len can return a size or infinity + new_size = min(new_size, len(trainer.train_dataloader.dataset)) # type: ignore[arg-type] else: stage = trainer.state.stage assert stage is not None @@ -302,11 +302,14 @@ def _adjust_batch_size( return new_size, changed -def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer") -> bool: +def _is_valid_batch_size( + batch_size: int, dataloader: "pl.trainer.supporters.CombinedLoader", trainer: "pl.Trainer" +) -> bool: from pytorch_lightning.utilities.data import has_len_all_ranks module = trainer.lightning_module or trainer.datamodule - return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) + has_len = has_len_all_ranks(dataloader, trainer.strategy, module) + return not has_len or batch_size <= len(dataloader) # type: ignore[arg-type] def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 17f8b9f101..e2efb4f562 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -35,6 +35,7 @@ from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterab from lightning_lite.utilities.data import has_len as new_has_len from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -96,14 +97,14 @@ def extract_batch_size(batch: BType) -> int: def has_len_all_ranks( - dataloader: DataLoader, + dataloader: Union[DataLoader, CombinedLoader], strategy: "pl.strategies.Strategy", model: Union["pl.LightningModule", "pl.LightningDataModule"], ) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or infinite dataloader.""" try: - local_length = len(dataloader) + local_length = len(dataloader) # type: ignore [arg-type] # we are checking with duck-typing total_length = strategy.reduce(torch.tensor(local_length, device=strategy.root_device), reduce_op="sum") if total_length == 0: @@ -129,7 +130,8 @@ def has_len_all_ranks( except (TypeError, NotImplementedError): has_len = False - if has_len and new_has_iterable_dataset(dataloader): + # we are checking using lightning_lite, which doesn't know CombinedLoader + if has_len and new_has_iterable_dataset(dataloader): # type: ignore [arg-type] rank_zero_warn( "Your `IterableDataset` has `__len__` defined." " In combination with multi-process data loading (when num_workers > 1)," diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index 22dfb53882..87d43791e5 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -17,9 +17,8 @@ import copy import inspect import pickle import types -from argparse import Namespace from dataclasses import fields, is_dataclass -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Dict, List, MutableMapping, Optional, Sequence, Tuple, Type, Union from torch import nn from typing_extensions import Literal @@ -94,18 +93,13 @@ def is_picklable(obj: object) -> bool: return False -def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: +def clean_namespace(hparams: MutableMapping) -> None: """Removes all unpicklable entries from hparams.""" - - hparams_dict = hparams - if isinstance(hparams, Namespace): - hparams_dict = hparams.__dict__ - - del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)] + del_attrs = [k for k, v in hparams.items() if not is_picklable(v)] for k in del_attrs: rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled") - del hparams_dict[k] + del hparams[k] def parse_class_init_keys(