diff --git a/ansible_mitogen/connection.py b/ansible_mitogen/connection.py index 5775bbed..a8bb74c7 100644 --- a/ansible_mitogen/connection.py +++ b/ansible_mitogen/connection.py @@ -29,6 +29,7 @@ from __future__ import absolute_import from __future__ import unicode_literals +import errno import logging import os import pprint @@ -1007,6 +1008,11 @@ class Connection(ansible.plugins.connection.ConnectionBase): #: slightly more overhead, so just randomly subtract 4KiB. SMALL_FILE_LIMIT = mitogen.core.CHUNK_SIZE - 4096 + def _throw_io_error(self, e, path): + if e.args[0] == errno.ENOENT: + s = 'file or module does not exist: ' + path + raise ansible.errors.AnsibleFileNotFound(s) + def put_file(self, in_path, out_path): """ Implement put_file() by streamily transferring the file via @@ -1017,7 +1023,12 @@ class Connection(ansible.plugins.connection.ConnectionBase): :param str out_path: Remote filesystem path to write. """ - st = os.stat(in_path) + try: + st = os.stat(in_path) + except OSError as e: + self._throw_io_error(e, in_path) + raise + if not stat.S_ISREG(st.st_mode): raise IOError('%r is not a regular file.' % (in_path,)) @@ -1025,17 +1036,22 @@ class Connection(ansible.plugins.connection.ConnectionBase): # rather than introducing an extra RTT for the child to request it from # FileService. if st.st_size <= self.SMALL_FILE_LIMIT: - fp = open(in_path, 'rb') try: - s = fp.read(self.SMALL_FILE_LIMIT + 1) - finally: - fp.close() + fp = open(in_path, 'rb') + try: + s = fp.read(self.SMALL_FILE_LIMIT + 1) + finally: + fp.close() + except OSError: + self._throw_io_error(e, in_path) + raise # Ensure did not grow during read. if len(s) == st.st_size: return self.put_data(out_path, s, mode=st.st_mode, utimes=(st.st_atime, st.st_mtime)) + self._connect() self.parent.call_service( service_name='mitogen.service.FileService', method_name='register', diff --git a/ansible_mitogen/process.py b/ansible_mitogen/process.py index 6e18a863..219d78a5 100644 --- a/ansible_mitogen/process.py +++ b/ansible_mitogen/process.py @@ -154,13 +154,16 @@ class MuxProcess(object): _instance = None @classmethod - def start(cls): + def start(cls, _init_logging=True): """ Arrange for the subprocess to be started, if it is not already running. The parent process picks a UNIX socket path the child will use prior to fork, creates a socketpair used essentially as a semaphore, then blocks waiting for the child to indicate the UNIX socket is ready for use. + + :param bool _init_logging: + For testing, if :data:`False`, don't initialize logging. """ if cls.worker_sock is not None: return @@ -180,7 +183,8 @@ class MuxProcess(object): cls.original_env = dict(os.environ) cls.child_pid = os.fork() - ansible_mitogen.logging.setup() + if _init_logging: + ansible_mitogen.logging.setup() if cls.child_pid: cls.child_sock.close() cls.child_sock = None diff --git a/tests/ansible/tests/connection_test.py b/tests/ansible/tests/connection_test.py index 33b60695..aaf4bf42 100644 --- a/tests/ansible/tests/connection_test.py +++ b/tests/ansible/tests/connection_test.py @@ -1,19 +1,33 @@ from __future__ import absolute_import +import os import os.path import subprocess import tempfile +import time + import unittest2 import mock +import ansible.errors +import ansible.playbook.play_context +import mitogen.core import ansible_mitogen.connection +import ansible_mitogen.plugins.connection.mitogen_local +import ansible_mitogen.process import testlib LOGGER_NAME = ansible_mitogen.target.LOG.name +# TODO: fixtureize +import mitogen.utils +mitogen.utils.log_to_file() +ansible_mitogen.process.MuxProcess.start(_init_logging=False) + + class OptionalIntTest(unittest2.TestCase): func = staticmethod(ansible_mitogen.connection.optional_int) @@ -34,5 +48,84 @@ class OptionalIntTest(unittest2.TestCase): self.assertEquals(None, self.func({1:2})) +class ConnectionMixin(object): + klass = ansible_mitogen.plugins.connection.mitogen_local.Connection + + def make_connection(self): + play_context = ansible.playbook.play_context.PlayContext() + return self.klass(play_context, new_stdin=False) + + def wait_for_completion(self): + # put_data() is asynchronous, must wait for operation to happen. Do + # that by making RPC for some junk that must run on the thread once op + # completes. + self.conn.get_chain().call(os.getpid) + + def setUp(self): + super(ConnectionMixin, self).setUp() + self.conn = self.make_connection() + + def tearDown(self): + self.conn.close() + super(ConnectionMixin, self).tearDown() + + +class PutDataTest(ConnectionMixin, unittest2.TestCase): + def test_out_path(self): + path = tempfile.mktemp(prefix='mitotest') + contents = mitogen.core.b('contents') + + self.conn.put_data(path, contents) + self.wait_for_completion() + self.assertEquals(contents, open(path, 'rb').read()) + os.unlink(path) + + def test_mode(self): + path = tempfile.mktemp(prefix='mitotest') + contents = mitogen.core.b('contents') + + self.conn.put_data(path, contents, mode=int('0123', 8)) + self.wait_for_completion() + st = os.stat(path) + self.assertEquals(int('0123', 8), st.st_mode & int('0777', 8)) + os.unlink(path) + + +class PutFileTest(ConnectionMixin, unittest2.TestCase): + @classmethod + def setUpClass(cls): + super(PutFileTest, cls).setUpClass() + cls.big_path = tempfile.mktemp(prefix='mitotestbig') + open(cls.big_path, 'w').write('x'*1048576) + + @classmethod + def tearDownClass(cls): + os.unlink(cls.big_path) + super(PutFileTest, cls).tearDownClass() + + def test_out_path_tiny(self): + path = tempfile.mktemp(prefix='mitotest') + self.conn.put_file(in_path=__file__, out_path=path) + self.wait_for_completion() + self.assertEquals(open(path, 'rb').read(), + open(__file__, 'rb').read()) + + os.unlink(path) + + def test_out_path_big(self): + path = tempfile.mktemp(prefix='mitotest') + self.conn.put_file(in_path=self.big_path, out_path=path) + self.wait_for_completion() + self.assertEquals(open(path, 'rb').read(), + open(self.big_path, 'rb').read()) + #self._compare_times_modes(path, __file__) + os.unlink(path) + + def test_big_in_path_not_found(self): + path = tempfile.mktemp(prefix='mitotest') + self.assertRaises(ansible.errors.AnsibleFileNotFound, + lambda: self.conn.put_file(in_path='/nonexistent', out_path=path)) + + if __name__ == '__main__': unittest2.main()