Add Environ, DatabaseURL

This commit is contained in:
Tom Christie 2018-12-05 12:28:18 +00:00
parent 3d797e24a7
commit d2512656e3
6 changed files with 105 additions and 18 deletions

View File

@ -1,11 +1,51 @@
import os
import typing
from collections.abc import MutableMapping
class undefined:
pass
class EnvironError(Exception):
pass
class Environ(MutableMapping):
def __init__(self, environ: typing.MutableMapping = os.environ):
self._environ = environ
self._has_been_read = set() # type: typing.Set[typing.Any]
def __getitem__(self, key: typing.Any) -> typing.Any:
self._has_been_read.add(key)
return self._environ.__getitem__(key)
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
if key in self._has_been_read:
raise EnvironError(
"Attempting to set environ['%s'], but the value has already be read."
% key
)
self._environ.__setitem__(key, value)
def __delitem__(self, key: typing.Any) -> None:
if key in self._has_been_read:
raise EnvironError(
"Attempting to delete environ['%s'], but the value has already be read."
% key
)
self._environ.__delitem__(key)
def __iter__(self) -> typing.Iterator:
return iter(self._environ)
def __len__(self) -> int:
return len(self._environ)
environ = Environ()
class Config:
def __init__(
self, env_file: str = None, environ: typing.Mapping[str, str] = os.environ

View File

@ -12,12 +12,12 @@ from starlette.database.core import (
DatabaseTransaction,
compile,
)
from starlette.datastructures import URL
from starlette.datastructures import DatabaseURL
class PostgresBackend(DatabaseBackend):
def __init__(self, database_url: typing.Union[str, URL]) -> None:
self.database_url = str(database_url)
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
self.database_url = database_url
self.dialect = self.get_dialect()
self.pool = None
@ -34,7 +34,7 @@ class PostgresBackend(DatabaseBackend):
return dialect
async def startup(self) -> None:
self.pool = await asyncpg.create_pool(self.database_url)
self.pool = await asyncpg.create_pool(str(self.database_url))
async def shutdown(self) -> None:
assert self.pool is not None, "DatabaseBackend is not running"

View File

@ -116,7 +116,7 @@ class URL:
kwargs["netloc"] = netloc
components = self.components._replace(**kwargs)
return URL(components.geturl())
return self.__class__(components.geturl())
def __eq__(self, other: typing.Any) -> bool:
return str(self) == str(other)
@ -131,6 +131,17 @@ class URL:
return "%s(%s)" % (self.__class__.__name__, repr(url))
class DatabaseURL(URL):
@property
def name(self) -> str:
return self.path.lstrip("/")
def replace(self, **kwargs: typing.Any) -> "URL":
if "name" in kwargs:
kwargs["path"] = "/" + kwargs.pop("name")
return super().replace(**kwargs)
class URLPath(str):
"""
A URL path string that also holds an associated protocol.

View File

@ -7,7 +7,7 @@ from starlette.database.core import (
DatabaseSession,
DatabaseTransaction,
)
from starlette.datastructures import URL
from starlette.datastructures import DatabaseURL
from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send
@ -15,7 +15,7 @@ class DatabaseMiddleware:
def __init__(
self,
app: ASGIApp,
database_url: typing.Union[str, URL],
database_url: typing.Union[str, DatabaseURL],
rollback_on_shutdown: bool,
) -> None:
self.app = app
@ -24,9 +24,11 @@ class DatabaseMiddleware:
self.session = None # type: typing.Optional[DatabaseSession]
self.transaction = None # type: typing.Optional[DatabaseTransaction]
def get_backend(self, database_url: typing.Union[str, URL]) -> DatabaseBackend:
def get_backend(
self, database_url: typing.Union[str, DatabaseURL]
) -> DatabaseBackend:
if isinstance(database_url, str):
database_url = URL(database_url)
database_url = DatabaseURL(database_url)
assert database_url.scheme == "postgresql"
from starlette.database.postgres import PostgresBackend

View File

@ -2,8 +2,8 @@ import os
import pytest
from starlette.config import Config
from starlette.datastructures import URL
from starlette.config import Config, Environ, EnvironError
from starlette.datastructures import DatabaseURL
def test_config(tmpdir):
@ -20,16 +20,12 @@ def test_config(tmpdir):
config = Config(path, environ={"DEBUG": "true"})
DEBUG = config.get("DEBUG", cast=bool)
DATABASE_URL = config.get("DATABASE_URL", cast=URL)
DATABASE_URL = config.get("DATABASE_URL", cast=DatabaseURL)
REQUEST_TIMEOUT = config.get("REQUEST_TIMEOUT", cast=int, default=10)
REQUEST_HOSTNAME = config.get("REQUEST_HOSTNAME")
assert DEBUG is True
assert str(DATABASE_URL) == "postgres://username:password@localhost/database_name"
assert (
repr(DATABASE_URL)
== "URL('postgres://username:********@localhost/database_name')"
)
assert DATABASE_URL.name == "database_name"
assert REQUEST_TIMEOUT == 10
assert REQUEST_HOSTNAME == "example.com"
@ -45,3 +41,28 @@ def test_config(tmpdir):
os.environ["STARLETTE_EXAMPLE_TEST"] = "123"
config = Config()
assert config.get("STARLETTE_EXAMPLE_TEST", cast=int) == 123
def test_environ():
environ = Environ()
# We can mutate the environ at this point.
environ["TESTING"] = "True"
environ["GONE"] = "123"
del environ["GONE"]
# We can read the environ.
assert environ["TESTING"] == "True"
assert "GONE" not in environ
# We cannot mutate these keys now that we've read them.
with pytest.raises(EnvironError):
environ["TESTING"] = "False"
with pytest.raises(EnvironError):
del environ["GONE"]
# Test coverage of abstract methods for MutableMapping.
environ = Environ()
assert list(iter(environ)) == list(iter(os.environ))
assert len(environ) == len(os.environ)

View File

@ -1,4 +1,10 @@
from starlette.datastructures import URL, Headers, MutableHeaders, QueryParams
from starlette.datastructures import (
URL,
DatabaseURL,
Headers,
MutableHeaders,
QueryParams,
)
def test_url():
@ -38,6 +44,13 @@ def test_hidden_password():
assert repr(u) == "URL('https://username:********@example.org/path/to/somewhere')"
def test_database_url():
u = DatabaseURL("postgresql://username:password@localhost/mydatabase")
u = u.replace(name="test_" + u.name)
assert u.name == "test_mydatabase"
assert str(u) == "postgresql://username:password@localhost/test_mydatabase"
def test_url_from_scope():
u = URL(
scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []}