[typing] Add typehint for broadcast in training type plugin (#6777)
* Update training_type_plugin.py * Update accelerator.py * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
parent
f8a379830d
commit
bb9ace4333
|
@ -480,7 +480,7 @@ class Accelerator(object):
|
|||
)
|
||||
self.setup_precision_plugin(plugin)
|
||||
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None:
|
||||
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
@ -30,6 +30,8 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load
|
|||
if TYPE_CHECKING:
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
TBroadcast = TypeVar("T")
|
||||
|
||||
|
||||
class TrainingTypePlugin(Plugin, ABC):
|
||||
"""A Plugin to change the behaviour of the training, validation and test-loop."""
|
||||
|
@ -88,7 +90,7 @@ class TrainingTypePlugin(Plugin, ABC):
|
|||
"""Forces all possibly joined processes to wait for each other"""
|
||||
|
||||
@abstractmethod
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
|
||||
"""Broadcasts an object to all processes"""
|
||||
|
||||
@abstractmethod
|
||||
|
|
Loading…
Reference in New Issue