Fix mypy for plugins registry (#7062)

This commit is contained in:
Kaushik B 2021-04-17 01:33:41 +05:30 committed by GitHub
parent 3fb8eada34
commit 6a7b4cf5d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 4 deletions

View File

@ -16,7 +16,7 @@ import os
from collections import UserDict
from inspect import getmembers, isclass
from pathlib import Path
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -75,7 +75,7 @@ class _TrainingTypePluginsRegistry(UserDict):
" HINT: Use `override=True`."
)
data = {}
data: Dict[str, Any] = {}
data["description"] = description if description is not None else ""
data["init_params"] = init_params
@ -90,7 +90,7 @@ class _TrainingTypePluginsRegistry(UserDict):
return do_register
def get(self, name: str) -> Any:
def get(self, name: str, default: Optional[Any] = None) -> Any:
"""
Calls the registered plugin with the required parameters
and returns the plugin object
@ -102,6 +102,9 @@ class _TrainingTypePluginsRegistry(UserDict):
data = self[name]
return data["plugin"](**data["init_params"])
if default is not None:
return default
err_msg = "'{}' not found in registry. Available names: {}"
available_names = ", ".join(sorted(self.keys())) or "none"
raise KeyError(err_msg.format(name, available_names))

View File

@ -544,7 +544,7 @@ class DeepSpeedPlugin(DDPPlugin):
return current_global_step
@classmethod
def register_plugins(cls, plugin_registry):
def register_plugins(cls, plugin_registry: Dict) -> None:
plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin")
plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2)
plugin_registry.register(