From 6a7b4cf5d349aa64938149f2d2629cb3fdd364af Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Sat, 17 Apr 2021 01:33:41 +0530 Subject: [PATCH] Fix mypy for plugins registry (#7062) --- pytorch_lightning/plugins/plugins_registry.py | 9 ++++++--- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 59dd7d8db6..755c4f5be1 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -16,7 +16,7 @@ import os from collections import UserDict from inspect import getmembers, isclass from pathlib import Path -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -75,7 +75,7 @@ class _TrainingTypePluginsRegistry(UserDict): " HINT: Use `override=True`." ) - data = {} + data: Dict[str, Any] = {} data["description"] = description if description is not None else "" data["init_params"] = init_params @@ -90,7 +90,7 @@ class _TrainingTypePluginsRegistry(UserDict): return do_register - def get(self, name: str) -> Any: + def get(self, name: str, default: Optional[Any] = None) -> Any: """ Calls the registered plugin with the required parameters and returns the plugin object @@ -102,6 +102,9 @@ class _TrainingTypePluginsRegistry(UserDict): data = self[name] return data["plugin"](**data["init_params"]) + if default is not None: + return default + err_msg = "'{}' not found in registry. Available names: {}" available_names = ", ".join(sorted(self.keys())) or "none" raise KeyError(err_msg.format(name, available_names)) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 9c67a1ccb5..34a9f50408 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -544,7 +544,7 @@ class DeepSpeedPlugin(DDPPlugin): return current_global_step @classmethod - def register_plugins(cls, plugin_registry): + def register_plugins(cls, plugin_registry: Dict) -> None: plugin_registry.register("deepspeed", cls, description="Default DeepSpeed Plugin") plugin_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2) plugin_registry.register(