diff --git a/libpathod/app.py b/libpathod/app.py index 858c1d865..c3ce9991a 100644 --- a/libpathod/app.py +++ b/libpathod/app.py @@ -1,6 +1,7 @@ import logging import pprint import cStringIO +import copy from flask import Flask, jsonify, render_template, request, abort, make_response import version, language, utils from netlib import http_uastrings @@ -10,6 +11,7 @@ logging.basicConfig(level="DEBUG") def make_app(noapi): app = Flask(__name__) + # app.debug = True if not noapi: @app.route('/api/info') @@ -144,20 +146,17 @@ def make_app(noapi): c = app.config["pathod"].check_policy( safe, - app.config["pathod"].request_settings + app.config["pathod"].settings ) if c: args["error"] = c return render(template, False, **args) if is_request: - language.serve( - safe, - s, - app.config["pathod"].request_settings, - request_host = "example.com" - ) + set = copy.copy(app.config["pathod"].settings) + set.request_host = "example.com" + language.serve(safe, s, set) else: - language.serve(safe, s, app.config["pathod"].request_settings) + language.serve(safe, s, app.config["pathod"].settings) args["output"] = utils.escape_unprintables(s.getvalue()) return render(template, False, **args) diff --git a/libpathod/language.py b/libpathod/language.py index 29d2ade89..28976a294 100644 --- a/libpathod/language.py +++ b/libpathod/language.py @@ -15,6 +15,20 @@ BLOCKSIZE = 1024 TRUNCATE = 1024 +class Settings: + def __init__( + self, + staticdir = None, + unconstrained_file_access = False, + request_host = None, + websocket_key = None + ): + self.staticdir = staticdir + self.unconstrained_file_access = unconstrained_file_access + self.request_host = request_host + self.websocket_key = websocket_key + + def quote(s): quotechar = s[0] s = s[1:-1] @@ -22,7 +36,11 @@ def quote(s): return quotechar + s + quotechar -class FileAccessDenied(Exception): +class RenderError(Exception): + pass + + +class FileAccessDenied(RenderError): pass @@ -97,7 +115,7 @@ def write_values(fp, vals, actions, sofar=0, blocksize=BLOCKSIZE): return True -def serve(msg, fp, settings, **kwargs): +def serve(msg, fp, settings): """ fp: The file pointer to write to. @@ -107,7 +125,7 @@ def serve(msg, fp, settings, **kwargs): Calling this function may modify the object. """ - msg = msg.resolve(settings, **kwargs) + msg = msg.resolve(settings) started = time.time() vals = msg.values(settings) @@ -351,15 +369,16 @@ class ValueFile(_Token): return self def get_generator(self, settings): - uf = settings.get("unconstrained_file_access") - sd = settings.get("staticdir") - if not sd: + if not settings.staticdir: raise FileAccessDenied("File access disabled.") - sd = os.path.normpath(os.path.abspath(sd)) + sd = os.path.normpath(os.path.abspath(settings.staticdir)) s = os.path.expanduser(self.path) - s = os.path.normpath(os.path.abspath(os.path.join(sd, s))) - if not uf and not s.startswith(sd): + s = os.path.normpath( + os.path.abspath(os.path.join(settings.staticdir, s)) + ) + uf = settings.unconstrained_file_access + if not uf and not s.startswith(settings.staticdir): raise FileAccessDenied( "File access outside of configured directory" ) @@ -594,9 +613,28 @@ class Path(_Component): return Path(self.value.freeze(settings)) +class WS(_Component): + def __init__(self, value): + self.value = value + + @classmethod + def expr(klass): + spec = pp.Literal("ws") + spec = spec.setParseAction(lambda x: klass(*x)) + return spec + + def values(self, settings): + return "ws" + + def spec(self): + return "ws" + + def freeze(self, settings): + return self + + class Method(_Component): methods = [ - "ws", "get", "head", "post", @@ -797,29 +835,35 @@ class _Message(object): def __init__(self, tokens): self.tokens = tokens - def _get_tokens(self, klass): + def toks(self, klass): + """ + Fetch all tokens that are instances of klass + """ return [i for i in self.tokens if isinstance(i, klass)] - def _get_token(self, klass): - l = self._get_tokens(klass) + def tok(self, klass): + """ + Fetch first token that is an instance of klass + """ + l = self.toks(klass) if l: return l[0] @property def raw(self): - return bool(self._get_token(Raw)) + return bool(self.tok(Raw)) @property def actions(self): - return self._get_tokens(_Action) + return self.toks(_Action) @property def body(self): - return self._get_token(Body) + return self.tok(Body) @property def headers(self): - return self._get_tokens(_Header) + return self.toks(_Header) def length(self, settings): """ @@ -883,8 +927,8 @@ class _Message(object): vals.append(self.body.value.get_generator(settings)) return vals - def freeze(self, settings, **kwargs): - r = self.resolve(settings, **kwargs) + def freeze(self, settings): + r = self.resolve(settings) return self.__class__([i.freeze(settings) for i in r.tokens]) def __repr__(self): @@ -908,17 +952,26 @@ class Response(_Message): ) logattrs = ["code", "reason", "version", "body"] + @property + def ws(self): + return self.tok(WS) + @property def code(self): - return self._get_token(Code) + return self.tok(Code) @property def reason(self): - return self._get_token(Reason) + return self.tok(Reason) def preamble(self, settings): l = [self.version, " "] - l.extend(self.code.values(settings)) + if self.code: + l.extend(self.code.values(settings)) + code = int(self.code.code) + elif self.ws: + l.extend(Code(101).values(settings)) + code = 101 l.append(" ") if self.reason: l.extend(self.reason.values(settings)) @@ -926,7 +979,7 @@ class Response(_Message): l.append( LiteralGenerator( http_status.RESPONSES.get( - int(self.code.code), + code, "Unknown code" ) ) @@ -935,6 +988,22 @@ class Response(_Message): def resolve(self, settings): tokens = self.tokens[:] + if self.ws: + if not settings.websocket_key: + raise RenderError( + "No websocket key - have we seen a client handshake?" + ) + if not self.code: + tokens.insert( + 1, + Code(101) + ) + hdrs = websockets.server_handshake_headers(settings.websocket_key) + for i in hdrs.lst: + if not utils.get_header(i[0], self.headers): + tokens.append( + Header(ValueLiteral(i[0]), ValueLiteral(i[1])) + ) if not self.raw: if not utils.get_header("Content-Length", self.headers): if not self.body: @@ -958,7 +1027,12 @@ class Response(_Message): atom = pp.MatchFirst(parts) resp = pp.And( [ - Code.expr(), + pp.MatchFirst( + [ + WS.expr() + pp.Optional(Sep + Code.expr()), + Code.expr(), + ] + ), pp.ZeroOrMore(Sep + atom) ] ) @@ -982,17 +1056,21 @@ class Request(_Message): ) logattrs = ["method", "path", "body"] + @property + def ws(self): + return self.tok(WS) + @property def method(self): - return self._get_token(Method) + return self.tok(Method) @property def path(self): - return self._get_token(Path) + return self.tok(Path) @property def pathodspec(self): - return self._get_token(PathodSpec) + return self.tok(PathodSpec) def preamble(self, settings): v = self.method.values(settings) @@ -1004,10 +1082,14 @@ class Request(_Message): v.append(self.version) return v - def resolve(self, settings, **kwargs): + def resolve(self, settings): tokens = self.tokens[:] - if self.method.string().lower() == "ws": - tokens[0] = Method("get") + if self.ws: + if not self.method: + tokens.insert( + 1, + Method("get") + ) for i in websockets.client_handshake_headers().lst: if not utils.get_header(i[0], self.headers): tokens.append( @@ -1023,13 +1105,12 @@ class Request(_Message): ValueLiteral(str(length)), ) ) - request_host = kwargs.get("request_host") - if request_host: + if settings.request_host: if not utils.get_header("Host", self.headers): tokens.append( Header( ValueLiteral("Host"), - ValueLiteral(request_host) + ValueLiteral(settings.request_host) ) ) intermediate = self.__class__(tokens) @@ -1043,7 +1124,12 @@ class Request(_Message): atom = pp.MatchFirst(parts) resp = pp.And( [ - Method.expr(), + pp.MatchFirst( + [ + WS.expr() + pp.Optional(Sep + Method.expr()), + Method.expr(), + ] + ), Sep, Path.expr(), pp.ZeroOrMore(Sep + atom) diff --git a/libpathod/pathoc.py b/libpathod/pathoc.py index 0d8ec8f94..cf9be5b9d 100644 --- a/libpathod/pathoc.py +++ b/libpathod/pathoc.py @@ -108,9 +108,10 @@ class Pathoc(tcp.TCPClient): ignorecodes: Sequence of return codes to ignore """ tcp.TCPClient.__init__(self, address) - self.settings = dict( + self.settings = language.Settings( staticdir = os.getcwd(), unconstrained_file_access = True, + request_host = self.address.host ) self.ssl, self.sni = ssl, sni self.clientcert = clientcert @@ -201,15 +202,14 @@ class Pathoc(tcp.TCPClient): if self.showresp: self.rfile.start_log() try: - req = language.serve( - r, - self.wfile, - self.settings, - request_host = self.address.host - ) + req = language.serve(r, self.wfile, self.settings) self.wfile.flush() resp = list( - http.read_response(self.rfile, r.method.string(), None) + http.read_response( + self.rfile, + req["method"], + None + ) ) resp.append(self.sslinfo) resp = Response(*resp) @@ -290,7 +290,7 @@ def main(args): # pragma: nocover ) if args.explain or args.memo: playlist = [ - i.freeze(p.settings, request_host=p.address.host) for i in playlist + i.freeze(p.settings) for i in playlist ] if args.memo: newlist = [] diff --git a/libpathod/pathod.py b/libpathod/pathod.py index 1c23baaef..0c6267772 100644 --- a/libpathod/pathod.py +++ b/libpathod/pathod.py @@ -66,10 +66,10 @@ class PathodHandler(tcp.BaseHandler): self.sni = connection.get_servername() def serve_crafted(self, crafted): - c = self.server.check_policy(crafted, self.server.request_settings) + c = self.server.check_policy(crafted, self.server.settings) if c: err = language.make_error_response(c) - language.serve(err, self.wfile, self.server.request_settings) + language.serve(err, self.wfile, self.server.settings) log = dict( type="error", msg=c @@ -77,12 +77,12 @@ class PathodHandler(tcp.BaseHandler): return False, log if self.server.explain and not isinstance(crafted, language.PathodErrorResponse): - crafted = crafted.freeze(self.server.request_settings) + crafted = crafted.freeze(self.server.settings) self.info(">> Spec: %s" % crafted.spec()) response_log = language.serve( crafted, self.wfile, - self.server.request_settings + self.server.settings ) if response_log["disconnect"]: return False, response_log @@ -199,7 +199,7 @@ class PathodHandler(tcp.BaseHandler): return again, retlog elif self.server.noweb: crafted = language.make_error_response("Access Denied") - language.serve(crafted, self.wfile, self.server.request_settings) + language.serve(crafted, self.wfile, self.server.settings) return False, dict( type="error", msg="Access denied: web interface disabled" @@ -323,6 +323,10 @@ class Pathod(tcp.TCPServer): self.logid = 0 self.anchors = anchors + self.settings = language.Settings( + staticdir = self.staticdir + ) + def check_policy(self, req, settings): """ A policy check that verifies the request size is withing limits. @@ -337,12 +341,6 @@ class Pathod(tcp.TCPServer): return "Pauses have been disabled." return False - @property - def request_settings(self): - return dict( - staticdir=self.staticdir - ) - def handle_client_connection(self, request, client_address): h = PathodHandler(request, client_address, self) try: diff --git a/libpathod/templates/docs_lang.html b/libpathod/templates/docs_lang.html index 4ed7f151d..e67b13c56 100644 --- a/libpathod/templates/docs_lang.html +++ b/libpathod/templates/docs_lang.html @@ -11,6 +11,7 @@
@@ -199,6 +200,43 @@
+
+ +

Requests and responses can be decorated with the ws prefix to + create a websockets client or server handshake. Since the websocket + specifier implies a request method (GET) and a response code (102), + these can optionally be omitted. All other request and response + features can be applied, and websocket-specific headers can be + over-ridden explicitly.

+ +

Request

+ +
ws:[method:]path:[colon-separated list of features]

+ +

This will generate a wsocket client handshake with a GET method:

+ +
ws:/

+ +

This will do the same, but using the (invalid) PUT method:

+ +
ws:put:/

+ + +

Response

+ +
ws[:code:][colon-separated list of features]

+ +

This will generate a simple protocol acceptance with a 101 response + code:

+ +
ws

+ +

This will do the same, but using the (invalid) 202 code:

+ +
ws:202

+ +
+ diff --git a/libpathod/templates/request_previewform.html b/libpathod/templates/request_previewform.html index 607bfefdb..d30837359 100644 --- a/libpathod/templates/request_previewform.html +++ b/libpathod/templates/request_previewform.html @@ -1,5 +1,5 @@
- get:/:b@100,ascii:ir,@1 100 ASCII bytes as the body, and randomly inject a random byte + + ws:/ + Initiate a websocket handshake. + diff --git a/libpathod/templates/response_previewform.html b/libpathod/templates/response_previewform.html index fbc3de5a8..285510150 100644 --- a/libpathod/templates/response_previewform.html +++ b/libpathod/templates/response_previewform.html @@ -1,5 +1,5 @@ - 100 ASCII bytes as the body, randomly generated 100k header name, with the value 'foo'. + + + ws + + A websocket connection acceptance handshake. + diff --git a/test/test_language.py b/test/test_language.py index 4dd3d8acc..28e26e109 100644 --- a/test/test_language.py +++ b/test/test_language.py @@ -101,7 +101,7 @@ class TestValueGenerate: def test_freeze(self): v = language.ValueGenerate(100, "b", "ascii") - f = v.freeze({}) + f = v.freeze(language.Settings()) assert len(f.val) == 100 @@ -121,16 +121,26 @@ class TestValueFile: with open(p, "wb") as f: f.write("x" * 10000) - assert v.get_generator(dict(staticdir=t)) + assert v.get_generator(language.Settings(staticdir=t)) v = language.Value.parseString(" 100 def test_path_generator(self): - r = parse_request("GET:@100").freeze({}) + r = parse_request("GET:@100").freeze(language.Settings()) assert len(r.spec()) > 100 def test_websocket(self): - r = parse_request('ws:"/foo"') - res = r.resolve({}) - assert utils.get_header("upgrade", res.headers) + r = parse_request('ws:/path/') + res = r.resolve(language.Settings()) + assert res.method.string().lower() == "get" + assert res.tok(language.Path).value.val == "/path/" + assert res.tok(language.Method).value.val.lower() == "get" + assert utils.get_header("Upgrade", res.headers).value.val == "websocket" + + r = parse_request('ws:put:/path/') + res = r.resolve(language.Settings()) + assert r.method.string().lower() == "put" + assert res.tok(language.Path).value.val == "/path/" + assert res.tok(language.Method).value.val.lower() == "put" + assert utils.get_header("Upgrade", res.headers).value.val == "websocket" + class TestWriteValues: @@ -725,13 +751,13 @@ class TestResponse: r = language.parse_response("400:b'foo':r") language.serve(r, s, {}) v = s.getvalue() - assert not "Content-Length" in v + assert "Content-Length" not in v def test_length(self): def testlen(x): s = cStringIO.StringIO() - language.serve(x, s, {}) - assert x.length({}) == len(s.getvalue()) + language.serve(x, s, language.Settings()) + assert x.length(language.Settings()) == len(s.getvalue()) testlen(language.parse_response("400:m'msg':r")) testlen(language.parse_response("400:m'msg':h'foo'='bar':r")) testlen(language.parse_response("400:m'msg':h'foo'='bar':b@100b:r")) @@ -797,6 +823,12 @@ class TestResponse: rt("400") rt("400:da") + def test_websockets(self): + r = language.parse_response("ws") + tutils.raises("no websocket key", r.resolve, language.Settings()) + res = r.resolve(language.Settings(websocket_key="foo")) + assert res.code.string() == "101" + def test_read_file(): tutils.raises(language.FileAccessDenied, language.read_file, {}, "=/foo") diff --git a/test/test_pathoc.py b/test/test_pathoc.py index 520752319..e14450b93 100644 --- a/test/test_pathoc.py +++ b/test/test_pathoc.py @@ -74,7 +74,7 @@ class _TestDaemon: for i in requests: r = language.parse_requests(i)[0] if explain: - r = r.freeze({}) + r = r.freeze(language.Settings()) try: c.request(r) except (http.HttpError, tcp.NetLibError), v: diff --git a/test/test_pathod.py b/test/test_pathod.py index c32f6e84c..00634e270 100644 --- a/test/test_pathod.py +++ b/test/test_pathod.py @@ -13,7 +13,7 @@ class TestPathod(object): p.clear_log() assert len(p.get_log()) == 0 - for i in range(p.LOGBUF + 1): + for _ in range(p.LOGBUF + 1): p.add_log(dict(s="foo")) assert len(p.get_log()) <= p.LOGBUF