From fe2b926009e39884496dd7bb8578890155accb95 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 14 Dec 2018 16:22:31 +0000 Subject: [PATCH] Add `CommaSeparatedStrings` datatype (#274) * Add CommaSeparatedStrings datatype * Version 0.9.9 --- docs/config.md | 8 +++++--- starlette/__init__.py | 2 +- starlette/datastructures.py | 29 +++++++++++++++++++++++++++++ tests/test_datastructures.py | 25 +++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 4 deletions(-) diff --git a/docs/config.md b/docs/config.md index 6dd0c2dd..28e246d7 100644 --- a/docs/config.md +++ b/docs/config.md @@ -9,7 +9,7 @@ that is not committed to source control. ```python from starlette.applications import Starlette from starlette.config import Config -from starlette.datastructures import DatabaseURL, Secret +from starlette.datastructures import CommaSeparatedStrings, DatabaseURL, Secret # Config will be read from environment variables and/or ".env" files. config = Config(".env") @@ -17,6 +17,7 @@ config = Config(".env") DEBUG = config('DEBUG', cast=bool, default=False) DATABASE_URL = config('DATABASE_URL', cast=DatabaseURL) SECRET_KEY = config('SECRET_KEY', cast=Secret) +ALLOWED_HOSTS = config('ALLOWED_HOSTS', cast=CommaSeparatedStrings) app = Starlette() app.debug = DEBUG @@ -31,6 +32,7 @@ app.debug = DEBUG DEBUG=True DATABASE_URL=postgresql://localhost/myproject SECRET_KEY=43n080musdfjt54t-09sdgr +ALLOWED_HOSTS="127.0.0.1", "localhost" ``` ## Configuration precedence @@ -45,8 +47,8 @@ If none of those match, then `config(...)` will raise an error. ## Secrets -For sensitive keys, the `Secret` class is useful, since it prevents the value from -leaking out into tracebacks or logging. +For sensitive keys, the `Secret` class is useful, since it helps minimize +occasions where the value it holds could leak out into tracebacks or logging. To get the value of a `Secret` instance, you must explicitly cast it to a string. You should only do this at the point at which the value is used. diff --git a/starlette/__init__.py b/starlette/__init__.py index a25765c3..88081a72 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1 +1 @@ -__version__ = "0.9.8" +__version__ = "0.9.9" diff --git a/starlette/datastructures.py b/starlette/datastructures.py index c6fe1e7c..70f5614b 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -1,5 +1,7 @@ import typing from collections import namedtuple +from collections.abc import Sequence +from shlex import shlex from urllib.parse import ParseResult, parse_qsl, urlencode, urlparse from starlette.types import Scope @@ -195,6 +197,33 @@ class Secret: return self._value +class CommaSeparatedStrings(Sequence): + def __init__(self, value: typing.Union[str, typing.Sequence[str]]): + if isinstance(value, str): + splitter = shlex(value, posix=True) + splitter.whitespace = "," + splitter.whitespace_split = True + self._items = [item.strip() for item in splitter] + else: + self._items = list(value) + + def __len__(self) -> int: + return len(self._items) + + def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any: + return self._items[index] + + def __iter__(self) -> typing.Iterator[str]: + return iter(self._items) + + def __repr__(self) -> str: + list_repr = repr([item for item in self]) + return "%s(%s)" % (self.__class__.__name__, list_repr) + + def __str__(self) -> str: + return ", ".join([repr(item) for item in self]) + + class QueryParams(typing.Mapping[str, str]): """ An immutable multidict. diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 666aabd0..4f873539 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,5 +1,6 @@ from starlette.datastructures import ( URL, + CommaSeparatedStrings, DatabaseURL, Headers, MutableHeaders, @@ -59,6 +60,30 @@ def test_database_url(): assert u.driver == "asyncpg" +def test_csv(): + csv = CommaSeparatedStrings('"localhost", "127.0.0.1", 0.0.0.0') + assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"] + assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])" + assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'" + assert csv[0] == "localhost" + assert len(csv) == 3 + + csv = CommaSeparatedStrings("'localhost', '127.0.0.1', 0.0.0.0") + assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"] + assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])" + assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'" + + csv = CommaSeparatedStrings("localhost, 127.0.0.1, 0.0.0.0") + assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"] + assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])" + assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'" + + csv = CommaSeparatedStrings(["localhost", "127.0.0.1", "0.0.0.0"]) + assert list(csv) == ["localhost", "127.0.0.1", "0.0.0.0"] + assert repr(csv) == "CommaSeparatedStrings(['localhost', '127.0.0.1', '0.0.0.0'])" + assert str(csv) == "'localhost', '127.0.0.1', '0.0.0.0'" + + def test_url_from_scope(): u = URL( scope={"path": "/path/to/somewhere", "query_string": b"abc=123", "headers": []}