Catch StreamClosedErrors in WebSocketHandler and abort.
When the stream is closed with buffered data, the close callback won't
be run until all buffered data is consumed, but any attempt to write
to the stream will fail, as will reading past the end of the buffer.
This requires a try/except around each read or write, analogous to the
one introduced in HTTPServer in commit 3258726f
.
Closes #604.
Closes #661.
This commit is contained in:
parent
b61bc5a79c
commit
60693e011c
|
@ -1,3 +1,5 @@
|
|||
from tornado.concurrent import Future
|
||||
from tornado import gen
|
||||
from tornado.httpclient import HTTPError
|
||||
from tornado.log import gen_log
|
||||
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
|
||||
|
@ -6,9 +8,15 @@ from tornado.websocket import WebSocketHandler, websocket_connect, WebSocketErro
|
|||
|
||||
|
||||
class EchoHandler(WebSocketHandler):
|
||||
def initialize(self, close_future):
|
||||
self.close_future = close_future
|
||||
|
||||
def on_message(self, message):
|
||||
self.write_message(message, isinstance(message, bytes))
|
||||
|
||||
def on_close(self):
|
||||
self.close_future.set_result(None)
|
||||
|
||||
|
||||
class NonWebSocketHandler(RequestHandler):
|
||||
def get(self):
|
||||
|
@ -17,8 +25,9 @@ class NonWebSocketHandler(RequestHandler):
|
|||
|
||||
class WebSocketTest(AsyncHTTPTestCase):
|
||||
def get_app(self):
|
||||
self.close_future = Future()
|
||||
return Application([
|
||||
('/echo', EchoHandler),
|
||||
('/echo', EchoHandler, dict(close_future=self.close_future)),
|
||||
('/non_ws', NonWebSocketHandler),
|
||||
])
|
||||
|
||||
|
@ -67,3 +76,12 @@ class WebSocketTest(AsyncHTTPTestCase):
|
|||
io_loop=self.io_loop,
|
||||
connect_timeout=0.01)
|
||||
self.assertEqual(cm.exception.code, 599)
|
||||
|
||||
@gen_test
|
||||
def test_websocket_close_buffered_data(self):
|
||||
ws = yield websocket_connect(
|
||||
'ws://localhost:%d/echo' % self.get_http_port())
|
||||
ws.write_message('hello')
|
||||
ws.write_message('world')
|
||||
ws.stream.close()
|
||||
yield self.close_future
|
||||
|
|
|
@ -35,6 +35,7 @@ from tornado.concurrent import Future
|
|||
from tornado.escape import utf8, native_str
|
||||
from tornado import httpclient
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import StreamClosedError
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado.netutil import Resolver
|
||||
from tornado import simple_httpclient
|
||||
|
@ -588,7 +589,10 @@ class WebSocketProtocol13(WebSocketProtocol):
|
|||
opcode = 0x1
|
||||
message = tornado.escape.utf8(message)
|
||||
assert isinstance(message, bytes_type)
|
||||
self._write_frame(True, opcode, message)
|
||||
try:
|
||||
self._write_frame(True, opcode, message)
|
||||
except StreamClosedError:
|
||||
self._abort()
|
||||
|
||||
def write_ping(self, data):
|
||||
"""Send ping frame."""
|
||||
|
@ -596,7 +600,10 @@ class WebSocketProtocol13(WebSocketProtocol):
|
|||
self._write_frame(True, 0x9, data)
|
||||
|
||||
def _receive_frame(self):
|
||||
self.stream.read_bytes(2, self._on_frame_start)
|
||||
try:
|
||||
self.stream.read_bytes(2, self._on_frame_start)
|
||||
except StreamClosedError:
|
||||
self._abort()
|
||||
|
||||
def _on_frame_start(self, data):
|
||||
header, payloadlen = struct.unpack("BB", data)
|
||||
|
@ -614,34 +621,46 @@ class WebSocketProtocol13(WebSocketProtocol):
|
|||
# control frames must have payload < 126
|
||||
self._abort()
|
||||
return
|
||||
if payloadlen < 126:
|
||||
self._frame_length = payloadlen
|
||||
try:
|
||||
if payloadlen < 126:
|
||||
self._frame_length = payloadlen
|
||||
if self._masked_frame:
|
||||
self.stream.read_bytes(4, self._on_masking_key)
|
||||
else:
|
||||
self.stream.read_bytes(self._frame_length, self._on_frame_data)
|
||||
elif payloadlen == 126:
|
||||
self.stream.read_bytes(2, self._on_frame_length_16)
|
||||
elif payloadlen == 127:
|
||||
self.stream.read_bytes(8, self._on_frame_length_64)
|
||||
except StreamClosedError:
|
||||
self._abort()
|
||||
|
||||
def _on_frame_length_16(self, data):
|
||||
self._frame_length = struct.unpack("!H", data)[0]
|
||||
try:
|
||||
if self._masked_frame:
|
||||
self.stream.read_bytes(4, self._on_masking_key)
|
||||
else:
|
||||
self.stream.read_bytes(self._frame_length, self._on_frame_data)
|
||||
elif payloadlen == 126:
|
||||
self.stream.read_bytes(2, self._on_frame_length_16)
|
||||
elif payloadlen == 127:
|
||||
self.stream.read_bytes(8, self._on_frame_length_64)
|
||||
|
||||
def _on_frame_length_16(self, data):
|
||||
self._frame_length = struct.unpack("!H", data)[0]
|
||||
if self._masked_frame:
|
||||
self.stream.read_bytes(4, self._on_masking_key)
|
||||
else:
|
||||
self.stream.read_bytes(self._frame_length, self._on_frame_data)
|
||||
except StreamClosedError:
|
||||
self._abort()
|
||||
|
||||
def _on_frame_length_64(self, data):
|
||||
self._frame_length = struct.unpack("!Q", data)[0]
|
||||
if self._masked_frame:
|
||||
self.stream.read_bytes(4, self._on_masking_key)
|
||||
else:
|
||||
self.stream.read_bytes(self._frame_length, self._on_frame_data)
|
||||
try:
|
||||
if self._masked_frame:
|
||||
self.stream.read_bytes(4, self._on_masking_key)
|
||||
else:
|
||||
self.stream.read_bytes(self._frame_length, self._on_frame_data)
|
||||
except StreamClosedError:
|
||||
self._abort()
|
||||
|
||||
def _on_masking_key(self, data):
|
||||
self._frame_mask = data
|
||||
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
|
||||
try:
|
||||
self.stream.read_bytes(self._frame_length, self._on_masked_frame_data)
|
||||
except StreamClosedError:
|
||||
self._abort()
|
||||
|
||||
def _apply_mask(self, mask, data):
|
||||
mask = array.array("B", mask)
|
||||
|
|
Loading…
Reference in New Issue