websockets: server handshake scheme

Also refactor settings and resolution interfaces
This commit is contained in:
Aldo Cortesi 2015-04-22 15:49:17 +12:00
parent 65f04bf4d1
commit 99cb0808ab
10 changed files with 243 additions and 80 deletions

View File

@ -1,6 +1,7 @@
import logging
import pprint
import cStringIO
import copy
from flask import Flask, jsonify, render_template, request, abort, make_response
import version, language, utils
from netlib import http_uastrings
@ -10,6 +11,7 @@ logging.basicConfig(level="DEBUG")
def make_app(noapi):
app = Flask(__name__)
# app.debug = True
if not noapi:
@app.route('/api/info')
@ -144,20 +146,17 @@ def make_app(noapi):
c = app.config["pathod"].check_policy(
safe,
app.config["pathod"].request_settings
app.config["pathod"].settings
)
if c:
args["error"] = c
return render(template, False, **args)
if is_request:
language.serve(
safe,
s,
app.config["pathod"].request_settings,
request_host = "example.com"
)
set = copy.copy(app.config["pathod"].settings)
set.request_host = "example.com"
language.serve(safe, s, set)
else:
language.serve(safe, s, app.config["pathod"].request_settings)
language.serve(safe, s, app.config["pathod"].settings)
args["output"] = utils.escape_unprintables(s.getvalue())
return render(template, False, **args)

View File

@ -15,6 +15,20 @@ BLOCKSIZE = 1024
TRUNCATE = 1024
class Settings:
def __init__(
self,
staticdir = None,
unconstrained_file_access = False,
request_host = None,
websocket_key = None
):
self.staticdir = staticdir
self.unconstrained_file_access = unconstrained_file_access
self.request_host = request_host
self.websocket_key = websocket_key
def quote(s):
quotechar = s[0]
s = s[1:-1]
@ -22,7 +36,11 @@ def quote(s):
return quotechar + s + quotechar
class FileAccessDenied(Exception):
class RenderError(Exception):
pass
class FileAccessDenied(RenderError):
pass
@ -97,7 +115,7 @@ def write_values(fp, vals, actions, sofar=0, blocksize=BLOCKSIZE):
return True
def serve(msg, fp, settings, **kwargs):
def serve(msg, fp, settings):
"""
fp: The file pointer to write to.
@ -107,7 +125,7 @@ def serve(msg, fp, settings, **kwargs):
Calling this function may modify the object.
"""
msg = msg.resolve(settings, **kwargs)
msg = msg.resolve(settings)
started = time.time()
vals = msg.values(settings)
@ -351,15 +369,16 @@ class ValueFile(_Token):
return self
def get_generator(self, settings):
uf = settings.get("unconstrained_file_access")
sd = settings.get("staticdir")
if not sd:
if not settings.staticdir:
raise FileAccessDenied("File access disabled.")
sd = os.path.normpath(os.path.abspath(sd))
sd = os.path.normpath(os.path.abspath(settings.staticdir))
s = os.path.expanduser(self.path)
s = os.path.normpath(os.path.abspath(os.path.join(sd, s)))
if not uf and not s.startswith(sd):
s = os.path.normpath(
os.path.abspath(os.path.join(settings.staticdir, s))
)
uf = settings.unconstrained_file_access
if not uf and not s.startswith(settings.staticdir):
raise FileAccessDenied(
"File access outside of configured directory"
)
@ -594,9 +613,28 @@ class Path(_Component):
return Path(self.value.freeze(settings))
class WS(_Component):
def __init__(self, value):
self.value = value
@classmethod
def expr(klass):
spec = pp.Literal("ws")
spec = spec.setParseAction(lambda x: klass(*x))
return spec
def values(self, settings):
return "ws"
def spec(self):
return "ws"
def freeze(self, settings):
return self
class Method(_Component):
methods = [
"ws",
"get",
"head",
"post",
@ -797,29 +835,35 @@ class _Message(object):
def __init__(self, tokens):
self.tokens = tokens
def _get_tokens(self, klass):
def toks(self, klass):
"""
Fetch all tokens that are instances of klass
"""
return [i for i in self.tokens if isinstance(i, klass)]
def _get_token(self, klass):
l = self._get_tokens(klass)
def tok(self, klass):
"""
Fetch first token that is an instance of klass
"""
l = self.toks(klass)
if l:
return l[0]
@property
def raw(self):
return bool(self._get_token(Raw))
return bool(self.tok(Raw))
@property
def actions(self):
return self._get_tokens(_Action)
return self.toks(_Action)
@property
def body(self):
return self._get_token(Body)
return self.tok(Body)
@property
def headers(self):
return self._get_tokens(_Header)
return self.toks(_Header)
def length(self, settings):
"""
@ -883,8 +927,8 @@ class _Message(object):
vals.append(self.body.value.get_generator(settings))
return vals
def freeze(self, settings, **kwargs):
r = self.resolve(settings, **kwargs)
def freeze(self, settings):
r = self.resolve(settings)
return self.__class__([i.freeze(settings) for i in r.tokens])
def __repr__(self):
@ -908,17 +952,26 @@ class Response(_Message):
)
logattrs = ["code", "reason", "version", "body"]
@property
def ws(self):
return self.tok(WS)
@property
def code(self):
return self._get_token(Code)
return self.tok(Code)
@property
def reason(self):
return self._get_token(Reason)
return self.tok(Reason)
def preamble(self, settings):
l = [self.version, " "]
l.extend(self.code.values(settings))
if self.code:
l.extend(self.code.values(settings))
code = int(self.code.code)
elif self.ws:
l.extend(Code(101).values(settings))
code = 101
l.append(" ")
if self.reason:
l.extend(self.reason.values(settings))
@ -926,7 +979,7 @@ class Response(_Message):
l.append(
LiteralGenerator(
http_status.RESPONSES.get(
int(self.code.code),
code,
"Unknown code"
)
)
@ -935,6 +988,22 @@ class Response(_Message):
def resolve(self, settings):
tokens = self.tokens[:]
if self.ws:
if not settings.websocket_key:
raise RenderError(
"No websocket key - have we seen a client handshake?"
)
if not self.code:
tokens.insert(
1,
Code(101)
)
hdrs = websockets.server_handshake_headers(settings.websocket_key)
for i in hdrs.lst:
if not utils.get_header(i[0], self.headers):
tokens.append(
Header(ValueLiteral(i[0]), ValueLiteral(i[1]))
)
if not self.raw:
if not utils.get_header("Content-Length", self.headers):
if not self.body:
@ -958,7 +1027,12 @@ class Response(_Message):
atom = pp.MatchFirst(parts)
resp = pp.And(
[
Code.expr(),
pp.MatchFirst(
[
WS.expr() + pp.Optional(Sep + Code.expr()),
Code.expr(),
]
),
pp.ZeroOrMore(Sep + atom)
]
)
@ -982,17 +1056,21 @@ class Request(_Message):
)
logattrs = ["method", "path", "body"]
@property
def ws(self):
return self.tok(WS)
@property
def method(self):
return self._get_token(Method)
return self.tok(Method)
@property
def path(self):
return self._get_token(Path)
return self.tok(Path)
@property
def pathodspec(self):
return self._get_token(PathodSpec)
return self.tok(PathodSpec)
def preamble(self, settings):
v = self.method.values(settings)
@ -1004,10 +1082,14 @@ class Request(_Message):
v.append(self.version)
return v
def resolve(self, settings, **kwargs):
def resolve(self, settings):
tokens = self.tokens[:]
if self.method.string().lower() == "ws":
tokens[0] = Method("get")
if self.ws:
if not self.method:
tokens.insert(
1,
Method("get")
)
for i in websockets.client_handshake_headers().lst:
if not utils.get_header(i[0], self.headers):
tokens.append(
@ -1023,13 +1105,12 @@ class Request(_Message):
ValueLiteral(str(length)),
)
)
request_host = kwargs.get("request_host")
if request_host:
if settings.request_host:
if not utils.get_header("Host", self.headers):
tokens.append(
Header(
ValueLiteral("Host"),
ValueLiteral(request_host)
ValueLiteral(settings.request_host)
)
)
intermediate = self.__class__(tokens)
@ -1043,7 +1124,12 @@ class Request(_Message):
atom = pp.MatchFirst(parts)
resp = pp.And(
[
Method.expr(),
pp.MatchFirst(
[
WS.expr() + pp.Optional(Sep + Method.expr()),
Method.expr(),
]
),
Sep,
Path.expr(),
pp.ZeroOrMore(Sep + atom)

View File

@ -108,9 +108,10 @@ class Pathoc(tcp.TCPClient):
ignorecodes: Sequence of return codes to ignore
"""
tcp.TCPClient.__init__(self, address)
self.settings = dict(
self.settings = language.Settings(
staticdir = os.getcwd(),
unconstrained_file_access = True,
request_host = self.address.host
)
self.ssl, self.sni = ssl, sni
self.clientcert = clientcert
@ -201,15 +202,14 @@ class Pathoc(tcp.TCPClient):
if self.showresp:
self.rfile.start_log()
try:
req = language.serve(
r,
self.wfile,
self.settings,
request_host = self.address.host
)
req = language.serve(r, self.wfile, self.settings)
self.wfile.flush()
resp = list(
http.read_response(self.rfile, r.method.string(), None)
http.read_response(
self.rfile,
req["method"],
None
)
)
resp.append(self.sslinfo)
resp = Response(*resp)
@ -290,7 +290,7 @@ def main(args): # pragma: nocover
)
if args.explain or args.memo:
playlist = [
i.freeze(p.settings, request_host=p.address.host) for i in playlist
i.freeze(p.settings) for i in playlist
]
if args.memo:
newlist = []

View File

@ -66,10 +66,10 @@ class PathodHandler(tcp.BaseHandler):
self.sni = connection.get_servername()
def serve_crafted(self, crafted):
c = self.server.check_policy(crafted, self.server.request_settings)
c = self.server.check_policy(crafted, self.server.settings)
if c:
err = language.make_error_response(c)
language.serve(err, self.wfile, self.server.request_settings)
language.serve(err, self.wfile, self.server.settings)
log = dict(
type="error",
msg=c
@ -77,12 +77,12 @@ class PathodHandler(tcp.BaseHandler):
return False, log
if self.server.explain and not isinstance(crafted, language.PathodErrorResponse):
crafted = crafted.freeze(self.server.request_settings)
crafted = crafted.freeze(self.server.settings)
self.info(">> Spec: %s" % crafted.spec())
response_log = language.serve(
crafted,
self.wfile,
self.server.request_settings
self.server.settings
)
if response_log["disconnect"]:
return False, response_log
@ -199,7 +199,7 @@ class PathodHandler(tcp.BaseHandler):
return again, retlog
elif self.server.noweb:
crafted = language.make_error_response("Access Denied")
language.serve(crafted, self.wfile, self.server.request_settings)
language.serve(crafted, self.wfile, self.server.settings)
return False, dict(
type="error",
msg="Access denied: web interface disabled"
@ -323,6 +323,10 @@ class Pathod(tcp.TCPServer):
self.logid = 0
self.anchors = anchors
self.settings = language.Settings(
staticdir = self.staticdir
)
def check_policy(self, req, settings):
"""
A policy check that verifies the request size is withing limits.
@ -337,12 +341,6 @@ class Pathod(tcp.TCPServer):
return "Pauses have been disabled."
return False
@property
def request_settings(self):
return dict(
staticdir=self.staticdir
)
def handle_client_connection(self, request, client_address):
h = PathodHandler(request, client_address, self)
try:

View File

@ -11,6 +11,7 @@
<ul class="nav nav-tabs">
<li class="active"><a href="#specifying_responses" data-toggle="tab">Responses</a></li>
<li><a href="#specifying_requests" data-toggle="tab">Requests</a></li>
<li><a href="#websockets" data-toggle="tab">Websockets</a></li>
</ul>
<div class="tab-content">
@ -199,6 +200,43 @@
</table>
</div>
<div class="tab-pane" id="websockets">
<p>Requests and responses can be decorated with the <b>ws</b> prefix to
create a websockets client or server handshake. Since the websocket
specifier implies a request method (GET) and a response code (102),
these can optionally be omitted. All other request and response
features can be applied, and websocket-specific headers can be
over-ridden explicitly.</p>
<h2>Request</h2>
<pre class="example">ws:[method:]path:[colon-separated list of features]</pre></p>
<p>This will generate a wsocket client handshake with a GET method:</p>
<pre class="example">ws:/</pre></p>
<p>This will do the same, but using the (invalid) PUT method:</p>
<pre class="example">ws:put:/</pre></p>
<h2>Response</h2>
<pre class="example">ws[:code:][colon-separated list of features]</pre></p>
<p>This will generate a simple protocol acceptance with a 101 response
code:</p>
<pre class="example">ws</pre></p>
<p>This will do the same, but using the (invalid) 202 code:</p>
<pre class="example">ws:202</pre></p>
</div>
</div>

View File

@ -1,5 +1,5 @@
<form style="margin-bottom: 0" class="form-inline" method="GET" action="/request_preview">
<input
<input
style="width: 18em"
id="spec"
name="spec"
@ -46,6 +46,10 @@
<td><a href="/request_preview?spec=get:/:b@100,ascii:ir,@1">get:/:b@100,ascii:ir,@1</a></td>
<td>100 ASCII bytes as the body, and randomly inject a random byte</td>
</tr>
<tr>
<td><a href="/request_preview?spec=ws:/">ws:/</a></td>
<td>Initiate a websocket handshake.</td>
</tr>
</tbody>
</table>
</div>

View File

@ -1,5 +1,5 @@
<form style="margin-bottom: 0" class="form-inline" method="GET" action="/response_preview">
<input
<input
style="width: 18em"
id="spec"
name="spec"
@ -68,6 +68,12 @@
</td>
<td>100 ASCII bytes as the body, randomly generated 100k header name, with the value 'foo'.</td>
</tr>
<tr>
<td>
<a href="/response_preview?spec=ws">ws</a>
</td>
<td>A websocket connection acceptance handshake.</td>
</tr>
</tbody>
</table>
</div>

View File

@ -101,7 +101,7 @@ class TestValueGenerate:
def test_freeze(self):
v = language.ValueGenerate(100, "b", "ascii")
f = v.freeze({})
f = v.freeze(language.Settings())
assert len(f.val) == 100
@ -121,16 +121,26 @@ class TestValueFile:
with open(p, "wb") as f:
f.write("x" * 10000)
assert v.get_generator(dict(staticdir=t))
assert v.get_generator(language.Settings(staticdir=t))
v = language.Value.parseString("<path2")[0]
tutils.raises(
language.FileAccessDenied, v.get_generator, dict(staticdir=t)
language.FileAccessDenied,
v.get_generator,
language.Settings(staticdir=t)
)
tutils.raises(
"access disabled",
v.get_generator,
language.Settings()
)
tutils.raises("access disabled", v.get_generator, dict())
v = language.Value.parseString("</outside")[0]
tutils.raises("outside", v.get_generator, dict(staticdir=t))
tutils.raises(
"outside",
v.get_generator,
language.Settings(staticdir=t)
)
def test_spec(self):
v = language.Value.parseString("<'one two'")[0]
@ -556,7 +566,12 @@ class TestRequest:
def test_render(self):
s = cStringIO.StringIO()
r = parse_request("GET:'/foo'")
assert language.serve(r, s, {}, request_host = "foo.com")
assert language.serve(
r,
s,
language.Settings(request_host = "foo.com")
)
def test_multiline(self):
l = """
@ -593,17 +608,28 @@ class TestRequest:
rt("get:/foo:da")
def test_freeze(self):
r = parse_request("GET:/:b@100").freeze({})
r = parse_request("GET:/:b@100").freeze(language.Settings())
assert len(r.spec()) > 100
def test_path_generator(self):
r = parse_request("GET:@100").freeze({})
r = parse_request("GET:@100").freeze(language.Settings())
assert len(r.spec()) > 100
def test_websocket(self):
r = parse_request('ws:"/foo"')
res = r.resolve({})
assert utils.get_header("upgrade", res.headers)
r = parse_request('ws:/path/')
res = r.resolve(language.Settings())
assert res.method.string().lower() == "get"
assert res.tok(language.Path).value.val == "/path/"
assert res.tok(language.Method).value.val.lower() == "get"
assert utils.get_header("Upgrade", res.headers).value.val == "websocket"
r = parse_request('ws:put:/path/')
res = r.resolve(language.Settings())
assert r.method.string().lower() == "put"
assert res.tok(language.Path).value.val == "/path/"
assert res.tok(language.Method).value.val.lower() == "put"
assert utils.get_header("Upgrade", res.headers).value.val == "websocket"
class TestWriteValues:
@ -725,13 +751,13 @@ class TestResponse:
r = language.parse_response("400:b'foo':r")
language.serve(r, s, {})
v = s.getvalue()
assert not "Content-Length" in v
assert "Content-Length" not in v
def test_length(self):
def testlen(x):
s = cStringIO.StringIO()
language.serve(x, s, {})
assert x.length({}) == len(s.getvalue())
language.serve(x, s, language.Settings())
assert x.length(language.Settings()) == len(s.getvalue())
testlen(language.parse_response("400:m'msg':r"))
testlen(language.parse_response("400:m'msg':h'foo'='bar':r"))
testlen(language.parse_response("400:m'msg':h'foo'='bar':b@100b:r"))
@ -797,6 +823,12 @@ class TestResponse:
rt("400")
rt("400:da")
def test_websockets(self):
r = language.parse_response("ws")
tutils.raises("no websocket key", r.resolve, language.Settings())
res = r.resolve(language.Settings(websocket_key="foo"))
assert res.code.string() == "101"
def test_read_file():
tutils.raises(language.FileAccessDenied, language.read_file, {}, "=/foo")

View File

@ -74,7 +74,7 @@ class _TestDaemon:
for i in requests:
r = language.parse_requests(i)[0]
if explain:
r = r.freeze({})
r = r.freeze(language.Settings())
try:
c.request(r)
except (http.HttpError, tcp.NetLibError), v:

View File

@ -13,7 +13,7 @@ class TestPathod(object):
p.clear_log()
assert len(p.get_log()) == 0
for i in range(p.LOGBUF + 1):
for _ in range(p.LOGBUF + 1):
p.add_log(dict(s="foo"))
assert len(p.get_log()) <= p.LOGBUF