From f1acd57a54aaeaf92cbb4e3a5d8768c4e25ccc83 Mon Sep 17 00:00:00 2001 From: Muspi Merol Date: Fri, 14 Jun 2024 20:59:51 +0800 Subject: [PATCH] Pyfetch abort on cancel (#4846) Co-authored-by: Hood Chatham --- docs/project/changelog.md | 3 + docs/requirements-doc.txt | 2 +- pyodide-build/pyproject.toml | 2 +- requirements.txt | 2 +- src/js/abortSignalAny.ts | 76 +++++++++++++++++++ src/js/api.ts | 1 + src/js/types.ts | 1 + src/py/js.pyi | 16 ++++ src/py/pyodide/http.py | 134 +++++++++++++++++++++++++++------ src/tests/test_pyodide_http.py | 53 +++++++++++++ 10 files changed, 265 insertions(+), 25 deletions(-) create mode 100644 src/js/abortSignalAny.ts diff --git a/docs/project/changelog.md b/docs/project/changelog.md index 3b081569c..d4f6feea9 100644 --- a/docs/project/changelog.md +++ b/docs/project/changelog.md @@ -18,6 +18,9 @@ myst: - {{ Fix }} Don't leak the values in a dictionary when applying `to_js` to it. {pr}`4853` +- {{ Enhancement }} Added implementation to abort `pyfetch` and `FetchResponse` + manually or automatically. + {pr}`4846` ### Packages diff --git a/docs/requirements-doc.txt b/docs/requirements-doc.txt index f2ced5814..b46017e72 100644 --- a/docs/requirements-doc.txt +++ b/docs/requirements-doc.txt @@ -22,4 +22,4 @@ pydantic micropip==0.2.2 jinja2>=3.0 ruamel.yaml -sphinx-js @ git+https://github.com/pyodide/sphinx-js-fork@ebe90e74c8b7c1811d2c5cf41bc5094d136b98cc +sphinx-js @ git+https://github.com/pyodide/sphinx-js-fork@14958086d51939ae4078751abec004e1f3fea1fe diff --git a/pyodide-build/pyproject.toml b/pyodide-build/pyproject.toml index f1dd9dbda..c0e912001 100644 --- a/pyodide-build/pyproject.toml +++ b/pyodide-build/pyproject.toml @@ -62,7 +62,7 @@ xbuildenv = "pyodide_build.cli.xbuildenv:app" test = [ # (FIXME: 2024/01/28) The latest pytest-asyncio 0.23.3 is not compatible with pytest 8.0.0 "pytest<8.0.0", - "pytest-pyodide==0.57.0", + "pytest-pyodide==0.58.1", "pytest-httpserver", "packaging", ] diff --git a/requirements.txt b/requirements.txt index 534e58237..83145f435 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ pytest-asyncio pytest-cov pytest-httpserver pytest-benchmark -pytest-pyodide==0.57.0 +pytest-pyodide==0.58.1 diff --git a/src/js/abortSignalAny.ts b/src/js/abortSignalAny.ts new file mode 100644 index 000000000..46f21f9b8 --- /dev/null +++ b/src/js/abortSignalAny.ts @@ -0,0 +1,76 @@ +/** + * Polyfill for the static method `AbortSignal.any` which is not yet implemented + * in all browsers. This function creates a new `AbortSignal` that is aborted + * when any of the provided signals are aborted, which is used in `pyfetch`. + * + * @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal/any_static#browser_compatibility + * + * deno: 1.39 (Released 2023-12) + * nodejs: 20.3.0 (Released 2023-06) + * chrome: 100 (Released 2023-08) + * safari: 17.4 (Released 2024-03) + * firefox: 124 (Released 2024-03) + * + * We may consider dropping this polyfill after EOL of Node.js 18 (April 2025). + */ + +// @ts-ignore +if (AbortSignal.any) { + // @ts-ignore + API.abortSignalAny = AbortSignal.any; +} else { + const registry = new FinalizationRegistry( + (callback: () => any) => void callback(), + ); + interface _AbortSignal extends AbortSignal { + /** @private */ + __controller?: AbortController; + } + + API.abortSignalAny = function (signals: AbortSignal[]) { + const controller = new AbortController(); + for (const signal of signals) { + if (signal.aborted) { + controller.abort(signal.reason); + return controller.signal; + } + } + const controllerRef = new WeakRef(controller); + const eventListenerPairs: [WeakRef, () => void][] = []; + let followingCount = signals.length; + + signals.forEach((signal) => { + const signalRef = new WeakRef(signal); + function abort() { + controllerRef.deref()?.abort(signalRef.deref()?.reason); + } + signal.addEventListener("abort", abort); + eventListenerPairs.push([signalRef, abort]); + registry.register(signal, () => !--followingCount && clear(), signal); + }); + + function clear() { + eventListenerPairs.forEach(([signalRef, abort]) => { + const signal = signalRef.deref(); + if (signal) { + signal.removeEventListener("abort", abort); + registry.unregister(signal); + } + const controller = controllerRef.deref(); + if (controller) { + registry.unregister(controller.signal); + delete (controller.signal as _AbortSignal).__controller; + } + }); + } + + const { signal }: { signal: _AbortSignal } = controller; + + registry.register(signal, clear, signal); + signal.addEventListener("abort", clear); + + signal.__controller = controller; // keep a strong reference + + return signal; + }; +} diff --git a/src/js/api.ts b/src/js/api.ts index 4c6e2dfdb..0f49dbe5d 100644 --- a/src/js/api.ts +++ b/src/js/api.ts @@ -10,6 +10,7 @@ import { scheduleCallback } from "./scheduler"; import { TypedArray } from "./types"; import { IN_NODE, detectEnvironment } from "./environments"; import "./literal-map.js"; +import "./abortSignalAny"; import { makeGlobalsProxy, SnapshotConfig, diff --git a/src/js/types.ts b/src/js/types.ts index ef3c878ab..a4d9df539 100644 --- a/src/js/types.ts +++ b/src/js/types.ts @@ -430,5 +430,6 @@ export interface API { saveSnapshot(): Uint8Array; finalizeBootstrap: (fromSnapshot?: SnapshotConfig) => PyodideInterface; syncUpSnapshotLoad3(conf: SnapshotConfig): void; + abortSignalAny: (signals: AbortSignal[]) => AbortSignal; version: string; } diff --git a/src/py/js.pyi b/src/py/js.pyi index 5882613d1..5c6949670 100644 --- a/src/py/js.pyi +++ b/src/py/js.pyi @@ -123,3 +123,19 @@ class Map: def new(a: Iterable[Any]) -> Map: ... async def sleep(ms: int | float) -> None: ... + +class AbortSignal(_JsObject): + @staticmethod + def any(iterable: Iterable[AbortSignal]) -> AbortSignal: ... + @staticmethod + def timeout(ms: int) -> AbortSignal: ... + aborted: bool + reason: JsException + def throwIfAborted(self): ... + def onabort(self): ... + +class AbortController(_JsObject): + @staticmethod + def new() -> AbortController: ... + signal: AbortSignal + def abort(self, reason: JsException | None = None) -> None: ... diff --git a/src/py/pyodide/http.py b/src/py/pyodide/http.py index aa9c31020..0ed1cea53 100644 --- a/src/py/pyodide/http.py +++ b/src/py/pyodide/http.py @@ -1,15 +1,18 @@ import json +from asyncio import CancelledError +from collections.abc import Awaitable, Callable +from functools import wraps from io import StringIO -from typing import IO, Any +from typing import IO, Any, ParamSpec, TypeVar from ._package_loader import unpack_buffer from .ffi import IN_BROWSER, JsBuffer, JsException, JsFetchResponse, to_js if IN_BROWSER: - from js import Object - try: + from js import AbortController, AbortSignal, Object from js import fetch as _jsfetch + from pyodide_js._api import abortSignalAny except ImportError: pass try: @@ -24,6 +27,24 @@ __all__ = [ ] +class HttpStatusError(OSError): + def __init__(self, status: int, status_text: str, url: str) -> None: + if 400 <= status < 500: + super().__init__(f"{status} Client Error: {status_text} for url: {url}") + elif 500 <= status < 600: + super().__init__(f"{status} Server Error: {status_text} for url: {url}") + + +class BodyUsedError(OSError): + def __init__(self, *args: Any) -> None: + super().__init__("Response body is already used") + + +class AbortError(OSError): + def __init__(self, reason: JsException) -> None: + super().__init__(reason.message) + + def open_url(url: str) -> StringIO: """Fetches a given URL synchronously. @@ -59,6 +80,37 @@ def open_url(url: str) -> StringIO: return StringIO(req.response) +P = ParamSpec("P") +T = TypeVar("T") + + +def _abort_on_cancel(method: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: + @wraps(method) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + try: + return await method(*args, **kwargs) + except JsException as e: + self: FetchResponse = kwargs.get("self") or args[0] # type:ignore[assignment] + raise AbortError(s.reason if (s := self.abort_signal) else e) from None + except CancelledError as e: + self: FetchResponse = kwargs.get("self") or args[0] # type:ignore[no-redef] + if self.abort_controller: + self.abort_controller.abort( + _construct_abort_reason( + "\n".join(map(str, e.args)) if e.args else None + ) + ) + raise + + return wrapper + + +def _construct_abort_reason(reason: Any) -> JsException | None: + if reason is None: + return None + return JsException("AbortError", reason) + + class FetchResponse: """A wrapper for a Javascript fetch :js:data:`Response`. @@ -70,9 +122,17 @@ class FetchResponse: A :py:class:`~pyodide.ffi.JsProxy` of the fetch response """ - def __init__(self, url: str, js_response: JsFetchResponse): + def __init__( + self, + url: str, + js_response: JsFetchResponse, + abort_controller: "AbortController | None" = None, + abort_signal: "AbortSignal | None" = None, + ): self._url = url self.js_response = js_response + self.abort_controller = abort_controller + self.abort_signal = abort_signal @property def body_used(self) -> bool: @@ -139,24 +199,15 @@ class FetchResponse: return self.js_response.url def _raise_if_failed(self) -> None: + if (signal := self.abort_signal) and signal.aborted: + raise AbortError(signal.reason) if self.js_response.bodyUsed: - raise OSError("Response body is already used") + raise BodyUsedError def raise_for_status(self) -> None: """Raise an :py:exc:`OSError` if the status of the response is an error (4xx or 5xx)""" - http_error_msg = "" - if 400 <= self.status < 500: - http_error_msg = ( - f"{self.status} Client Error: {self.status_text} for url: {self.url}" - ) - - if 500 <= self.status < 600: - http_error_msg = ( - f"{self.status} Server Error: {self.status_text} for url: {self.url}" - ) - - if http_error_msg: - raise OSError(http_error_msg) + if 400 <= self.status < 600: + raise HttpStatusError(self.status, self.status_text, self.url) def clone(self) -> "FetchResponse": """Return an identical copy of the :py:class:`FetchResponse`. @@ -165,9 +216,15 @@ class FetchResponse: objects. See :js:meth:`Response.clone`. """ if self.js_response.bodyUsed: - raise OSError("Response body is already used") - return FetchResponse(self._url, self.js_response.clone()) + raise BodyUsedError + return FetchResponse( + self._url, + self.js_response.clone(), + self.abort_controller, + self.abort_signal, + ) + @_abort_on_cancel async def buffer(self) -> JsBuffer: """Return the response body as a Javascript :js:class:`ArrayBuffer`. @@ -176,11 +233,13 @@ class FetchResponse: self._raise_if_failed() return await self.js_response.arrayBuffer() + @_abort_on_cancel async def text(self) -> str: """Return the response body as a string""" self._raise_if_failed() return await self.js_response.text() + @_abort_on_cancel async def string(self) -> str: """Return the response body as a string @@ -193,6 +252,7 @@ class FetchResponse: """ return await self.text() + @_abort_on_cancel async def json(self, **kwargs: Any) -> Any: """Treat the response body as a JSON string and use :py:func:`json.loads` to parse it into a Python object. @@ -202,11 +262,13 @@ class FetchResponse: self._raise_if_failed() return json.loads(await self.string(), **kwargs) + @_abort_on_cancel async def memoryview(self) -> memoryview: """Return the response body as a :py:class:`memoryview` object""" self._raise_if_failed() return (await self.buffer()).to_memoryview() + @_abort_on_cancel async def _into_file(self, f: IO[bytes] | IO[str]) -> None: """Write the data into an empty file with no copy. @@ -216,6 +278,7 @@ class FetchResponse: buf = await self.buffer() buf._into_file(f) + @_abort_on_cancel async def _create_file(self, path: str) -> None: """Uses the data to back a new file without copying it. @@ -238,11 +301,13 @@ class FetchResponse: with open(path, "x") as f: await self._into_file(f) + @_abort_on_cancel async def bytes(self) -> bytes: """Return the response body as a bytes object""" self._raise_if_failed() return (await self.buffer()).to_bytes() + @_abort_on_cancel async def unpack_archive( self, *, extract_dir: str | None = None, format: str | None = None ) -> None: @@ -270,6 +335,16 @@ class FetchResponse: filename = self._url.rsplit("/", -1)[-1] unpack_buffer(buf, filename=filename, format=format, extract_dir=extract_dir) + def abort(self, reason: Any = None) -> None: + """Abort the fetch request. + + In case ``abort_controller`` is not set, a :py:exc:`ValueError` is raised. + """ + if self.abort_controller is None: + raise ValueError("abort_controller is not set") + + self.abort_controller.abort(_construct_abort_reason(reason)) + async def pyfetch(url: str, **kwargs: Any) -> FetchResponse: r"""Fetch the url and return the response. @@ -301,9 +376,24 @@ async def pyfetch(url: str, **kwargs: Any) -> FetchResponse: {'info': {'arch': 'wasm32', 'platform': 'emscripten_3_1_32', 'version': '0.23.4', 'python': '3.11.2'}, ... # long output truncated """ + + controller = AbortController.new() + if "signal" in kwargs: + signal = abortSignalAny(to_js([kwargs["signal"], controller.signal])) + else: + signal = controller.signal + kwargs["signal"] = signal try: return FetchResponse( - url, await _jsfetch(url, to_js(kwargs, dict_converter=Object.fromEntries)) + url, + await _jsfetch(url, to_js(kwargs, dict_converter=Object.fromEntries)), + controller, + signal, ) + except CancelledError as e: + controller.abort( + _construct_abort_reason("\n".join(map(str, e.args))) if e.args else None + ) + raise except JsException as e: - raise OSError(e.message) from None + raise AbortError(e) from None diff --git a/src/tests/test_pyodide_http.py b/src/tests/test_pyodide_http.py index 99faf2959..5a08c5046 100644 --- a/src/tests/test_pyodide_http.py +++ b/src/tests/test_pyodide_http.py @@ -196,3 +196,56 @@ def test_pyfetch_cors_error(selenium, httpserver): data = await pyodide.http.pyfetch('{request_url}') """ ) + + +@run_in_pyodide +async def test_pyfetch_manually_abort(selenium): + import pytest + + from pyodide.http import AbortError, pyfetch + + resp = await pyfetch("/") + resp.abort("reason") + with pytest.raises(AbortError, match="reason"): + await resp.text() + + +@run_in_pyodide +async def test_pyfetch_abort_on_cancel(selenium): + from asyncio import CancelledError, ensure_future + + import pytest + + from pyodide.http import pyfetch + + future = ensure_future(pyfetch("/")) + future.cancel() + with pytest.raises(CancelledError): + await future + + +@run_in_pyodide +async def test_pyfetch_abort_cloned_response(selenium): + import pytest + + from pyodide.http import AbortError, pyfetch + + resp = await pyfetch("/") + clone = resp.clone() + clone.abort() + with pytest.raises(AbortError): + await clone.text() + + +@run_in_pyodide +async def test_pyfetch_custom_abort_signal(selenium): + import pytest + + from js import AbortController + from pyodide.http import AbortError, pyfetch + + controller = AbortController.new() + controller.abort() + f = pyfetch("/", signal=controller.signal) + with pytest.raises(AbortError): + await f