Catch StreamClosedErrors in WebSocketHandler and abort.

When the stream is closed with buffered data, the close callback won't
be run until all buffered data is consumed, but any attempt to write
to the stream will fail, as will reading past the end of the buffer.
This requires a try/except around each read or write, analogous to the
one introduced in HTTPServer in commit 3258726f.

Closes #604.
Closes #661.
This commit is contained in:
Ben Darnell 2013-06-05 21:43:16 -04:00
parent b61bc5a79c
commit 60693e011c
2 changed files with 58 additions and 21 deletions

View File

@ -1,3 +1,5 @@
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError
from tornado.log import gen_log
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
@ -6,9 +8,15 @@ from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketErro
class EchoHandler(WebSocketHandler):
def initialize(self, close_future):
self.close_future = close_future
def on_message(self, message):
self.write_message(message, isinstance(message, bytes))
def on_close(self):
self.close_future.set_result(None)
class NonWebSocketHandler(RequestHandler):
def get(self):
@ -17,8 +25,9 @@ class NonWebSocketHandler(RequestHandler):
class WebSocketTest(AsyncHTTPTestCase):
def get_app(self):
self.close_future = Future()
return Application([
('/echo', EchoHandler),
('/echo', EchoHandler, dict(close_future=self.close_future)),
('/non_ws', NonWebSocketHandler),
])
@ -67,3 +76,12 @@ class WebSocketTest(AsyncHTTPTestCase):
io_loop=self.io_loop,
connect_timeout=0.01)
self.assertEqual(cm.exception.code, 599)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect(
'ws://localhost:%d/echo' % self.get_http_port())
ws.write_message('hello')
ws.write_message('world')
ws.stream.close()
yield self.close_future

View File

@ -35,6 +35,7 @@ from tornado.concurrent import Future
from tornado.escape import utf8, native_str
from tornado import httpclient
from tornado.ioloop import IOLoop
from tornado.iostream import StreamClosedError
from tornado.log import gen_log, app_log
from tornado.netutil import Resolver
from tornado import simple_httpclient
@ -588,7 +589,10 @@ class WebSocketProtocol13(WebSocketProtocol):
opcode = 0x1
message = tornado.escape.utf8(message)
assert isinstance(message, bytes_type)
self._write_frame(True, opcode, message)
try:
self._write_frame(True, opcode, message)
except StreamClosedError:
self._abort()
def write_ping(self, data):
"""Send ping frame."""
@ -596,7 +600,10 @@ class WebSocketProtocol13(WebSocketProtocol):
self._write_frame(True, 0x9, data)
def _receive_frame(self):
self.stream.read_bytes(2, self._on_frame_start)
try:
self.stream.read_bytes(2, self._on_frame_start)
except StreamClosedError:
self._abort()
def _on_frame_start(self, data):
header, payloadlen = struct.unpack("BB", data)
@ -614,34 +621,46 @@ class WebSocketProtocol13(WebSocketProtocol):
# control frames must have payload < 126
self._abort()
return
if payloadlen < 126:
self._frame_length = payloadlen
try:
if payloadlen < 126:
self._frame_length = payloadlen
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
elif payloadlen == 126:
self.stream.read_bytes(2, self._on_frame_length_16)
elif payloadlen == 127:
self.stream.read_bytes(8, self._on_frame_length_64)
except StreamClosedError:
self._abort()
def _on_frame_length_16(self, data):
self._frame_length = struct.unpack("!H", data)[0]
try:
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
elif payloadlen == 126:
self.stream.read_bytes(2, self._on_frame_length_16)
elif payloadlen == 127:
self.stream.read_bytes(8, self._on_frame_length_64)
def _on_frame_length_16(self, data):
self._frame_length = struct.unpack("!H", data)[0]
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
except StreamClosedError:
self._abort()
def _on_frame_length_64(self, data):
self._frame_length = struct.unpack("!Q", data)[0]
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
try:
if self._masked_frame:
self.stream.read_bytes(4, self._on_masking_key)
else:
self.stream.read_bytes(self._frame_length, self._on_frame_data)
except StreamClosedError:
self._abort()
def _on_masking_key(self, data):
self._frame_mask = data
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
try:
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
except StreamClosedError:
self._abort()
def _apply_mask(self, mask, data):
mask = array.array("B", mask)