From 66b06d7cf351ff0ce2c70bf1cf29a033e17b56f3 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sun, 15 Jun 2014 23:35:02 -0400 Subject: [PATCH] Relax restrictions on HTTP methods in WebSocketHandler. Methods like set_status are now disallowed once the websocket handshake has begun, but may be used before then. This applies to application overrides of prepare() and to WebSocketHandler.get's internal error handling. Closes #1065. --- tornado/test/websocket_test.py | 12 ++++++++++ tornado/websocket.py | 40 ++++++++++++++++------------------ 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index fd0b08ca..7b3c34ce 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -47,6 +47,13 @@ class EchoHandler(TestWebSocketHandler): class HeaderHandler(TestWebSocketHandler): def open(self): + try: + # In a websocket context, many RequestHandler methods + # raise RuntimeErrors. + self.set_status(503) + raise Exception("did not get expected exception") + except RuntimeError: + pass self.write_message(self.request.headers.get('X-Test', '')) @@ -71,6 +78,11 @@ class WebSocketTest(AsyncHTTPTestCase): dict(close_future=self.close_future)), ]) + def test_http_request(self): + # WS server, HTTP client. + response = self.fetch('/echo') + self.assertEqual(response.code, 400) + @gen_test def test_websocket_gen(self): ws = yield websocket_connect( diff --git a/tornado/websocket.py b/tornado/websocket.py index 2704c26c..19196b88 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -115,22 +115,17 @@ class WebSocketHandler(tornado.web.RequestHandler): self.ws_connection = None self.close_code = None self.close_reason = None + self.stream = None @tornado.web.asynchronous def get(self, *args, **kwargs): self.open_args = args self.open_kwargs = kwargs - self.stream = self.request.connection.detach() - self.stream.set_close_callback(self.on_connection_close) - # Upgrade header should be present and should be equal to WebSocket if self.request.headers.get("Upgrade", "").lower() != 'websocket': - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 400 Bad Request\r\n\r\n" - "Can \"Upgrade\" only to \"WebSocket\"." - )) - self.stream.close() + self.set_status(400) + self.finish("Can \"Upgrade\" only to \"WebSocket\".") return # Connection header should be upgrade. Some proxy servers/load balancers @@ -138,11 +133,8 @@ class WebSocketHandler(tornado.web.RequestHandler): headers = self.request.headers connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(",")) if 'upgrade' not in connection: - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 400 Bad Request\r\n\r\n" - "\"Connection\" must be \"Upgrade\"." - )) - self.stream.close() + self.set_status(400) + self.finish("\"Connection\" must be \"Upgrade\".") return # Handle WebSocket Origin naming convention differences @@ -159,12 +151,13 @@ class WebSocketHandler(tornado.web.RequestHandler): # according to check_origin. When the origin is None, we assume it # did not come from a browser and that it can be passed on. if origin is not None and not self.check_origin(origin): - self.stream.write(tornado.escape.utf8( - "HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n" - )) - self.stream.close() + self.set_status(403) + self.finish("Cross origin websockets not allowed") return + self.stream = self.request.connection.detach() + self.stream.set_close_callback(self.on_connection_close) + if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"): self.ws_connection = WebSocketProtocol13(self) self.ws_connection.accept_connection() @@ -346,9 +339,6 @@ class WebSocketHandler(tornado.web.RequestHandler): """ return "wss" if self.request.protocol == "https" else "ws" - def _not_supported(self, *args, **kwargs): - raise Exception("Method not supported for Web Sockets") - def on_connection_close(self): if self.ws_connection: self.ws_connection.on_connection_close() @@ -356,9 +346,17 @@ class WebSocketHandler(tornado.web.RequestHandler): self.on_close() +def _wrap_method(method): + def _disallow_for_websocket(self, *args, **kwargs): + if self.stream is None: + method(self, *args, **kwargs) + else: + raise RuntimeError("Method not supported for Web Sockets") + return _disallow_for_websocket for method in ["write", "redirect", "set_header", "send_error", "set_cookie", "set_status", "flush", "finish"]: - setattr(WebSocketHandler, method, WebSocketHandler._not_supported) + setattr(WebSocketHandler, method, + _wrap_method(getattr(WebSocketHandler, method))) class WebSocketProtocol(object):