diff --git a/pyproject.toml b/pyproject.toml index f9b4cc905a..660a38fb42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,7 +136,6 @@ module = [ "lightning.app.components.training", "lightning.app.core.api", "lightning.app.core.app", - "lightning.app.core.queues", "lightning.app.frontend.panel.app_state_comm", "lightning.app.frontend.panel.app_state_watcher", "lightning.app.frontend.panel.panel_frontend", diff --git a/src/lightning/app/core/queues.py b/src/lightning/app/core/queues.py index d99b9ee5cc..d9f0d510d1 100644 --- a/src/lightning/app/core/queues.py +++ b/src/lightning/app/core/queues.py @@ -20,7 +20,7 @@ import warnings from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, Tuple from urllib.parse import urljoin import requests @@ -170,11 +170,11 @@ class BaseQueue(ABC): self.default_timeout = default_timeout @abstractmethod - def put(self, item): + def put(self, item: Any) -> None: pass @abstractmethod - def get(self, timeout: Optional[int] = None): + def get(self, timeout: Optional[int] = None) -> Any: """Returns the left most element of the queue. Parameters @@ -195,16 +195,16 @@ class BaseQueue(ABC): class MultiProcessQueue(BaseQueue): - def __init__(self, name: str, default_timeout: float): + def __init__(self, name: str, default_timeout: float) -> None: self.name = name self.default_timeout = default_timeout context = multiprocessing.get_context("spawn") self.queue = context.Queue() - def put(self, item): + def put(self, item: Any) -> None: self.queue.put(item) - def get(self, timeout: Optional[int] = None): + def get(self, timeout: Optional[float] = None) -> Any: if timeout == 0: timeout = self.default_timeout return self.queue.get(timeout=timeout, block=(timeout is None)) @@ -278,7 +278,7 @@ class RedisQueue(BaseQueue): if is_work: item._backend = backend - def get(self, timeout: Optional[int] = None): + def get(self, timeout: Optional[float] = None) -> Any: """Returns the left most element of the redis queue. Parameters @@ -330,7 +330,7 @@ class RedisQueue(BaseQueue): except redis.exceptions.ConnectionError: return False - def to_dict(self): + def to_dict(self) -> dict: return { "type": "redis", "name": self.name, @@ -341,12 +341,12 @@ class RedisQueue(BaseQueue): } @classmethod - def from_dict(cls, state): + def from_dict(cls, state: dict) -> "RedisQueue": return cls(**state) class HTTPQueue(BaseQueue): - def __init__(self, name: str, default_timeout: float): + def __init__(self, name: str, default_timeout: float) -> None: """ Parameters ---------- @@ -378,6 +378,7 @@ class HTTPQueue(BaseQueue): return True except (ConnectionError, ConnectTimeout, ReadTimeout): return False + return False def get(self, timeout: Optional[int] = None) -> Any: if not self.app_id: @@ -410,7 +411,7 @@ class HTTPQueue(BaseQueue): time.sleep(0.05) pass - def _get(self): + def _get(self) -> Any: try: resp = self.client.post(f"v1/{self.app_id}/{self._name_suffix}", query_params={"action": "pop"}) if resp.status_code == 204: @@ -436,7 +437,7 @@ class HTTPQueue(BaseQueue): if resp.status_code != 201: raise RuntimeError(f"Failed to push to queue: {self._name_suffix}") - def length(self): + def length(self) -> int: if not self.app_id: raise ValueError(f"App ID couldn't be extracted from the queue name: {self.name}") @@ -444,7 +445,7 @@ class HTTPQueue(BaseQueue): return int(val.text) @staticmethod - def _split_app_id_and_queue_name(queue_name): + def _split_app_id_and_queue_name(queue_name: str) -> Tuple[str, str]: """This splits the app id and the queue name into two parts. This can be brittle, as if the queue name creation logic changes, the response values from here wouldn't be @@ -455,7 +456,7 @@ class HTTPQueue(BaseQueue): app_id, queue_name = queue_name.split("_", 1) return app_id, queue_name - def to_dict(self): + def to_dict(self) -> dict: return { "type": "http", "name": self.name, @@ -463,7 +464,7 @@ class HTTPQueue(BaseQueue): } @classmethod - def from_dict(cls, state): + def from_dict(cls, state: dict) -> "HTTPQueue": return cls(**state)