diff --git a/boltons/ioutils.py b/boltons/ioutils.py index 0fef04c..47f2b6a 100644 --- a/boltons/ioutils.py +++ b/boltons/ioutils.py @@ -1,4 +1,4 @@ -# -*- coding: UTF-8 -*- +# -*- coding: utf-8 -*- # Coding decl above needed for rendering the emdash properly in the # documentation. @@ -10,13 +10,13 @@ ways. """ import os import sys +from io import BytesIO from abc import ( ABCMeta, abstractmethod, abstractproperty, ) from errno import EINVAL -from io import BytesIO from codecs import EncodedFile from tempfile import TemporaryFile @@ -404,3 +404,82 @@ class SpooledStringIO(SpooledIOBase): total += len(ret) self.buffer.seek(pos) return total + + +def is_text_fileobj(fileobj): + if getattr(fileobj, 'encoding', False): + # codecs.open and io.TextIOBase + return True + if getattr(fileobj, 'getvalue', False): + # StringIO.StringIO / cStringIO.StringIO / io.StringIO + try: + if isinstance(fileobj.getvalue(), type(u'')): + return True + except Exception: + pass + return False + + +class MultiFileReader(object): + """Takes a list of open files or file-like objects and provides an + interface to read from them all contiguously. Like + :func:`itertools.chain()`, but for reading files. + + >>> mfr = MultiFileReader(BytesIO(b'ab'), BytesIO(b'cd'), BytesIO(b'e')) + >>> mfr.read(3).decode('ascii') + u'abc' + >>> mfr.read(3).decode('ascii') + u'de' + + The constructor takes as many fileobjs as you hand it, and will + raise a TypeError on non-file-like objects. A ValueError is raised + when file-like objects are a mix of bytes- and text-handling + objects (for instance, BytesIO and StringIO). + """ + + def __init__(self, *fileobjs): + if not all([callable(getattr(f, 'read', None)) and + callable(getattr(f, 'seek', None)) for f in fileobjs]): + raise TypeError('MultiFileReader expected file-like objects' + ' with .read() and .seek()') + if all([is_text_fileobj(f) for f in fileobjs]): + # codecs.open and io.TextIOBase + self._joiner = u'' + elif any([is_text_fileobj(f) for f in fileobjs]): + raise ValueError('All arguments to MultiFileReader must handle' + ' bytes OR text, not a mix') + else: + # open/file and io.BytesIO + self._joiner = b'' + self._fileobjs = fileobjs + self._index = 0 + + def read(self, amt=None): + """Read up to the specified *amt*, seamlessly bridging across + files. Returns the appropriate type of string (bytes or text) + for the input, and returns an empty string when the files are + exhausted. + """ + if not amt: + return self._joiner.join(f.read() for f in self._fileobjs) + parts = [] + while amt > 0 and self._index < len(self._fileobjs): + parts.append(self._fileobjs[self._index].read(amt)) + got = len(parts[-1]) + if got < amt: + self._index += 1 + amt -= got + return self._joiner.join(parts) + + def seek(self, offset, whence=os.SEEK_SET): + """Enables setting position of the file cursor to a given + *offset*. Currently only supports ``offset=0``. + """ + if whence != os.SEEK_SET: + raise NotImplementedError( + 'MultiFileReader.seek() only supports os.SEEK_SET') + if offset != 0: + raise NotImplementedError( + 'MultiFileReader only supports seeking to start at this time') + for f in self._fileobjs: + f.seek(0) diff --git a/docs/ioutils.rst b/docs/ioutils.rst index 8de7322..10cfbe9 100644 --- a/docs/ioutils.rst +++ b/docs/ioutils.rst @@ -78,3 +78,13 @@ Here is a simple example using the requests library to download a zip file:: # Print all the files in the zip print(zip_doc.namelist()) + + +Multiple Files +-------------- + +.. _multifilereader: + +MultiFileReader +^^^^^^^^^^^^^^^ +.. autoclass:: boltons.ioutils.MultiFileReader diff --git a/tests/test_ioutils.py b/tests/test_ioutils.py index f6183c8..a798dcb 100644 --- a/tests/test_ioutils.py +++ b/tests/test_ioutils.py @@ -1,13 +1,26 @@ +import io import os +import sys +import codecs import random import string -import sys + +try: + from StringIO import StringIO +except: + # py3 + StringIO = io.StringIO + from tempfile import mkdtemp from unittest import TestCase from zipfile import ZipFile, ZIP_DEFLATED + from boltons import ioutils +CUR_FILE_PATH = os.path.abspath(__file__) + + # Python2/3 compat if sys.version_info[0] == 3: text_type = str @@ -391,3 +404,51 @@ class TestSpooledStringIO(TestCase, BaseTestMixin, AssertionsMixin): self.spooled_flo.write(test_str) self.spooled_flo.seek(0) self.assertEqual(self.spooled_flo.read(3), test_str) + + +class TestMultiFileReader(TestCase): + def test_read_seek_bytes(self): + r = ioutils.MultiFileReader(io.BytesIO(b'narf'), io.BytesIO(b'troz')) + self.assertEqual([b'nar', b'ftr', b'oz'], + list(iter(lambda: r.read(3), b''))) + r.seek(0) + self.assertEqual(b'narftroz', r.read()) + + def test_read_seek_text(self): + # also tests StringIO.StringIO on py2 + r = ioutils.MultiFileReader(StringIO(u'narf'), + io.StringIO(u'troz')) + self.assertEqual([u'nar', u'ftr', u'oz'], + list(iter(lambda: r.read(3), u''))) + r.seek(0) + self.assertEqual(u'narftroz', r.read()) + + def test_no_mixed_bytes_and_text(self): + self.assertRaises(ValueError, ioutils.MultiFileReader, + io.BytesIO(b'narf'), io.StringIO(u'troz')) + + def test_open(self): + with open(CUR_FILE_PATH, 'r') as f: + r_file_str = f.read() + with open(CUR_FILE_PATH, 'r') as f1: + with open(CUR_FILE_PATH, 'r') as f2: + mfr = ioutils.MultiFileReader(f1, f2) + r_double_file_str = mfr.read() + + assert r_double_file_str == (r_file_str * 2) + + with open(CUR_FILE_PATH, 'rb') as f: + rb_file_str = f.read() + with open(CUR_FILE_PATH, 'rb') as f1: + with open(CUR_FILE_PATH, 'rb') as f2: + mfr = ioutils.MultiFileReader(f1, f2) + rb_double_file_str = mfr.read() + + assert rb_double_file_str == (rb_file_str * 2) + + utf8_file_str = codecs.open(CUR_FILE_PATH, encoding='utf8').read() + f1, f2 = (codecs.open(CUR_FILE_PATH, encoding='utf8'), + codecs.open(CUR_FILE_PATH, encoding='utf8')) + mfr = ioutils.MultiFileReader(f1, f2) + utf8_double_file_str = mfr.read() + assert utf8_double_file_str == (utf8_file_str * 2)