mirror of https://github.com/mahmoud/boltons.git
Merge branch 'multifile' of git://github.com/durin42/boltons into durin42-multifile
This commit is contained in:
commit
c178f02aaf
|
@ -16,7 +16,7 @@ from abc import (
|
|||
abstractproperty,
|
||||
)
|
||||
from errno import EINVAL
|
||||
from io import BytesIO
|
||||
from io import BytesIO, TextIOBase
|
||||
from codecs import EncodedFile
|
||||
from tempfile import TemporaryFile
|
||||
|
||||
|
@ -404,3 +404,41 @@ class SpooledStringIO(SpooledIOBase):
|
|||
total += len(ret)
|
||||
self.buffer.seek(pos)
|
||||
return total
|
||||
|
||||
|
||||
class MultiFileReader(object):
|
||||
|
||||
def __init__(self, *fileobjs):
|
||||
if all(isinstance(f, TextIOBase) for f in fileobjs):
|
||||
self._joiner = ''
|
||||
elif any(isinstance(f, TextIOBase) for f in fileobjs):
|
||||
raise ValueError('All arguments to MultiFileReader must be either '
|
||||
'bytes IO or text IO, not a mix')
|
||||
else:
|
||||
self._joiner = b''
|
||||
self._fileobjs = fileobjs
|
||||
self._index = 0
|
||||
|
||||
def read(self, amt=None):
|
||||
if not amt:
|
||||
return self._joiner.join(f.read() for f in self._fileobjs)
|
||||
parts = []
|
||||
while amt > 0 and self._index < len(self._fileobjs):
|
||||
parts.append(self._fileobjs[self._index].read(amt))
|
||||
got = len(parts[-1])
|
||||
if got < amt:
|
||||
self._index += 1
|
||||
amt -= got
|
||||
return self._joiner.join(parts)
|
||||
|
||||
def seek(self, offset, whence=os.SEEK_SET):
|
||||
if whence != os.SEEK_SET:
|
||||
raise NotImplementedError(
|
||||
'fileprepender does not support anything other'
|
||||
' than os.SEEK_SET for whence on seek()')
|
||||
if offset != 0:
|
||||
raise NotImplementedError(
|
||||
'fileprepender only supports seeking to start, but that '
|
||||
'could be fixed if you need it')
|
||||
for f in self._fileobjs:
|
||||
f.seek(0)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import io
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
|
@ -391,3 +392,23 @@ class TestSpooledStringIO(TestCase, BaseTestMixin, AssertionsMixin):
|
|||
self.spooled_flo.write(test_str)
|
||||
self.spooled_flo.seek(0)
|
||||
self.assertEqual(self.spooled_flo.read(3), test_str)
|
||||
|
||||
|
||||
class TestMultiFileReader(TestCase):
|
||||
def test_read_seek_bytes(self):
|
||||
r = ioutils.MultiFileReader(io.BytesIO(b'narf'), io.BytesIO(b'troz'))
|
||||
self.assertEqual([b'nar', b'ftr', b'oz'],
|
||||
list(iter(lambda: r.read(3), b'')))
|
||||
r.seek(0)
|
||||
self.assertEqual(b'narftroz', r.read())
|
||||
|
||||
def test_read_seek_text(self):
|
||||
r = ioutils.MultiFileReader(io.StringIO(u'narf'), io.StringIO(u'troz'))
|
||||
self.assertEqual([u'nar', u'ftr', u'oz'],
|
||||
list(iter(lambda: r.read(3), u'')))
|
||||
r.seek(0)
|
||||
self.assertEqual(u'narftroz', r.read())
|
||||
|
||||
def test_no_mixed_bytes_and_text(self):
|
||||
with self.assertRaises(ValueError):
|
||||
ioutils.MultiFileReader(io.BytesIO(b'narf'), io.StringIO(u'troz'))
|
||||
|
|
Loading…
Reference in New Issue