Fix mypy for plugins registry (#7062)
This commit is contained in:
parent
3fb8eada34
commit
6a7b4cf5d3
|
@ -16,7 +16,7 @@ import os
|
|||
from collections import UserDict
|
||||
from inspect import getmembers, isclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -75,7 +75,7 @@ class _TrainingTypePluginsRegistry(UserDict):
|
|||
" HINT: Use `override=True`."
|
||||
)
|
||||
|
||||
data = {}
|
||||
data: Dict[str, Any] = {}
|
||||
data["description"] = description if description is not None else ""
|
||||
|
||||
data["init_params"] = init_params
|
||||
|
@ -90,7 +90,7 @@ class _TrainingTypePluginsRegistry(UserDict):
|
|||
|
||||
return do_register
|
||||
|
||||
def get(self, name: str) -> Any:
|
||||
def get(self, name: str, default: Optional[Any] = None) -> Any:
|
||||
"""
|
||||
Calls the registered plugin with the required parameters
|
||||
and returns the plugin object
|
||||
|
@ -102,6 +102,9 @@ class _TrainingTypePluginsRegistry(UserDict):
|
|||
data = self[name]
|
||||
return data["plugin"](**data["init_params"])
|
||||
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
err_msg = "'{}' not found in registry. Available names: {}"
|
||||
available_names = ", ".join(sorted(self.keys())) or "none"
|
||||
raise KeyError(err_msg.format(name, available_names))
|
||||
|
|
|
@ -544,7 +544,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
return current_global_step
|
||||
|
||||
@classmethod
|
||||
def register_plugins(cls, plugin_registry):
|
||||
def register_plugins(cls, plugin_registry: Dict) -> None:
|
||||
plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin")
|
||||
plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2)
|
||||
plugin_registry.register(
|
||||
|
|
Loading…
Reference in New Issue