From f645c63838527d460377c131348a8a9e26e8dd17 Mon Sep 17 00:00:00 2001 From: Augie Fackler Date: Thu, 27 Jul 2017 09:51:37 -0400 Subject: [PATCH] ioutils: add MultiFileReader to ease concatenation of multiple readers I ended up needing something like this for Mercurial, and mhashemirc suggested that it would make sense in boltons. --- boltons/ioutils.py | 40 +++++++++++++++++++++++++++++++++++++++- tests/test_ioutils.py | 21 +++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/boltons/ioutils.py b/boltons/ioutils.py index 0fef04c..67bd3f9 100644 --- a/boltons/ioutils.py +++ b/boltons/ioutils.py @@ -16,7 +16,7 @@ from abc import ( abstractproperty, ) from errno import EINVAL -from io import BytesIO +from io import BytesIO, TextIOBase from codecs import EncodedFile from tempfile import TemporaryFile @@ -404,3 +404,41 @@ class SpooledStringIO(SpooledIOBase): total += len(ret) self.buffer.seek(pos) return total + + +class MultiFileReader(object): + + def __init__(self, *fileobjs): + if all(isinstance(f, TextIOBase) for f in fileobjs): + self._joiner = '' + elif any(isinstance(f, TextIOBase) for f in fileobjs): + raise ValueError('All arguments to MultiFileReader must be either ' + 'bytes IO or text IO, not a mix') + else: + self._joiner = b'' + self._fileobjs = fileobjs + self._index = 0 + + def read(self, amt=None): + if not amt: + return self._joiner.join(f.read() for f in self._fileobjs) + parts = [] + while amt > 0 and self._index < len(self._fileobjs): + parts.append(self._fileobjs[self._index].read(amt)) + got = len(parts[-1]) + if got < amt: + self._index += 1 + amt -= got + return self._joiner.join(parts) + + def seek(self, offset, whence=os.SEEK_SET): + if whence != os.SEEK_SET: + raise NotImplementedError( + 'fileprepender does not support anything other' + ' than os.SEEK_SET for whence on seek()') + if offset != 0: + raise NotImplementedError( + 'fileprepender only supports seeking to start, but that ' + 'could be fixed if you need it') + for f in self._fileobjs: + f.seek(0) diff --git a/tests/test_ioutils.py b/tests/test_ioutils.py index f6183c8..5fac8b1 100644 --- a/tests/test_ioutils.py +++ b/tests/test_ioutils.py @@ -1,3 +1,4 @@ +import io import os import random import string @@ -391,3 +392,23 @@ class TestSpooledStringIO(TestCase, BaseTestMixin, AssertionsMixin): self.spooled_flo.write(test_str) self.spooled_flo.seek(0) self.assertEqual(self.spooled_flo.read(3), test_str) + + +class TestMultiFileReader(TestCase): + def test_read_seek_bytes(self): + r = ioutils.MultiFileReader(io.BytesIO(b'narf'), io.BytesIO(b'troz')) + self.assertEqual([b'nar', b'ftr', b'oz'], + list(iter(lambda: r.read(3), b''))) + r.seek(0) + self.assertEqual(b'narftroz', r.read()) + + def test_read_seek_text(self): + r = ioutils.MultiFileReader(io.StringIO(u'narf'), io.StringIO(u'troz')) + self.assertEqual([u'nar', u'ftr', u'oz'], + list(iter(lambda: r.read(3), u''))) + r.seek(0) + self.assertEqual(u'narftroz', r.read()) + + def test_no_mixed_bytes_and_text(self): + with self.assertRaises(ValueError): + ioutils.MultiFileReader(io.BytesIO(b'narf'), io.StringIO(u'troz'))