Fix mypy for plugins registry (#7062)
This commit is contained in:
parent
3fb8eada34
commit
6a7b4cf5d3
|
@ -16,7 +16,7 @@ import os
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from inspect import getmembers, isclass
|
from inspect import getmembers, isclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
@ -75,7 +75,7 @@ class _TrainingTypePluginsRegistry(UserDict):
|
||||||
" HINT: Use `override=True`."
|
" HINT: Use `override=True`."
|
||||||
)
|
)
|
||||||
|
|
||||||
data = {}
|
data: Dict[str, Any] = {}
|
||||||
data["description"] = description if description is not None else ""
|
data["description"] = description if description is not None else ""
|
||||||
|
|
||||||
data["init_params"] = init_params
|
data["init_params"] = init_params
|
||||||
|
@ -90,7 +90,7 @@ class _TrainingTypePluginsRegistry(UserDict):
|
||||||
|
|
||||||
return do_register
|
return do_register
|
||||||
|
|
||||||
def get(self, name: str) -> Any:
|
def get(self, name: str, default: Optional[Any] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
Calls the registered plugin with the required parameters
|
Calls the registered plugin with the required parameters
|
||||||
and returns the plugin object
|
and returns the plugin object
|
||||||
|
@ -102,6 +102,9 @@ class _TrainingTypePluginsRegistry(UserDict):
|
||||||
data = self[name]
|
data = self[name]
|
||||||
return data["plugin"](**data["init_params"])
|
return data["plugin"](**data["init_params"])
|
||||||
|
|
||||||
|
if default is not None:
|
||||||
|
return default
|
||||||
|
|
||||||
err_msg = "'{}' not found in registry. Available names: {}"
|
err_msg = "'{}' not found in registry. Available names: {}"
|
||||||
available_names = ", ".join(sorted(self.keys())) or "none"
|
available_names = ", ".join(sorted(self.keys())) or "none"
|
||||||
raise KeyError(err_msg.format(name, available_names))
|
raise KeyError(err_msg.format(name, available_names))
|
||||||
|
|
|
@ -544,7 +544,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
||||||
return current_global_step
|
return current_global_step
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_plugins(cls, plugin_registry):
|
def register_plugins(cls, plugin_registry: Dict) -> None:
|
||||||
plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin")
|
plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin")
|
||||||
plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2)
|
plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2)
|
||||||
plugin_registry.register(
|
plugin_registry.register(
|
||||||
|
|
Loading…
Reference in New Issue