diff --git a/tornado/test/websocket_test.py b/tornado/test/websocket_test.py index f416fea4..0c5a4747 100644 --- a/tornado/test/websocket_test.py +++ b/tornado/test/websocket_test.py @@ -1,3 +1,5 @@ +from tornado.concurrent import Future +from tornado import gen from tornado.httpclient import HTTPError from tornado.log import gen_log from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog @@ -6,9 +8,15 @@ from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketErro class EchoHandler(WebSocketHandler): + def initialize(self, close_future): + self.close_future = close_future + def on_message(self, message): self.write_message(message, isinstance(message, bytes)) + def on_close(self): + self.close_future.set_result(None) + class NonWebSocketHandler(RequestHandler): def get(self): @@ -17,8 +25,9 @@ class NonWebSocketHandler(RequestHandler): class WebSocketTest(AsyncHTTPTestCase): def get_app(self): + self.close_future = Future() return Application([ - ('/echo', EchoHandler), + ('/echo', EchoHandler, dict(close_future=self.close_future)), ('/non_ws', NonWebSocketHandler), ]) @@ -67,3 +76,12 @@ class WebSocketTest(AsyncHTTPTestCase): io_loop=self.io_loop, connect_timeout=0.01) self.assertEqual(cm.exception.code, 599) + + @gen_test + def test_websocket_close_buffered_data(self): + ws = yield websocket_connect( + 'ws://localhost:%d/echo' % self.get_http_port()) + ws.write_message('hello') + ws.write_message('world') + ws.stream.close() + yield self.close_future diff --git a/tornado/websocket.py b/tornado/websocket.py index 8435e28a..1eef4019 100644 --- a/tornado/websocket.py +++ b/tornado/websocket.py @@ -35,6 +35,7 @@ from tornado.concurrent import Future from tornado.escape import utf8, native_str from tornado import httpclient from tornado.ioloop import IOLoop +from tornado.iostream import StreamClosedError from tornado.log import gen_log, app_log from tornado.netutil import Resolver from tornado import simple_httpclient @@ -588,7 +589,10 @@ class WebSocketProtocol13(WebSocketProtocol): opcode = 0x1 message = tornado.escape.utf8(message) assert isinstance(message, bytes_type) - self._write_frame(True, opcode, message) + try: + self._write_frame(True, opcode, message) + except StreamClosedError: + self._abort() def write_ping(self, data): """Send ping frame.""" @@ -596,7 +600,10 @@ class WebSocketProtocol13(WebSocketProtocol): self._write_frame(True, 0x9, data) def _receive_frame(self): - self.stream.read_bytes(2, self._on_frame_start) + try: + self.stream.read_bytes(2, self._on_frame_start) + except StreamClosedError: + self._abort() def _on_frame_start(self, data): header, payloadlen = struct.unpack("BB", data) @@ -614,34 +621,46 @@ class WebSocketProtocol13(WebSocketProtocol): # control frames must have payload < 126 self._abort() return - if payloadlen < 126: - self._frame_length = payloadlen + try: + if payloadlen < 126: + self._frame_length = payloadlen + if self._masked_frame: + self.stream.read_bytes(4, self._on_masking_key) + else: + self.stream.read_bytes(self._frame_length, self._on_frame_data) + elif payloadlen == 126: + self.stream.read_bytes(2, self._on_frame_length_16) + elif payloadlen == 127: + self.stream.read_bytes(8, self._on_frame_length_64) + except StreamClosedError: + self._abort() + + def _on_frame_length_16(self, data): + self._frame_length = struct.unpack("!H", data)[0] + try: if self._masked_frame: self.stream.read_bytes(4, self._on_masking_key) else: self.stream.read_bytes(self._frame_length, self._on_frame_data) - elif payloadlen == 126: - self.stream.read_bytes(2, self._on_frame_length_16) - elif payloadlen == 127: - self.stream.read_bytes(8, self._on_frame_length_64) - - def _on_frame_length_16(self, data): - self._frame_length = struct.unpack("!H", data)[0] - if self._masked_frame: - self.stream.read_bytes(4, self._on_masking_key) - else: - self.stream.read_bytes(self._frame_length, self._on_frame_data) + except StreamClosedError: + self._abort() def _on_frame_length_64(self, data): self._frame_length = struct.unpack("!Q", data)[0] - if self._masked_frame: - self.stream.read_bytes(4, self._on_masking_key) - else: - self.stream.read_bytes(self._frame_length, self._on_frame_data) + try: + if self._masked_frame: + self.stream.read_bytes(4, self._on_masking_key) + else: + self.stream.read_bytes(self._frame_length, self._on_frame_data) + except StreamClosedError: + self._abort() def _on_masking_key(self, data): self._frame_mask = data - self.stream.read_bytes(self._frame_length, self._on_masked_frame_data) + try: + self.stream.read_bytes(self._frame_length, self._on_masked_frame_data) + except StreamClosedError: + self._abort() def _apply_mask(self, mask, data): mask = array.array("B", mask)