Fix typing in `pl.callbacks.timer` (#10798)
This commit is contained in:
parent
5d9df39b1c
commit
fc86c301e7
|
@ -54,7 +54,6 @@ module = [
|
|||
"pytorch_lightning.callbacks.progress.tqdm_progress",
|
||||
"pytorch_lightning.callbacks.quantization",
|
||||
"pytorch_lightning.callbacks.stochastic_weight_avg",
|
||||
"pytorch_lightning.callbacks.timer",
|
||||
"pytorch_lightning.callbacks.xla_stats_monitor",
|
||||
"pytorch_lightning.core.datamodule",
|
||||
"pytorch_lightning.core.decorators",
|
||||
|
|
|
@ -94,8 +94,8 @@ class Timer(Callback):
|
|||
self._duration = duration.total_seconds() if duration is not None else None
|
||||
self._interval = interval
|
||||
self._verbose = verbose
|
||||
self._start_time = {stage: None for stage in RunningStage}
|
||||
self._end_time = {stage: None for stage in RunningStage}
|
||||
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
|
||||
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
|
||||
self._offset = 0
|
||||
|
||||
def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
|
||||
|
@ -124,30 +124,30 @@ class Timer(Callback):
|
|||
if self._duration is not None:
|
||||
return self._duration - self.time_elapsed(stage)
|
||||
|
||||
def on_train_start(self, *args, **kwargs) -> None:
|
||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._start_time[RunningStage.TRAINING] = time.monotonic()
|
||||
|
||||
def on_train_end(self, *args, **kwargs) -> None:
|
||||
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._end_time[RunningStage.TRAINING] = time.monotonic()
|
||||
|
||||
def on_validation_start(self, *args, **kwargs) -> None:
|
||||
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._start_time[RunningStage.VALIDATING] = time.monotonic()
|
||||
|
||||
def on_validation_end(self, *args, **kwargs) -> None:
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._end_time[RunningStage.VALIDATING] = time.monotonic()
|
||||
|
||||
def on_test_start(self, *args, **kwargs) -> None:
|
||||
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._start_time[RunningStage.TESTING] = time.monotonic()
|
||||
|
||||
def on_test_end(self, *args, **kwargs) -> None:
|
||||
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._end_time[RunningStage.TESTING] = time.monotonic()
|
||||
|
||||
def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
|
||||
def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
||||
if self._interval != Interval.step or self._duration is None:
|
||||
return
|
||||
self._check_time_remaining(trainer)
|
||||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
|
||||
if self._interval != Interval.epoch or self._duration is None:
|
||||
return
|
||||
self._check_time_remaining(trainer)
|
||||
|
@ -164,6 +164,7 @@ class Timer(Callback):
|
|||
self._offset = time_elapsed.get(RunningStage.TRAINING.value, 0)
|
||||
|
||||
def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
|
||||
assert self._duration is not None
|
||||
should_stop = self.time_elapsed() >= self._duration
|
||||
should_stop = trainer.training_type_plugin.broadcast(should_stop)
|
||||
trainer.should_stop = trainer.should_stop or should_stop
|
||||
|
|
|
@ -27,6 +27,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sequ
|
|||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
|
||||
from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin
|
||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TBroadcast
|
||||
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
|
||||
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
|
||||
|
@ -359,7 +360,7 @@ class LightningLite(ABC):
|
|||
data = convert_to_tensors(data, device=self.device)
|
||||
return apply_to_collection(data, torch.Tensor, self._strategy.all_gather, group=group, sync_grads=sync_grads)
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
|
||||
return self._strategy.broadcast(obj, src=src)
|
||||
|
||||
def save(self, content: Dict[str, Any], filepath: Union[str, Path]) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -32,6 +32,8 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, move_dat
|
|||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT
|
||||
|
||||
TBroadcast = TypeVar("TBroadcast")
|
||||
|
||||
|
||||
class TrainingTypePlugin(ABC):
|
||||
"""Base class for all training type plugins that change the behaviour of the training, validation and test-
|
||||
|
@ -246,7 +248,7 @@ class TrainingTypePlugin(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
|
||||
"""Broadcasts an object to all processes.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue