Add maximum message size checks. Closes #151.
This commit is contained in:
parent
e1af2db4ae
commit
1ff27ada49
|
@ -134,7 +134,7 @@ class MuxProcess(object):
|
|||
"""
|
||||
Construct a Router, Broker, and mitogen.unix listener
|
||||
"""
|
||||
self.router = mitogen.master.Router()
|
||||
self.router = mitogen.master.Router(max_message_size=4096*1048576)
|
||||
self.router.responder.whitelist_prefix('ansible')
|
||||
self.router.responder.whitelist_prefix('ansible_mitogen')
|
||||
mitogen.core.listen(self.router.broker, 'shutdown', self.on_broker_shutdown)
|
||||
|
|
|
@ -812,6 +812,12 @@ class Stream(BasicStream):
|
|||
self._input_buf[0][:self.HEADER_LEN],
|
||||
)
|
||||
|
||||
if msg_len > self._router.max_message_size:
|
||||
LOG.error('Maximum message size exceeded (got %d, max %d)',
|
||||
msg_len, self._router.max_message_size)
|
||||
self.on_disconnect(broker)
|
||||
return False
|
||||
|
||||
total_len = msg_len + self.HEADER_LEN
|
||||
if self._input_buf_len < total_len:
|
||||
_vv and IOLOG.debug(
|
||||
|
@ -1191,6 +1197,7 @@ class IoLogger(BasicStream):
|
|||
|
||||
class Router(object):
|
||||
context_class = Context
|
||||
max_message_size = 128 * 1048576
|
||||
|
||||
def __init__(self, broker):
|
||||
self.broker = broker
|
||||
|
@ -1274,6 +1281,11 @@ class Router(object):
|
|||
|
||||
def _async_route(self, msg, stream=None):
|
||||
_vv and IOLOG.debug('%r._async_route(%r, %r)', self, msg, stream)
|
||||
if len(msg.data) > self.max_message_size:
|
||||
LOG.error('message too large (max %d bytes): %r',
|
||||
self.max_message_size, msg)
|
||||
return
|
||||
|
||||
# Perform source verification.
|
||||
if stream is not None:
|
||||
expected_stream = self._stream_by_id.get(msg.auth_id,
|
||||
|
@ -1438,7 +1450,9 @@ class ExternalContext(object):
|
|||
_v and LOG.debug('%r: parent stream is gone, dying.', self)
|
||||
self.broker.shutdown()
|
||||
|
||||
def _setup_master(self, profiling, parent_id, context_id, in_fd, out_fd):
|
||||
def _setup_master(self, max_message_size, profiling, parent_id,
|
||||
context_id, in_fd, out_fd):
|
||||
Router.max_message_size = max_message_size
|
||||
self.profiling = profiling
|
||||
if profiling:
|
||||
enable_profiling()
|
||||
|
@ -1571,9 +1585,11 @@ class ExternalContext(object):
|
|||
self.dispatch_stopped = True
|
||||
|
||||
def main(self, parent_ids, context_id, debug, profiling, log_level,
|
||||
in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True,
|
||||
setup_package=True, importer=None, whitelist=(), blacklist=()):
|
||||
self._setup_master(profiling, parent_ids[0], context_id, in_fd, out_fd)
|
||||
max_message_size, in_fd=100, out_fd=1, core_src_fd=101,
|
||||
setup_stdio=True, setup_package=True, importer=None,
|
||||
whitelist=(), blacklist=()):
|
||||
self._setup_master(max_message_size, profiling, parent_ids[0],
|
||||
context_id, in_fd, out_fd)
|
||||
try:
|
||||
try:
|
||||
self._setup_logging(debug, log_level)
|
||||
|
|
|
@ -343,14 +343,15 @@ def run(dest, router, args, deadline=None, econtext=None):
|
|||
fp.write(inspect.getsource(mitogen.core))
|
||||
fp.write('\n')
|
||||
fp.write('ExternalContext().main(**%r)\n' % ({
|
||||
'parent_ids': parent_ids,
|
||||
'context_id': context_id,
|
||||
'debug': getattr(router, 'debug', False),
|
||||
'profiling': getattr(router, 'profiling', False),
|
||||
'log_level': mitogen.parent.get_log_level(),
|
||||
'in_fd': sock2.fileno(),
|
||||
'out_fd': sock2.fileno(),
|
||||
'core_src_fd': None,
|
||||
'debug': getattr(router, 'debug', False),
|
||||
'in_fd': sock2.fileno(),
|
||||
'log_level': mitogen.parent.get_log_level(),
|
||||
'max_message_size': router.max_message_size,
|
||||
'out_fd': sock2.fileno(),
|
||||
'parent_ids': parent_ids,
|
||||
'profiling': getattr(router, 'profiling', False),
|
||||
'setup_stdio': False,
|
||||
},))
|
||||
finally:
|
||||
|
|
|
@ -85,9 +85,11 @@ class Stream(mitogen.parent.Stream):
|
|||
#: User-supplied function for cleaning up child process state.
|
||||
on_fork = None
|
||||
|
||||
def construct(self, old_router, on_fork=None, debug=False, profiling=False):
|
||||
def construct(self, old_router, max_message_size, on_fork=None,
|
||||
debug=False, profiling=False):
|
||||
# fork method only supports a tiny subset of options.
|
||||
super(Stream, self).construct(debug=debug, profiling=profiling)
|
||||
super(Stream, self).construct(max_message_size=max_message_size,
|
||||
debug=debug, profiling=profiling)
|
||||
self.on_fork = on_fork
|
||||
|
||||
responder = getattr(old_router, 'responder', None)
|
||||
|
|
|
@ -646,9 +646,11 @@ class Router(mitogen.parent.Router):
|
|||
debug = False
|
||||
profiling = False
|
||||
|
||||
def __init__(self, broker=None):
|
||||
def __init__(self, broker=None, max_message_size=None):
|
||||
if broker is None:
|
||||
broker = self.broker_class()
|
||||
if max_message_size:
|
||||
self.max_message_size = max_message_size
|
||||
super(Router, self).__init__(broker)
|
||||
self.upgrade()
|
||||
|
||||
|
|
|
@ -337,6 +337,10 @@ class Stream(mitogen.core.Stream):
|
|||
#: Set to the child's PID by connect().
|
||||
pid = None
|
||||
|
||||
#: Passed via Router wrapper methods, must eventually be passed to
|
||||
#: ExternalContext.main().
|
||||
max_message_size = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Stream, self).__init__(*args, **kwargs)
|
||||
self.sent_modules = set(['mitogen', 'mitogen.core'])
|
||||
|
@ -344,12 +348,13 @@ class Stream(mitogen.core.Stream):
|
|||
#: during disconnection.
|
||||
self.routes = set([self.remote_id])
|
||||
|
||||
def construct(self, remote_name=None, python_path=None, debug=False,
|
||||
connect_timeout=None, profiling=False,
|
||||
def construct(self, max_message_size, remote_name=None, python_path=None,
|
||||
debug=False, connect_timeout=None, profiling=False,
|
||||
old_router=None, **kwargs):
|
||||
"""Get the named context running on the local machine, creating it if
|
||||
it does not exist."""
|
||||
super(Stream, self).construct(**kwargs)
|
||||
self.max_message_size = max_message_size
|
||||
if python_path:
|
||||
self.python_path = python_path
|
||||
if sys.platform == 'darwin' and self.python_path == '/usr/bin/python':
|
||||
|
@ -367,6 +372,7 @@ class Stream(mitogen.core.Stream):
|
|||
self.remote_name = remote_name
|
||||
self.debug = debug
|
||||
self.profiling = profiling
|
||||
self.max_message_size = max_message_size
|
||||
self.connect_deadline = time.time() + self.connect_timeout
|
||||
|
||||
def on_shutdown(self, broker):
|
||||
|
@ -441,6 +447,7 @@ class Stream(mitogen.core.Stream):
|
|||
]
|
||||
|
||||
def get_main_kwargs(self):
|
||||
assert self.max_message_size is not None
|
||||
parent_ids = mitogen.parent_ids[:]
|
||||
parent_ids.insert(0, mitogen.context_id)
|
||||
return {
|
||||
|
@ -451,6 +458,7 @@ class Stream(mitogen.core.Stream):
|
|||
'log_level': get_log_level(),
|
||||
'whitelist': self._router.get_module_whitelist(),
|
||||
'blacklist': self._router.get_module_blacklist(),
|
||||
'max_message_size': self.max_message_size,
|
||||
}
|
||||
|
||||
def get_preamble(self):
|
||||
|
@ -703,7 +711,9 @@ class Router(mitogen.core.Router):
|
|||
def _connect(self, klass, name=None, **kwargs):
|
||||
context_id = self.allocate_id()
|
||||
context = self.context_class(self, context_id)
|
||||
stream = klass(self, context_id, old_router=self, **kwargs)
|
||||
kwargs['old_router'] = self
|
||||
kwargs['max_message_size'] = self.max_message_size
|
||||
stream = klass(self, context_id, **kwargs)
|
||||
if name is not None:
|
||||
stream.name = name
|
||||
stream.connect()
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import Queue
|
||||
import StringIO
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
|
@ -8,7 +10,16 @@ import testlib
|
|||
import mitogen.master
|
||||
import mitogen.utils
|
||||
|
||||
mitogen.utils.log_to_file()
|
||||
|
||||
@mitogen.core.takes_router
|
||||
def return_router_max_message_size(router):
|
||||
return router.max_message_size
|
||||
|
||||
|
||||
def send_n_sized_reply(sender, n):
|
||||
sender.send(' ' * n)
|
||||
return 123
|
||||
|
||||
|
||||
class AddHandlerTest(unittest2.TestCase):
|
||||
klass = mitogen.master.Router
|
||||
|
@ -21,6 +32,44 @@ class AddHandlerTest(unittest2.TestCase):
|
|||
self.assertEquals(queue.get(timeout=5), mitogen.core._DEAD)
|
||||
|
||||
|
||||
class MessageSizeTest(testlib.BrokerMixin, unittest2.TestCase):
|
||||
klass = mitogen.master.Router
|
||||
|
||||
def test_local_exceeded(self):
|
||||
router = self.klass(broker=self.broker, max_message_size=4096)
|
||||
recv = mitogen.core.Receiver(router)
|
||||
|
||||
logs = testlib.LogCapturer()
|
||||
logs.start()
|
||||
|
||||
sem = mitogen.core.Latch()
|
||||
router.route(mitogen.core.Message.pickled(' '*8192))
|
||||
router.broker.defer(sem.put, ' ') # wlil always run after _async_route
|
||||
sem.get()
|
||||
|
||||
expect = 'message too large (max 4096 bytes)'
|
||||
self.assertTrue(expect in logs.stop())
|
||||
|
||||
def test_remote_configured(self):
|
||||
router = self.klass(broker=self.broker, max_message_size=4096)
|
||||
remote = router.fork()
|
||||
size = remote.call(return_router_max_message_size)
|
||||
self.assertEquals(size, 4096)
|
||||
|
||||
def test_remote_exceeded(self):
|
||||
# Ensure new contexts receive a router with the same value.
|
||||
router = self.klass(broker=self.broker, max_message_size=4096)
|
||||
recv = mitogen.core.Receiver(router)
|
||||
|
||||
logs = testlib.LogCapturer()
|
||||
logs.start()
|
||||
|
||||
remote = router.fork()
|
||||
remote.call(send_n_sized_reply, recv.to_sender(), 8192)
|
||||
|
||||
expect = 'message too large (max 4096 bytes)'
|
||||
self.assertTrue(expect in logs.stop())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest2.main()
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
|
||||
import StringIO
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
@ -113,6 +115,24 @@ def wait_for_port(
|
|||
% (host, port))
|
||||
|
||||
|
||||
class LogCapturer(object):
|
||||
def __init__(self, name=None):
|
||||
self.sio = StringIO.StringIO()
|
||||
self.logger = logging.getLogger(name)
|
||||
self.handler = logging.StreamHandler(self.sio)
|
||||
self.old_propagate = self.logger.propagate
|
||||
self.old_handlers = self.logger.handlers
|
||||
|
||||
def start(self):
|
||||
self.logger.handlers = [self.handler]
|
||||
self.logger.propagate = False
|
||||
|
||||
def stop(self):
|
||||
self.logger.handlers = self.old_handlers
|
||||
self.logger.propagate = self.old_propagate
|
||||
return self.sio.getvalue()
|
||||
|
||||
|
||||
class TestCase(unittest2.TestCase):
|
||||
def assertRaises(self, exc, func, *args, **kwargs):
|
||||
"""Like regular assertRaises, except return the exception that was
|
||||
|
@ -156,19 +176,25 @@ class DockerizedSshDaemon(object):
|
|||
self.container.remove()
|
||||
|
||||
|
||||
class RouterMixin(object):
|
||||
class BrokerMixin(object):
|
||||
broker_class = mitogen.master.Broker
|
||||
router_class = mitogen.master.Router
|
||||
|
||||
def setUp(self):
|
||||
super(RouterMixin, self).setUp()
|
||||
super(BrokerMixin, self).setUp()
|
||||
self.broker = self.broker_class()
|
||||
self.router = self.router_class(self.broker)
|
||||
|
||||
def tearDown(self):
|
||||
self.broker.shutdown()
|
||||
self.broker.join()
|
||||
super(RouterMixin, self).tearDown()
|
||||
super(BrokerMixin, self).tearDown()
|
||||
|
||||
|
||||
class RouterMixin(BrokerMixin):
|
||||
router_class = mitogen.master.Router
|
||||
|
||||
def setUp(self):
|
||||
super(RouterMixin, self).setUp()
|
||||
self.router = self.router_class(self.broker)
|
||||
|
||||
|
||||
class DockerMixin(RouterMixin):
|
||||
|
|
Loading…
Reference in New Issue