From bb9ace43334ad50e3758d9cff08ad34216c7d4da Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 2 Apr 2021 11:55:34 -0700 Subject: [PATCH] [typing] Add typehint for broadcast in training type plugin (#6777) * Update training_type_plugin.py * Update accelerator.py * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Akihiro Nitta Co-authored-by: Akihiro Nitta --- pytorch_lightning/accelerators/accelerator.py | 2 +- .../plugins/training_type/training_type_plugin.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 048f3365e1..66bbdc7fc3 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -480,7 +480,7 @@ class Accelerator(object): ) self.setup_precision_plugin(plugin) - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None: + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 01c23504b7..6fd02142bf 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TYPE_CHECKING, TypeVar, Union import torch from torch.nn import Module @@ -30,6 +30,8 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load if TYPE_CHECKING: from pytorch_lightning.trainer.trainer import Trainer +TBroadcast = TypeVar("T") + class TrainingTypePlugin(Plugin, ABC): """A Plugin to change the behaviour of the training, validation and test-loop.""" @@ -88,7 +90,7 @@ class TrainingTypePlugin(Plugin, ABC): """Forces all possibly joined processes to wait for each other""" @abstractmethod - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: """Broadcasts an object to all processes""" @abstractmethod