mirror of https://github.com/encode/starlette.git
Add Environ, DatabaseURL
This commit is contained in:
parent
3d797e24a7
commit
d2512656e3
|
@ -1,11 +1,51 @@
|
|||
import os
|
||||
import typing
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
|
||||
class undefined:
|
||||
pass
|
||||
|
||||
|
||||
class EnvironError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Environ(MutableMapping):
|
||||
def __init__(self, environ: typing.MutableMapping = os.environ):
|
||||
self._environ = environ
|
||||
self._has_been_read = set() # type: typing.Set[typing.Any]
|
||||
|
||||
def __getitem__(self, key: typing.Any) -> typing.Any:
|
||||
self._has_been_read.add(key)
|
||||
return self._environ.__getitem__(key)
|
||||
|
||||
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(
|
||||
"Attempting to set environ['%s'], but the value has already be read."
|
||||
% key
|
||||
)
|
||||
self._environ.__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key: typing.Any) -> None:
|
||||
if key in self._has_been_read:
|
||||
raise EnvironError(
|
||||
"Attempting to delete environ['%s'], but the value has already be read."
|
||||
% key
|
||||
)
|
||||
self._environ.__delitem__(key)
|
||||
|
||||
def __iter__(self) -> typing.Iterator:
|
||||
return iter(self._environ)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._environ)
|
||||
|
||||
|
||||
environ = Environ()
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(
|
||||
self, env_file: str = None, environ: typing.Mapping[str, str] = os.environ
|
||||
|
|
|
@ -12,12 +12,12 @@ from starlette.database.core import (
|
|||
DatabaseTransaction,
|
||||
compile,
|
||||
)
|
||||
from starlette.datastructures import URL
|
||||
from starlette.datastructures import DatabaseURL
|
||||
|
||||
|
||||
class PostgresBackend(DatabaseBackend):
|
||||
def __init__(self, database_url: typing.Union[str, URL]) -> None:
|
||||
self.database_url = str(database_url)
|
||||
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
|
||||
self.database_url = database_url
|
||||
self.dialect = self.get_dialect()
|
||||
self.pool = None
|
||||
|
||||
|
@ -34,7 +34,7 @@ class PostgresBackend(DatabaseBackend):
|
|||
return dialect
|
||||
|
||||
async def startup(self) -> None:
|
||||
self.pool = await asyncpg.create_pool(self.database_url)
|
||||
self.pool = await asyncpg.create_pool(str(self.database_url))
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
assert self.pool is not None, "DatabaseBackend is not running"
|
||||
|
|
|
@ -116,7 +116,7 @@ class URL:
|
|||
kwargs["netloc"] = netloc
|
||||
|
||||
components = self.components._replace(**kwargs)
|
||||
return URL(components.geturl())
|
||||
return self.__class__(components.geturl())
|
||||
|
||||
def __eq__(self, other: typing.Any) -> bool:
|
||||
return str(self) == str(other)
|
||||
|
@ -131,6 +131,17 @@ class URL:
|
|||
return "%s(%s)" % (self.__class__.__name__, repr(url))
|
||||
|
||||
|
||||
class DatabaseURL(URL):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.path.lstrip("/")
|
||||
|
||||
def replace(self, **kwargs: typing.Any) -> "URL":
|
||||
if "name" in kwargs:
|
||||
kwargs["path"] = "/" + kwargs.pop("name")
|
||||
return super().replace(**kwargs)
|
||||
|
||||
|
||||
class URLPath(str):
|
||||
"""
|
||||
A URL path string that also holds an associated protocol.
|
||||
|
|
|
@ -7,7 +7,7 @@ from starlette.database.core import (
|
|||
DatabaseSession,
|
||||
DatabaseTransaction,
|
||||
)
|
||||
from starlette.datastructures import URL
|
||||
from starlette.datastructures import DatabaseURL
|
||||
from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send
|
||||
|
||||
|
||||
|
@ -15,7 +15,7 @@ class DatabaseMiddleware:
|
|||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
database_url: typing.Union[str, URL],
|
||||
database_url: typing.Union[str, DatabaseURL],
|
||||
rollback_on_shutdown: bool,
|
||||
) -> None:
|
||||
self.app = app
|
||||
|
@ -24,9 +24,11 @@ class DatabaseMiddleware:
|
|||
self.session = None # type: typing.Optional[DatabaseSession]
|
||||
self.transaction = None # type: typing.Optional[DatabaseTransaction]
|
||||
|
||||
def get_backend(self, database_url: typing.Union[str, URL]) -> DatabaseBackend:
|
||||
def get_backend(
|
||||
self, database_url: typing.Union[str, DatabaseURL]
|
||||
) -> DatabaseBackend:
|
||||
if isinstance(database_url, str):
|
||||
database_url = URL(database_url)
|
||||
database_url = DatabaseURL(database_url)
|
||||
assert database_url.scheme == "postgresql"
|
||||
from starlette.database.postgres import PostgresBackend
|
||||
|
||||
|
|
|
@ -2,8 +2,8 @@ import os
|
|||
|
||||
import pytest
|
||||
|
||||
from starlette.config import Config
|
||||
from starlette.datastructures import URL
|
||||
from starlette.config import Config, Environ, EnvironError
|
||||
from starlette.datastructures import DatabaseURL
|
||||
|
||||
|
||||
def test_config(tmpdir):
|
||||
|
@ -20,16 +20,12 @@ def test_config(tmpdir):
|
|||
config = Config(path, environ={"DEBUG": "true"})
|
||||
|
||||
DEBUG = config.get("DEBUG", cast=bool)
|
||||
DATABASE_URL = config.get("DATABASE_URL", cast=URL)
|
||||
DATABASE_URL = config.get("DATABASE_URL", cast=DatabaseURL)
|
||||
REQUEST_TIMEOUT = config.get("REQUEST_TIMEOUT", cast=int, default=10)
|
||||
REQUEST_HOSTNAME = config.get("REQUEST_HOSTNAME")
|
||||
|
||||
assert DEBUG is True
|
||||
assert str(DATABASE_URL) == "postgres://username:password@localhost/database_name"
|
||||
assert (
|
||||
repr(DATABASE_URL)
|
||||
== "URL('postgres://username:********@localhost/database_name')"
|
||||
)
|
||||
assert DATABASE_URL.name == "database_name"
|
||||
assert REQUEST_TIMEOUT == 10
|
||||
assert REQUEST_HOSTNAME == "example.com"
|
||||
|
||||
|
@ -45,3 +41,28 @@ def test_config(tmpdir):
|
|||
os.environ["STARLETTE_EXAMPLE_TEST"] = "123"
|
||||
config = Config()
|
||||
assert config.get("STARLETTE_EXAMPLE_TEST", cast=int) == 123
|
||||
|
||||
|
||||
def test_environ():
|
||||
environ = Environ()
|
||||
|
||||
# We can mutate the environ at this point.
|
||||
environ["TESTING"] = "True"
|
||||
environ["GONE"] = "123"
|
||||
del environ["GONE"]
|
||||
|
||||
# We can read the environ.
|
||||
assert environ["TESTING"] == "True"
|
||||
assert "GONE" not in environ
|
||||
|
||||
# We cannot mutate these keys now that we've read them.
|
||||
with pytest.raises(EnvironError):
|
||||
environ["TESTING"] = "False"
|
||||
|
||||
with pytest.raises(EnvironError):
|
||||
del environ["GONE"]
|
||||
|
||||
# Test coverage of abstract methods for MutableMapping.
|
||||
environ = Environ()
|
||||
assert list(iter(environ)) == list(iter(os.environ))
|
||||
assert len(environ) == len(os.environ)
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
from starlette.datastructures import URL, Headers, MutableHeaders, QueryParams
|
||||
from starlette.datastructures import (
|
||||
URL,
|
||||
DatabaseURL,
|
||||
Headers,
|
||||
MutableHeaders,
|
||||
QueryParams,
|
||||
)
|
||||
|
||||
|
||||
def test_url():
|
||||
|
@ -38,6 +44,13 @@ def test_hidden_password():
|
|||
assert repr(u) == "URL('https://username:********@example.org/path/to/somewhere')"
|
||||
|
||||
|
||||
def test_database_url():
|
||||
u = DatabaseURL("postgresql://username:password@localhost/mydatabase")
|
||||
u = u.replace(name="test_" + u.name)
|
||||
assert u.name == "test_mydatabase"
|
||||
assert str(u) == "postgresql://username:password@localhost/test_mydatabase"
|
||||
|
||||
|
||||
def test_url_from_scope():
|
||||
u = URL(
|
||||
scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []}
|
||||
|
|
Loading…
Reference in New Issue