Fix typing in `pl.callbacks.timer` (#10798)

This commit is contained in:
Adrian Wälchli 2021-11-30 14:54:08 +01:00 committed by GitHub
parent 5d9df39b1c
commit fc86c301e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 14 deletions

View File

@ -54,7 +54,6 @@ module = [
"pytorch_lightning.callbacks.progress.tqdm_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.callbacks.stochastic_weight_avg",
"pytorch_lightning.callbacks.timer",
"pytorch_lightning.callbacks.xla_stats_monitor",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.decorators",

View File

@ -94,8 +94,8 @@ class Timer(Callback):
self._duration = duration.total_seconds() if duration is not None else None
self._interval = interval
self._verbose = verbose
self._start_time = {stage: None for stage in RunningStage}
self._end_time = {stage: None for stage in RunningStage}
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._offset = 0
def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
@ -124,30 +124,30 @@ class Timer(Callback):
if self._duration is not None:
return self._duration - self.time_elapsed(stage)
def on_train_start(self, *args, **kwargs) -> None:
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._start_time[RunningStage.TRAINING] = time.monotonic()
def on_train_end(self, *args, **kwargs) -> None:
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._end_time[RunningStage.TRAINING] = time.monotonic()
def on_validation_start(self, *args, **kwargs) -> None:
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._start_time[RunningStage.VALIDATING] = time.monotonic()
def on_validation_end(self, *args, **kwargs) -> None:
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._end_time[RunningStage.VALIDATING] = time.monotonic()
def on_test_start(self, *args, **kwargs) -> None:
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._start_time[RunningStage.TESTING] = time.monotonic()
def on_test_end(self, *args, **kwargs) -> None:
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._end_time[RunningStage.TESTING] = time.monotonic()
def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
if self._interval != Interval.step or self._duration is None:
return
self._check_time_remaining(trainer)
def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
if self._interval != Interval.epoch or self._duration is None:
return
self._check_time_remaining(trainer)
@ -164,6 +164,7 @@ class Timer(Callback):
self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0)
def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
assert self._duration is not None
should_stop = self.time_elapsed() >= self._duration
should_stop = trainer.training_type_plugin.broadcast(should_stop)
trainer.should_stop = trainer.should_stop or should_stop

View File

@ -27,6 +27,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sequ
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TBroadcast
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
@ -359,7 +360,7 @@ class LightningLite(ABC):
data = convert_to_tensors(data, device=self.device)
return apply_to_collection(data, torch.Tensor, self._strategy.all_gather, group=group, sync_grads=sync_grads)
def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return self._strategy.broadcast(obj, src=src)
def save(self, content: Dict[str, Any], filepath: Union[str, Path]) -> None:

View File

@ -13,7 +13,7 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union
import torch
from torch import Tensor
@ -32,6 +32,8 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, move_dat
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT
TBroadcast = TypeVar("TBroadcast")
class TrainingTypePlugin(ABC):
"""Base class for all training type plugins that change the behaviour of the training, validation and test-
@ -246,7 +248,7 @@ class TrainingTypePlugin(ABC):
"""
@abstractmethod
def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
"""Broadcasts an object to all processes.
Args: