From 60693e011cc0602d83a29ac37fa7263eda107f61 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Wed, 5 Jun 2013 21:43:16 -0400 Subject: [PATCH] Catch StreamClosedErrors in WebSocketHandler and abort. When the stream is closed with buffered data, the close callback won't be run until all buffered data is consumed, but any attempt to write to the stream will fail, as will reading past the end of the buffer. This requires a try/except around each read or write, analogous to the one introduced in HTTPServer in commit 3258726f. Closes #604. Closes #661. --- tornado/test/websocket_test.py | 20 +++++++++++- tornado/websocket.py | 59 ++++++++++++++++++++++------------ 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index f416fea4..0c5a4747 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -1,3 +1,5 @@ +from tornado.concurrent import Future +from tornado import gen from tornado.httpclient import HTTPError from tornado.log import gen_log from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog @@ -6,9 +8,15 @@ from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketErro class EchoHandler(WebSocketHandler): + def initialize(self, close_future): + self.close_future = close_future + def on_message(self, message): self.write_message(message, isinstance(message, bytes)) + def on_close(self): + self.close_future.set_result(None) + class NonWebSocketHandler(RequestHandler): def get(self): @@ -17,8 +25,9 @@ class NonWebSocketHandler(RequestHandler): class WebSocketTest(AsyncHTTPTestCase): def get_app(self): + self.close_future = Future() return Application([ - ('/echo', EchoHandler), + ('/echo', EchoHandler, dict(close_future=self.close_future)), ('/non_ws', NonWebSocketHandler), ]) @@ -67,3 +76,12 @@ class WebSocketTest(AsyncHTTPTestCase): io_loop=self.io_loop, connect_timeout=0.01) self.assertEqual(cm.exception.code, 599) + + @gen_test + def test_websocket_close_buffered_data(self): + ws = yield websocket_connect( + 'ws://localhost:%d/echo' % self.get_http_port()) + ws.write_message('hello') + ws.write_message('world') + ws.stream.close() + yield self.close_future diff --git a/tornado/websocket.py b/tornado/websocket.py index 8435e28a..1eef4019 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -35,6 +35,7 @@ from tornado.concurrent import Future from tornado.escape import utf8, native_str from tornado import httpclient from tornado.ioloop import IOLoop +from tornado.iostream import StreamClosedError from tornado.log import gen_log, app_log from tornado.netutil import Resolver from tornado import simple_httpclient @@ -588,7 +589,10 @@ class WebSocketProtocol13(WebSocketProtocol): opcode = 0x1 message = tornado.escape.utf8(message) assert isinstance(message, bytes_type) - self._write_frame(True, opcode, message) + try: + self._write_frame(True, opcode, message) + except StreamClosedError: + self._abort() def write_ping(self, data): """Send ping frame.""" @@ -596,7 +600,10 @@ class WebSocketProtocol13(WebSocketProtocol): self._write_frame(True, 0x9, data) def _receive_frame(self): - self.stream.read_bytes(2, self._on_frame_start) + try: + self.stream.read_bytes(2, self._on_frame_start) + except StreamClosedError: + self._abort() def _on_frame_start(self, data): header, payloadlen = struct.unpack("BB", data) @@ -614,34 +621,46 @@ class WebSocketProtocol13(WebSocketProtocol): # control frames must have payload < 126 self._abort() return - if payloadlen < 126: - self._frame_length = payloadlen + try: + if payloadlen < 126: + self._frame_length = payloadlen + if self._masked_frame: + self.stream.read_bytes(4, self._on_masking_key) + else: + self.stream.read_bytes(self._frame_length, self._on_frame_data) + elif payloadlen == 126: + self.stream.read_bytes(2, self._on_frame_length_16) + elif payloadlen == 127: + self.stream.read_bytes(8, self._on_frame_length_64) + except StreamClosedError: + self._abort() + + def _on_frame_length_16(self, data): + self._frame_length = struct.unpack("!H", data)[0] + try: if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: self.stream.read_bytes(self._frame_length, self._on_frame_data) - elif payloadlen == 126: - self.stream.read_bytes(2, self._on_frame_length_16) - elif payloadlen == 127: - self.stream.read_bytes(8, self._on_frame_length_64) - - def _on_frame_length_16(self, data): - self._frame_length = struct.unpack("!H", data)[0] - if self._masked_frame: - self.stream.read_bytes(4, self._on_masking_key) - else: - self.stream.read_bytes(self._frame_length, self._on_frame_data) + except StreamClosedError: + self._abort() def _on_frame_length_64(self, data): self._frame_length = struct.unpack("!Q", data)[0] - if self._masked_frame: - self.stream.read_bytes(4, self._on_masking_key) - else: - self.stream.read_bytes(self._frame_length, self._on_frame_data) + try: + if self._masked_frame: + self.stream.read_bytes(4, self._on_masking_key) + else: + self.stream.read_bytes(self._frame_length, self._on_frame_data) + except StreamClosedError: + self._abort() def _on_masking_key(self, data): self._frame_mask = data - self.stream.read_bytes(self._frame_length, self._on_masked_frame_data) + try: + self.stream.read_bytes(self._frame_length, self._on_masked_frame_data) + except StreamClosedError: + self._abort() def _apply_mask(self, mask, data): mask = array.array("B", mask)