From d2512656e3689651d13b9925d4e29dedcf8f907a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 5 Dec 2018 12:28:18 +0000 Subject: [PATCH] Add Environ, DatabaseURL --- starlette/config.py | 40 ++++++++++++++++++++++++++++++++ starlette/database/postgres.py | 8 +++---- starlette/datastructures.py | 13 ++++++++++- starlette/middleware/database.py | 10 ++++---- tests/test_config.py | 37 ++++++++++++++++++++++------- tests/test_datastructures.py | 15 +++++++++++- 6 files changed, 105 insertions(+), 18 deletions(-) diff --git a/starlette/config.py b/starlette/config.py index 0ffb32d8..09b086bb 100644 --- a/starlette/config.py +++ b/starlette/config.py @@ -1,11 +1,51 @@ import os import typing +from collections.abc import MutableMapping class undefined: pass +class EnvironError(Exception): + pass + + +class Environ(MutableMapping): + def __init__(self, environ: typing.MutableMapping = os.environ): + self._environ = environ + self._has_been_read = set() # type: typing.Set[typing.Any] + + def __getitem__(self, key: typing.Any) -> typing.Any: + self._has_been_read.add(key) + return self._environ.__getitem__(key) + + def __setitem__(self, key: typing.Any, value: typing.Any) -> None: + if key in self._has_been_read: + raise EnvironError( + "Attempting to set environ['%s'], but the value has already be read." + % key + ) + self._environ.__setitem__(key, value) + + def __delitem__(self, key: typing.Any) -> None: + if key in self._has_been_read: + raise EnvironError( + "Attempting to delete environ['%s'], but the value has already be read." + % key + ) + self._environ.__delitem__(key) + + def __iter__(self) -> typing.Iterator: + return iter(self._environ) + + def __len__(self) -> int: + return len(self._environ) + + +environ = Environ() + + class Config: def __init__( self, env_file: str = None, environ: typing.Mapping[str, str] = os.environ diff --git a/starlette/database/postgres.py b/starlette/database/postgres.py index e3da73f8..63b141e8 100644 --- a/starlette/database/postgres.py +++ b/starlette/database/postgres.py @@ -12,12 +12,12 @@ from starlette.database.core import ( DatabaseTransaction, compile, ) -from starlette.datastructures import URL +from starlette.datastructures import DatabaseURL class PostgresBackend(DatabaseBackend): - def __init__(self, database_url: typing.Union[str, URL]) -> None: - self.database_url = str(database_url) + def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None: + self.database_url = database_url self.dialect = self.get_dialect() self.pool = None @@ -34,7 +34,7 @@ class PostgresBackend(DatabaseBackend): return dialect async def startup(self) -> None: - self.pool = await asyncpg.create_pool(self.database_url) + self.pool = await asyncpg.create_pool(str(self.database_url)) async def shutdown(self) -> None: assert self.pool is not None, "DatabaseBackend is not running" diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 4ef81def..2b57281e 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -116,7 +116,7 @@ class URL: kwargs["netloc"] = netloc components = self.components._replace(**kwargs) - return URL(components.geturl()) + return self.__class__(components.geturl()) def __eq__(self, other: typing.Any) -> bool: return str(self) == str(other) @@ -131,6 +131,17 @@ class URL: return "%s(%s)" % (self.__class__.__name__, repr(url)) +class DatabaseURL(URL): + @property + def name(self) -> str: + return self.path.lstrip("/") + + def replace(self, **kwargs: typing.Any) -> "URL": + if "name" in kwargs: + kwargs["path"] = "/" + kwargs.pop("name") + return super().replace(**kwargs) + + class URLPath(str): """ A URL path string that also holds an associated protocol. diff --git a/starlette/middleware/database.py b/starlette/middleware/database.py index 4f0c3de4..fc15f2df 100644 --- a/starlette/middleware/database.py +++ b/starlette/middleware/database.py @@ -7,7 +7,7 @@ from starlette.database.core import ( DatabaseSession, DatabaseTransaction, ) -from starlette.datastructures import URL +from starlette.datastructures import DatabaseURL from starlette.types import ASGIApp, ASGIInstance, Message, Receive, Scope, Send @@ -15,7 +15,7 @@ class DatabaseMiddleware: def __init__( self, app: ASGIApp, - database_url: typing.Union[str, URL], + database_url: typing.Union[str, DatabaseURL], rollback_on_shutdown: bool, ) -> None: self.app = app @@ -24,9 +24,11 @@ class DatabaseMiddleware: self.session = None # type: typing.Optional[DatabaseSession] self.transaction = None # type: typing.Optional[DatabaseTransaction] - def get_backend(self, database_url: typing.Union[str, URL]) -> DatabaseBackend: + def get_backend( + self, database_url: typing.Union[str, DatabaseURL] + ) -> DatabaseBackend: if isinstance(database_url, str): - database_url = URL(database_url) + database_url = DatabaseURL(database_url) assert database_url.scheme == "postgresql" from starlette.database.postgres import PostgresBackend diff --git a/tests/test_config.py b/tests/test_config.py index d3686ebd..1b46a1d9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,8 +2,8 @@ import os import pytest -from starlette.config import Config -from starlette.datastructures import URL +from starlette.config import Config, Environ, EnvironError +from starlette.datastructures import DatabaseURL def test_config(tmpdir): @@ -20,16 +20,12 @@ def test_config(tmpdir): config = Config(path, environ={"DEBUG": "true"}) DEBUG = config.get("DEBUG", cast=bool) - DATABASE_URL = config.get("DATABASE_URL", cast=URL) + DATABASE_URL = config.get("DATABASE_URL", cast=DatabaseURL) REQUEST_TIMEOUT = config.get("REQUEST_TIMEOUT", cast=int, default=10) REQUEST_HOSTNAME = config.get("REQUEST_HOSTNAME") assert DEBUG is True - assert str(DATABASE_URL) == "postgres://username:password@localhost/database_name" - assert ( - repr(DATABASE_URL) - == "URL('postgres://username:********@localhost/database_name')" - ) + assert DATABASE_URL.name == "database_name" assert REQUEST_TIMEOUT == 10 assert REQUEST_HOSTNAME == "example.com" @@ -45,3 +41,28 @@ def test_config(tmpdir): os.environ["STARLETTE_EXAMPLE_TEST"] = "123" config = Config() assert config.get("STARLETTE_EXAMPLE_TEST", cast=int) == 123 + + +def test_environ(): + environ = Environ() + + # We can mutate the environ at this point. + environ["TESTING"] = "True" + environ["GONE"] = "123" + del environ["GONE"] + + # We can read the environ. + assert environ["TESTING"] == "True" + assert "GONE" not in environ + + # We cannot mutate these keys now that we've read them. + with pytest.raises(EnvironError): + environ["TESTING"] = "False" + + with pytest.raises(EnvironError): + del environ["GONE"] + + # Test coverage of abstract methods for MutableMapping. + environ = Environ() + assert list(iter(environ)) == list(iter(os.environ)) + assert len(environ) == len(os.environ) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 9b789b49..9560630e 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,4 +1,10 @@ -from starlette.datastructures import URL, Headers, MutableHeaders, QueryParams +from starlette.datastructures import ( + URL, + DatabaseURL, + Headers, + MutableHeaders, + QueryParams, +) def test_url(): @@ -38,6 +44,13 @@ def test_hidden_password(): assert repr(u) == "URL('https://username:********@example.org/path/to/somewhere')" +def test_database_url(): + u = DatabaseURL("postgresql://username:password@localhost/mydatabase") + u = u.replace(name="test_" + u.name) + assert u.name == "test_mydatabase" + assert str(u) == "postgresql://username:password@localhost/test_mydatabase" + + def test_url_from_scope(): u = URL( scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []}