starlette/tests/test__utils.py

97 lines
2.5 KiB
Python

import functools
from typing import Any
import pytest
from starlette._utils import get_route_path, is_async_callable
from starlette.types import Scope
def test_async_func() -> None:
async def async_func() -> None: ... # pragma: no cover
def func() -> None: ... # pragma: no cover
assert is_async_callable(async_func)
assert not is_async_callable(func)
def test_async_partial() -> None:
async def async_func(a: Any, b: Any) -> None: ... # pragma: no cover
def func(a: Any, b: Any) -> None: ... # pragma: no cover
partial = functools.partial(async_func, 1)
assert is_async_callable(partial)
partial = functools.partial(func, 1) # type: ignore
assert not is_async_callable(partial)
def test_async_method() -> None:
class Async:
async def method(self) -> None: ... # pragma: no cover
class Sync:
def method(self) -> None: ... # pragma: no cover
assert is_async_callable(Async().method)
assert not is_async_callable(Sync().method)
def test_async_object_call() -> None:
class Async:
async def __call__(self) -> None: ... # pragma: no cover
class Sync:
def __call__(self) -> None: ... # pragma: no cover
assert is_async_callable(Async())
assert not is_async_callable(Sync())
def test_async_partial_object_call() -> None:
class Async:
async def __call__(
self,
a: Any,
b: Any,
) -> None: ... # pragma: no cover
class Sync:
def __call__(
self,
a: Any,
b: Any,
) -> None: ... # pragma: no cover
partial = functools.partial(Async(), 1)
assert is_async_callable(partial)
partial = functools.partial(Sync(), 1) # type: ignore
assert not is_async_callable(partial)
def test_async_nested_partial() -> None:
async def async_func(
a: Any,
b: Any,
) -> None: ... # pragma: no cover
partial = functools.partial(async_func, b=2)
nested_partial = functools.partial(partial, a=1)
assert is_async_callable(nested_partial)
@pytest.mark.parametrize(
"scope, expected_result",
[
({"path": "/foo-123/bar", "root_path": "/foo"}, "/foo-123/bar"),
({"path": "/foo/bar", "root_path": "/foo"}, "/bar"),
({"path": "/foo", "root_path": "/foo"}, ""),
({"path": "/foo/bar", "root_path": "/bar"}, "/foo/bar"),
],
)
def test_get_route_path(scope: Scope, expected_result: str) -> None:
assert get_route_path(scope) == expected_result