Pyfetch abort on cancel (#4846)

Co-authored-by: Hood Chatham <roberthoodchatham@gmail.com>
This commit is contained in:
Muspi Merol 2024-06-14 20:59:51 +08:00 committed by GitHub
parent eb1d208638
commit f1acd57a54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 265 additions and 25 deletions

View File

@ -18,6 +18,9 @@ myst:
- {{ Fix }} Don't leak the values in a dictionary when applying `to_js` to it.
{pr}`4853`
- {{ Enhancement }} Added implementation to abort `pyfetch` and `FetchResponse`
manually or automatically.
{pr}`4846`
### Packages

View File

@ -22,4 +22,4 @@ pydantic
micropip==0.2.2
jinja2>=3.0
ruamel.yaml
sphinx-js @ git+https://github.com/pyodide/sphinx-js-fork@ebe90e74c8b7c1811d2c5cf41bc5094d136b98cc
sphinx-js @ git+https://github.com/pyodide/sphinx-js-fork@14958086d51939ae4078751abec004e1f3fea1fe

View File

@ -62,7 +62,7 @@ xbuildenv = "pyodide_build.cli.xbuildenv:app"
test = [
# (FIXME: 2024/01/28) The latest pytest-asyncio 0.23.3 is not compatible with pytest 8.0.0
"pytest<8.0.0",
"pytest-pyodide==0.57.0",
"pytest-pyodide==0.58.1",
"pytest-httpserver",
"packaging",
]

View File

@ -12,4 +12,4 @@ pytest-asyncio
pytest-cov
pytest-httpserver
pytest-benchmark
pytest-pyodide==0.57.0
pytest-pyodide==0.58.1

76
src/js/abortSignalAny.ts Normal file
View File

@ -0,0 +1,76 @@
/**
* Polyfill for the static method `AbortSignal.any` which is not yet implemented
* in all browsers. This function creates a new `AbortSignal` that is aborted
* when any of the provided signals are aborted, which is used in `pyfetch`.
*
* @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal/any_static#browser_compatibility
*
* deno: 1.39 (Released 2023-12)
* nodejs: 20.3.0 (Released 2023-06)
* chrome: 100 (Released 2023-08)
* safari: 17.4 (Released 2024-03)
* firefox: 124 (Released 2024-03)
*
* We may consider dropping this polyfill after EOL of Node.js 18 (April 2025).
*/
// @ts-ignore
if (AbortSignal.any) {
// @ts-ignore
API.abortSignalAny = AbortSignal.any;
} else {
const registry = new FinalizationRegistry(
(callback: () => any) => void callback(),
);
interface _AbortSignal extends AbortSignal {
/** @private */
__controller?: AbortController;
}
API.abortSignalAny = function (signals: AbortSignal[]) {
const controller = new AbortController();
for (const signal of signals) {
if (signal.aborted) {
controller.abort(signal.reason);
return controller.signal;
}
}
const controllerRef = new WeakRef(controller);
const eventListenerPairs: [WeakRef<AbortSignal>, () => void][] = [];
let followingCount = signals.length;
signals.forEach((signal) => {
const signalRef = new WeakRef(signal);
function abort() {
controllerRef.deref()?.abort(signalRef.deref()?.reason);
}
signal.addEventListener("abort", abort);
eventListenerPairs.push([signalRef, abort]);
registry.register(signal, () => !--followingCount && clear(), signal);
});
function clear() {
eventListenerPairs.forEach(([signalRef, abort]) => {
const signal = signalRef.deref();
if (signal) {
signal.removeEventListener("abort", abort);
registry.unregister(signal);
}
const controller = controllerRef.deref();
if (controller) {
registry.unregister(controller.signal);
delete (controller.signal as _AbortSignal).__controller;
}
});
}
const { signal }: { signal: _AbortSignal } = controller;
registry.register(signal, clear, signal);
signal.addEventListener("abort", clear);
signal.__controller = controller; // keep a strong reference
return signal;
};
}

View File

@ -10,6 +10,7 @@ import { scheduleCallback } from "./scheduler";
import { TypedArray } from "./types";
import { IN_NODE, detectEnvironment } from "./environments";
import "./literal-map.js";
import "./abortSignalAny";
import {
makeGlobalsProxy,
SnapshotConfig,

View File

@ -430,5 +430,6 @@ export interface API {
saveSnapshot(): Uint8Array;
finalizeBootstrap: (fromSnapshot?: SnapshotConfig) => PyodideInterface;
syncUpSnapshotLoad3(conf: SnapshotConfig): void;
abortSignalAny: (signals: AbortSignal[]) => AbortSignal;
version: string;
}

View File

@ -123,3 +123,19 @@ class Map:
def new(a: Iterable[Any]) -> Map: ...
async def sleep(ms: int | float) -> None: ...
class AbortSignal(_JsObject):
@staticmethod
def any(iterable: Iterable[AbortSignal]) -> AbortSignal: ...
@staticmethod
def timeout(ms: int) -> AbortSignal: ...
aborted: bool
reason: JsException
def throwIfAborted(self): ...
def onabort(self): ...
class AbortController(_JsObject):
@staticmethod
def new() -> AbortController: ...
signal: AbortSignal
def abort(self, reason: JsException | None = None) -> None: ...

View File

@ -1,15 +1,18 @@
import json
from asyncio import CancelledError
from collections.abc import Awaitable, Callable
from functools import wraps
from io import StringIO
from typing import IO, Any
from typing import IO, Any, ParamSpec, TypeVar
from ._package_loader import unpack_buffer
from .ffi import IN_BROWSER, JsBuffer, JsException, JsFetchResponse, to_js
if IN_BROWSER:
from js import Object
try:
from js import AbortController, AbortSignal, Object
from js import fetch as _jsfetch
from pyodide_js._api import abortSignalAny
except ImportError:
pass
try:
@ -24,6 +27,24 @@ __all__ = [
]
class HttpStatusError(OSError):
def __init__(self, status: int, status_text: str, url: str) -> None:
if 400 <= status < 500:
super().__init__(f"{status} Client Error: {status_text} for url: {url}")
elif 500 <= status < 600:
super().__init__(f"{status} Server Error: {status_text} for url: {url}")
class BodyUsedError(OSError):
def __init__(self, *args: Any) -> None:
super().__init__("Response body is already used")
class AbortError(OSError):
def __init__(self, reason: JsException) -> None:
super().__init__(reason.message)
def open_url(url: str) -> StringIO:
"""Fetches a given URL synchronously.
@ -59,6 +80,37 @@ def open_url(url: str) -> StringIO:
return StringIO(req.response)
P = ParamSpec("P")
T = TypeVar("T")
def _abort_on_cancel(method: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@wraps(method)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return await method(*args, **kwargs)
except JsException as e:
self: FetchResponse = kwargs.get("self") or args[0] # type:ignore[assignment]
raise AbortError(s.reason if (s := self.abort_signal) else e) from None
except CancelledError as e:
self: FetchResponse = kwargs.get("self") or args[0] # type:ignore[no-redef]
if self.abort_controller:
self.abort_controller.abort(
_construct_abort_reason(
"\n".join(map(str, e.args)) if e.args else None
)
)
raise
return wrapper
def _construct_abort_reason(reason: Any) -> JsException | None:
if reason is None:
return None
return JsException("AbortError", reason)
class FetchResponse:
"""A wrapper for a Javascript fetch :js:data:`Response`.
@ -70,9 +122,17 @@ class FetchResponse:
A :py:class:`~pyodide.ffi.JsProxy` of the fetch response
"""
def __init__(self, url: str, js_response: JsFetchResponse):
def __init__(
self,
url: str,
js_response: JsFetchResponse,
abort_controller: "AbortController | None" = None,
abort_signal: "AbortSignal | None" = None,
):
self._url = url
self.js_response = js_response
self.abort_controller = abort_controller
self.abort_signal = abort_signal
@property
def body_used(self) -> bool:
@ -139,24 +199,15 @@ class FetchResponse:
return self.js_response.url
def _raise_if_failed(self) -> None:
if (signal := self.abort_signal) and signal.aborted:
raise AbortError(signal.reason)
if self.js_response.bodyUsed:
raise OSError("Response body is already used")
raise BodyUsedError
def raise_for_status(self) -> None:
"""Raise an :py:exc:`OSError` if the status of the response is an error (4xx or 5xx)"""
http_error_msg = ""
if 400 <= self.status < 500:
http_error_msg = (
f"{self.status} Client Error: {self.status_text} for url: {self.url}"
)
if 500 <= self.status < 600:
http_error_msg = (
f"{self.status} Server Error: {self.status_text} for url: {self.url}"
)
if http_error_msg:
raise OSError(http_error_msg)
if 400 <= self.status < 600:
raise HttpStatusError(self.status, self.status_text, self.url)
def clone(self) -> "FetchResponse":
"""Return an identical copy of the :py:class:`FetchResponse`.
@ -165,9 +216,15 @@ class FetchResponse:
objects. See :js:meth:`Response.clone`.
"""
if self.js_response.bodyUsed:
raise OSError("Response body is already used")
return FetchResponse(self._url, self.js_response.clone())
raise BodyUsedError
return FetchResponse(
self._url,
self.js_response.clone(),
self.abort_controller,
self.abort_signal,
)
@_abort_on_cancel
async def buffer(self) -> JsBuffer:
"""Return the response body as a Javascript :js:class:`ArrayBuffer`.
@ -176,11 +233,13 @@ class FetchResponse:
self._raise_if_failed()
return await self.js_response.arrayBuffer()
@_abort_on_cancel
async def text(self) -> str:
"""Return the response body as a string"""
self._raise_if_failed()
return await self.js_response.text()
@_abort_on_cancel
async def string(self) -> str:
"""Return the response body as a string
@ -193,6 +252,7 @@ class FetchResponse:
"""
return await self.text()
@_abort_on_cancel
async def json(self, **kwargs: Any) -> Any:
"""Treat the response body as a JSON string and use
:py:func:`json.loads` to parse it into a Python object.
@ -202,11 +262,13 @@ class FetchResponse:
self._raise_if_failed()
return json.loads(await self.string(), **kwargs)
@_abort_on_cancel
async def memoryview(self) -> memoryview:
"""Return the response body as a :py:class:`memoryview` object"""
self._raise_if_failed()
return (await self.buffer()).to_memoryview()
@_abort_on_cancel
async def _into_file(self, f: IO[bytes] | IO[str]) -> None:
"""Write the data into an empty file with no copy.
@ -216,6 +278,7 @@ class FetchResponse:
buf = await self.buffer()
buf._into_file(f)
@_abort_on_cancel
async def _create_file(self, path: str) -> None:
"""Uses the data to back a new file without copying it.
@ -238,11 +301,13 @@ class FetchResponse:
with open(path, "x") as f:
await self._into_file(f)
@_abort_on_cancel
async def bytes(self) -> bytes:
"""Return the response body as a bytes object"""
self._raise_if_failed()
return (await self.buffer()).to_bytes()
@_abort_on_cancel
async def unpack_archive(
self, *, extract_dir: str | None = None, format: str | None = None
) -> None:
@ -270,6 +335,16 @@ class FetchResponse:
filename = self._url.rsplit("/", -1)[-1]
unpack_buffer(buf, filename=filename, format=format, extract_dir=extract_dir)
def abort(self, reason: Any = None) -> None:
"""Abort the fetch request.
In case ``abort_controller`` is not set, a :py:exc:`ValueError` is raised.
"""
if self.abort_controller is None:
raise ValueError("abort_controller is not set")
self.abort_controller.abort(_construct_abort_reason(reason))
async def pyfetch(url: str, **kwargs: Any) -> FetchResponse:
r"""Fetch the url and return the response.
@ -301,9 +376,24 @@ async def pyfetch(url: str, **kwargs: Any) -> FetchResponse:
{'info': {'arch': 'wasm32', 'platform': 'emscripten_3_1_32',
'version': '0.23.4', 'python': '3.11.2'}, ... # long output truncated
"""
controller = AbortController.new()
if "signal" in kwargs:
signal = abortSignalAny(to_js([kwargs["signal"], controller.signal]))
else:
signal = controller.signal
kwargs["signal"] = signal
try:
return FetchResponse(
url, await _jsfetch(url, to_js(kwargs, dict_converter=Object.fromEntries))
url,
await _jsfetch(url, to_js(kwargs, dict_converter=Object.fromEntries)),
controller,
signal,
)
except CancelledError as e:
controller.abort(
_construct_abort_reason("\n".join(map(str, e.args))) if e.args else None
)
raise
except JsException as e:
raise OSError(e.message) from None
raise AbortError(e) from None

View File

@ -196,3 +196,56 @@ def test_pyfetch_cors_error(selenium, httpserver):
data = await pyodide.http.pyfetch('{request_url}')
"""
)
@run_in_pyodide
async def test_pyfetch_manually_abort(selenium):
import pytest
from pyodide.http import AbortError, pyfetch
resp = await pyfetch("/")
resp.abort("reason")
with pytest.raises(AbortError, match="reason"):
await resp.text()
@run_in_pyodide
async def test_pyfetch_abort_on_cancel(selenium):
from asyncio import CancelledError, ensure_future
import pytest
from pyodide.http import pyfetch
future = ensure_future(pyfetch("/"))
future.cancel()
with pytest.raises(CancelledError):
await future
@run_in_pyodide
async def test_pyfetch_abort_cloned_response(selenium):
import pytest
from pyodide.http import AbortError, pyfetch
resp = await pyfetch("/")
clone = resp.clone()
clone.abort()
with pytest.raises(AbortError):
await clone.text()
@run_in_pyodide
async def test_pyfetch_custom_abort_signal(selenium):
import pytest
from js import AbortController
from pyodide.http import AbortError, pyfetch
controller = AbortController.new()
controller.abort()
f = pyfetch("/", signal=controller.signal)
with pytest.raises(AbortError):
await f