From fc86c301e78df2a51b2b26aa96c1d2b41b3465f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 30 Nov 2021 14:54:08 +0100 Subject: [PATCH] Fix typing in `pl.callbacks.timer` (#10798) --- pyproject.toml | 1 - pytorch_lightning/callbacks/timer.py | 21 ++++++++++--------- pytorch_lightning/lite/lite.py | 3 ++- .../training_type/training_type_plugin.py | 6 ++++-- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 54586addf0..675c6611e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,6 @@ module = [ "pytorch_lightning.callbacks.progress.tqdm_progress", "pytorch_lightning.callbacks.quantization", "pytorch_lightning.callbacks.stochastic_weight_avg", - "pytorch_lightning.callbacks.timer", "pytorch_lightning.callbacks.xla_stats_monitor", "pytorch_lightning.core.datamodule", "pytorch_lightning.core.decorators", diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index efeedb30c4..810439b15b 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -94,8 +94,8 @@ class Timer(Callback): self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time = {stage: None for stage in RunningStage} - self._end_time = {stage: None for stage in RunningStage} + self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} self._offset = 0 def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: @@ -124,30 +124,30 @@ class Timer(Callback): if self._duration is not None: return self._duration - self.time_elapsed(stage) - def on_train_start(self, *args, **kwargs) -> None: + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._start_time[RunningStage.TRAINING] = time.monotonic() - def on_train_end(self, *args, **kwargs) -> None: + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._end_time[RunningStage.TRAINING] = time.monotonic() - def on_validation_start(self, *args, **kwargs) -> None: + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._start_time[RunningStage.VALIDATING] = time.monotonic() - def on_validation_end(self, *args, **kwargs) -> None: + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._end_time[RunningStage.VALIDATING] = time.monotonic() - def on_test_start(self, *args, **kwargs) -> None: + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._start_time[RunningStage.TESTING] = time.monotonic() - def on_test_end(self, *args, **kwargs) -> None: + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._end_time[RunningStage.TESTING] = time.monotonic() - def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: + def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: if self._interval != Interval.step or self._duration is None: return self._check_time_remaining(trainer) - def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None: + def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: if self._interval != Interval.epoch or self._duration is None: return self._check_time_remaining(trainer) @@ -164,6 +164,7 @@ class Timer(Callback): self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0) def _check_time_remaining(self, trainer: "pl.Trainer") -> None: + assert self._duration is not None should_stop = self.time_elapsed() >= self._duration should_stop = trainer.training_type_plugin.broadcast(should_stop) trainer.should_stop = trainer.should_stop or should_stop diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index fede7f5df7..e4d54f7ed6 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -27,6 +27,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sequ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin +from pytorch_lightning.plugins.training_type.training_type_plugin import TBroadcast from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors @@ -359,7 +360,7 @@ class LightningLite(ABC): data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, torch.Tensor, self._strategy.all_gather, group=group, sync_grads=sync_grads) - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return self._strategy.broadcast(obj, src=src) def save(self, content: Dict[str, Any], filepath: Union[str, Path]) -> None: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 9709467ebd..05d444a849 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union import torch from torch import Tensor @@ -32,6 +32,8 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, move_dat from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT +TBroadcast = TypeVar("TBroadcast") + class TrainingTypePlugin(ABC): """Base class for all training type plugins that change the behaviour of the training, validation and test- @@ -246,7 +248,7 @@ class TrainingTypePlugin(ABC): """ @abstractmethod - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: """Broadcasts an object to all processes. Args: