Relax restrictions on HTTP methods in WebSocketHandler.

Methods like set_status are now disallowed once the websocket handshake
has begun, but may be used before then.  This applies to application
overrides of prepare() and to WebSocketHandler.get's internal error
handling.

Closes #1065.
This commit is contained in:
Ben Darnell 2014-06-15 23:35:02 -04:00
parent d100cdf71d
commit 66b06d7cf3
2 changed files with 31 additions and 21 deletions

View File

@ -47,6 +47,13 @@ class EchoHandler(TestWebSocketHandler):
class HeaderHandler(TestWebSocketHandler):
def open(self):
try:
# In a websocket context, many RequestHandler methods
# raise RuntimeErrors.
self.set_status(503)
raise Exception("did not get expected exception")
except RuntimeError:
pass
self.write_message(self.request.headers.get('X-Test', ''))
@ -71,6 +78,11 @@ class WebSocketTest(AsyncHTTPTestCase):
dict(close_future=self.close_future)),
])
def test_http_request(self):
# WS server, HTTP client.
response = self.fetch('/echo')
self.assertEqual(response.code, 400)
@gen_test
def test_websocket_gen(self):
ws = yield websocket_connect(

View File

@ -115,22 +115,17 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.ws_connection = None
self.close_code = None
self.close_reason = None
self.stream = None
@tornado.web.asynchronous
def get(self, *args, **kwargs):
self.open_args = args
self.open_kwargs = kwargs
self.stream = self.request.connection.detach()
self.stream.set_close_callback(self.on_connection_close)
# Upgrade header should be present and should be equal to WebSocket
if self.request.headers.get("Upgrade", "").lower() != 'websocket':
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 400 Bad Request\r\n\r\n"
"Can \"Upgrade\" only to \"WebSocket\"."
))
self.stream.close()
self.set_status(400)
self.finish("Can \"Upgrade\" only to \"WebSocket\".")
return
# Connection header should be upgrade. Some proxy servers/load balancers
@ -138,11 +133,8 @@ class WebSocketHandler(tornado.web.RequestHandler):
headers = self.request.headers
connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
if 'upgrade' not in connection:
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 400 Bad Request\r\n\r\n"
"\"Connection\" must be \"Upgrade\"."
))
self.stream.close()
self.set_status(400)
self.finish("\"Connection\" must be \"Upgrade\".")
return
# Handle WebSocket Origin naming convention differences
@ -159,12 +151,13 @@ class WebSocketHandler(tornado.web.RequestHandler):
# according to check_origin. When the origin is None, we assume it
# did not come from a browser and that it can be passed on.
if origin is not None and not self.check_origin(origin):
self.stream.write(tornado.escape.utf8(
"HTTP/1.1 403 Cross Origin Websockets Disabled\r\n\r\n"
))
self.stream.close()
self.set_status(403)
self.finish("Cross origin websockets not allowed")
return
self.stream = self.request.connection.detach()
self.stream.set_close_callback(self.on_connection_close)
if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
self.ws_connection = WebSocketProtocol13(self)
self.ws_connection.accept_connection()
@ -346,9 +339,6 @@ class WebSocketHandler(tornado.web.RequestHandler):
"""
return "wss" if self.request.protocol == "https" else "ws"
def _not_supported(self, *args, **kwargs):
raise Exception("Method not supported for Web Sockets")
def on_connection_close(self):
if self.ws_connection:
self.ws_connection.on_connection_close()
@ -356,9 +346,17 @@ class WebSocketHandler(tornado.web.RequestHandler):
self.on_close()
def _wrap_method(method):
def _disallow_for_websocket(self, *args, **kwargs):
if self.stream is None:
method(self, *args, **kwargs)
else:
raise RuntimeError("Method not supported for Web Sockets")
return _disallow_for_websocket
for method in ["write", "redirect", "set_header", "send_error", "set_cookie",
"set_status", "flush", "finish"]:
setattr(WebSocketHandler, method, WebSocketHandler._not_supported)
setattr(WebSocketHandler, method,
_wrap_method(getattr(WebSocketHandler, method)))
class WebSocketProtocol(object):