diff --git a/ansible_mitogen/process.py b/ansible_mitogen/process.py index 1880f769..eb9bd2ff 100644 --- a/ansible_mitogen/process.py +++ b/ansible_mitogen/process.py @@ -134,7 +134,7 @@ class MuxProcess(object): """ Construct a Router, Broker, and mitogen.unix listener """ - self.router = mitogen.master.Router() + self.router = mitogen.master.Router(max_message_size=4096*1048576) self.router.responder.whitelist_prefix('ansible') self.router.responder.whitelist_prefix('ansible_mitogen') mitogen.core.listen(self.router.broker, 'shutdown', self.on_broker_shutdown) diff --git a/mitogen/core.py b/mitogen/core.py index 01d343d6..29e04a2a 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -812,6 +812,12 @@ class Stream(BasicStream): self._input_buf[0][:self.HEADER_LEN], ) + if msg_len > self._router.max_message_size: + LOG.error('Maximum message size exceeded (got %d, max %d)', + msg_len, self._router.max_message_size) + self.on_disconnect(broker) + return False + total_len = msg_len + self.HEADER_LEN if self._input_buf_len < total_len: _vv and IOLOG.debug( @@ -1191,6 +1197,7 @@ class IoLogger(BasicStream): class Router(object): context_class = Context + max_message_size = 128 * 1048576 def __init__(self, broker): self.broker = broker @@ -1274,6 +1281,11 @@ class Router(object): def _async_route(self, msg, stream=None): _vv and IOLOG.debug('%r._async_route(%r, %r)', self, msg, stream) + if len(msg.data) > self.max_message_size: + LOG.error('message too large (max %d bytes): %r', + self.max_message_size, msg) + return + # Perform source verification. if stream is not None: expected_stream = self._stream_by_id.get(msg.auth_id, @@ -1438,7 +1450,9 @@ class ExternalContext(object): _v and LOG.debug('%r: parent stream is gone, dying.', self) self.broker.shutdown() - def _setup_master(self, profiling, parent_id, context_id, in_fd, out_fd): + def _setup_master(self, max_message_size, profiling, parent_id, + context_id, in_fd, out_fd): + Router.max_message_size = max_message_size self.profiling = profiling if profiling: enable_profiling() @@ -1571,9 +1585,11 @@ class ExternalContext(object): self.dispatch_stopped = True def main(self, parent_ids, context_id, debug, profiling, log_level, - in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True, - setup_package=True, importer=None, whitelist=(), blacklist=()): - self._setup_master(profiling, parent_ids[0], context_id, in_fd, out_fd) + max_message_size, in_fd=100, out_fd=1, core_src_fd=101, + setup_stdio=True, setup_package=True, importer=None, + whitelist=(), blacklist=()): + self._setup_master(max_message_size, profiling, parent_ids[0], + context_id, in_fd, out_fd) try: try: self._setup_logging(debug, log_level) diff --git a/mitogen/fakessh.py b/mitogen/fakessh.py index f5dcbe1c..e07916ad 100644 --- a/mitogen/fakessh.py +++ b/mitogen/fakessh.py @@ -343,14 +343,15 @@ def run(dest, router, args, deadline=None, econtext=None): fp.write(inspect.getsource(mitogen.core)) fp.write('\n') fp.write('ExternalContext().main(**%r)\n' % ({ - 'parent_ids': parent_ids, 'context_id': context_id, - 'debug': getattr(router, 'debug', False), - 'profiling': getattr(router, 'profiling', False), - 'log_level': mitogen.parent.get_log_level(), - 'in_fd': sock2.fileno(), - 'out_fd': sock2.fileno(), 'core_src_fd': None, + 'debug': getattr(router, 'debug', False), + 'in_fd': sock2.fileno(), + 'log_level': mitogen.parent.get_log_level(), + 'max_message_size': router.max_message_size, + 'out_fd': sock2.fileno(), + 'parent_ids': parent_ids, + 'profiling': getattr(router, 'profiling', False), 'setup_stdio': False, },)) finally: diff --git a/mitogen/fork.py b/mitogen/fork.py index 36e4f0ca..e4a8625a 100644 --- a/mitogen/fork.py +++ b/mitogen/fork.py @@ -85,9 +85,11 @@ class Stream(mitogen.parent.Stream): #: User-supplied function for cleaning up child process state. on_fork = None - def construct(self, old_router, on_fork=None, debug=False, profiling=False): + def construct(self, old_router, max_message_size, on_fork=None, + debug=False, profiling=False): # fork method only supports a tiny subset of options. - super(Stream, self).construct(debug=debug, profiling=profiling) + super(Stream, self).construct(max_message_size=max_message_size, + debug=debug, profiling=profiling) self.on_fork = on_fork responder = getattr(old_router, 'responder', None) diff --git a/mitogen/master.py b/mitogen/master.py index 4359a732..0cf5d451 100644 --- a/mitogen/master.py +++ b/mitogen/master.py @@ -646,9 +646,11 @@ class Router(mitogen.parent.Router): debug = False profiling = False - def __init__(self, broker=None): + def __init__(self, broker=None, max_message_size=None): if broker is None: broker = self.broker_class() + if max_message_size: + self.max_message_size = max_message_size super(Router, self).__init__(broker) self.upgrade() diff --git a/mitogen/parent.py b/mitogen/parent.py index 8a9a186a..599a4603 100644 --- a/mitogen/parent.py +++ b/mitogen/parent.py @@ -337,6 +337,10 @@ class Stream(mitogen.core.Stream): #: Set to the child's PID by connect(). pid = None + #: Passed via Router wrapper methods, must eventually be passed to + #: ExternalContext.main(). + max_message_size = None + def __init__(self, *args, **kwargs): super(Stream, self).__init__(*args, **kwargs) self.sent_modules = set(['mitogen', 'mitogen.core']) @@ -344,12 +348,13 @@ class Stream(mitogen.core.Stream): #: during disconnection. self.routes = set([self.remote_id]) - def construct(self, remote_name=None, python_path=None, debug=False, - connect_timeout=None, profiling=False, + def construct(self, max_message_size, remote_name=None, python_path=None, + debug=False, connect_timeout=None, profiling=False, old_router=None, **kwargs): """Get the named context running on the local machine, creating it if it does not exist.""" super(Stream, self).construct(**kwargs) + self.max_message_size = max_message_size if python_path: self.python_path = python_path if sys.platform == 'darwin' and self.python_path == '/usr/bin/python': @@ -367,6 +372,7 @@ class Stream(mitogen.core.Stream): self.remote_name = remote_name self.debug = debug self.profiling = profiling + self.max_message_size = max_message_size self.connect_deadline = time.time() + self.connect_timeout def on_shutdown(self, broker): @@ -441,6 +447,7 @@ class Stream(mitogen.core.Stream): ] def get_main_kwargs(self): + assert self.max_message_size is not None parent_ids = mitogen.parent_ids[:] parent_ids.insert(0, mitogen.context_id) return { @@ -451,6 +458,7 @@ class Stream(mitogen.core.Stream): 'log_level': get_log_level(), 'whitelist': self._router.get_module_whitelist(), 'blacklist': self._router.get_module_blacklist(), + 'max_message_size': self.max_message_size, } def get_preamble(self): @@ -703,7 +711,9 @@ class Router(mitogen.core.Router): def _connect(self, klass, name=None, **kwargs): context_id = self.allocate_id() context = self.context_class(self, context_id) - stream = klass(self, context_id, old_router=self, **kwargs) + kwargs['old_router'] = self + kwargs['max_message_size'] = self.max_message_size + stream = klass(self, context_id, **kwargs) if name is not None: stream.name = name stream.connect() diff --git a/tests/router_test.py b/tests/router_test.py index 3f460c3f..c6b4e2df 100644 --- a/tests/router_test.py +++ b/tests/router_test.py @@ -1,4 +1,6 @@ import Queue +import StringIO +import logging import subprocess import time @@ -8,7 +10,16 @@ import testlib import mitogen.master import mitogen.utils -mitogen.utils.log_to_file() + +@mitogen.core.takes_router +def return_router_max_message_size(router): + return router.max_message_size + + +def send_n_sized_reply(sender, n): + sender.send(' ' * n) + return 123 + class AddHandlerTest(unittest2.TestCase): klass = mitogen.master.Router @@ -21,6 +32,44 @@ class AddHandlerTest(unittest2.TestCase): self.assertEquals(queue.get(timeout=5), mitogen.core._DEAD) +class MessageSizeTest(testlib.BrokerMixin, unittest2.TestCase): + klass = mitogen.master.Router + + def test_local_exceeded(self): + router = self.klass(broker=self.broker, max_message_size=4096) + recv = mitogen.core.Receiver(router) + + logs = testlib.LogCapturer() + logs.start() + + sem = mitogen.core.Latch() + router.route(mitogen.core.Message.pickled(' '*8192)) + router.broker.defer(sem.put, ' ') # wlil always run after _async_route + sem.get() + + expect = 'message too large (max 4096 bytes)' + self.assertTrue(expect in logs.stop()) + + def test_remote_configured(self): + router = self.klass(broker=self.broker, max_message_size=4096) + remote = router.fork() + size = remote.call(return_router_max_message_size) + self.assertEquals(size, 4096) + + def test_remote_exceeded(self): + # Ensure new contexts receive a router with the same value. + router = self.klass(broker=self.broker, max_message_size=4096) + recv = mitogen.core.Receiver(router) + + logs = testlib.LogCapturer() + logs.start() + + remote = router.fork() + remote.call(send_n_sized_reply, recv.to_sender(), 8192) + + expect = 'message too large (max 4096 bytes)' + self.assertTrue(expect in logs.stop()) + + if __name__ == '__main__': unittest2.main() - diff --git a/tests/testlib.py b/tests/testlib.py index fd41298b..5d959b8d 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -1,4 +1,6 @@ +import StringIO +import logging import os import random import re @@ -113,6 +115,24 @@ def wait_for_port( % (host, port)) +class LogCapturer(object): + def __init__(self, name=None): + self.sio = StringIO.StringIO() + self.logger = logging.getLogger(name) + self.handler = logging.StreamHandler(self.sio) + self.old_propagate = self.logger.propagate + self.old_handlers = self.logger.handlers + + def start(self): + self.logger.handlers = [self.handler] + self.logger.propagate = False + + def stop(self): + self.logger.handlers = self.old_handlers + self.logger.propagate = self.old_propagate + return self.sio.getvalue() + + class TestCase(unittest2.TestCase): def assertRaises(self, exc, func, *args, **kwargs): """Like regular assertRaises, except return the exception that was @@ -156,19 +176,25 @@ class DockerizedSshDaemon(object): self.container.remove() -class RouterMixin(object): +class BrokerMixin(object): broker_class = mitogen.master.Broker - router_class = mitogen.master.Router def setUp(self): - super(RouterMixin, self).setUp() + super(BrokerMixin, self).setUp() self.broker = self.broker_class() - self.router = self.router_class(self.broker) def tearDown(self): self.broker.shutdown() self.broker.join() - super(RouterMixin, self).tearDown() + super(BrokerMixin, self).tearDown() + + +class RouterMixin(BrokerMixin): + router_class = mitogen.master.Router + + def setUp(self): + super(RouterMixin, self).setUp() + self.router = self.router_class(self.broker) class DockerMixin(RouterMixin):