diff --git a/boltons/ioutils.py b/boltons/ioutils.py index 0fdc8e1..5c22612 100644 --- a/boltons/ioutils.py +++ b/boltons/ioutils.py @@ -10,13 +10,13 @@ ways. """ import os import sys +from io import BytesIO from abc import ( ABCMeta, abstractmethod, abstractproperty, ) from errno import EINVAL -from io import BytesIO, TextIOBase from codecs import EncodedFile from tempfile import TemporaryFile @@ -406,6 +406,20 @@ class SpooledStringIO(SpooledIOBase): return total +def is_text_fileobj(fileobj): + if hasattr(fileobj, 'encoding'): + # codecs.open and io.TextIOBase + return True + if hasattr(fileobj, 'getvalue'): + # StringIO.StringIO / cStringIO.StringIO / io.StringIO + try: + if isinstance(fileobj.getvalue(), type(u'')): + return True + except Exception: + pass + return False + + class MultiFileReader(object): def __init__(self, *fileobjs): @@ -413,13 +427,14 @@ class MultiFileReader(object): callable(getattr(f, 'seek', None)) for f in fileobjs]): raise TypeError('MultiFileReader expected file-like objects' ' with .read() and .seek()') - if all([hasattr(f, 'encoding') for f in fileobjs]): + if all([is_text_fileobj(f) for f in fileobjs]): # codecs.open and io.TextIOBase self._joiner = u'' - elif any([hasattr(f, 'encoding') for f in fileobjs]): + elif any([is_text_fileobj(f) for f in fileobjs]): raise ValueError('All arguments to MultiFileReader must handle' ' bytes OR text, not a mix') else: + # open/file and io.BytesIO self._joiner = b'' self._fileobjs = fileobjs self._index = 0 diff --git a/tests/test_ioutils.py b/tests/test_ioutils.py index 149ce24..a798dcb 100644 --- a/tests/test_ioutils.py +++ b/tests/test_ioutils.py @@ -4,10 +4,18 @@ import sys import codecs import random import string + +try: + from StringIO import StringIO +except: + # py3 + StringIO = io.StringIO + from tempfile import mkdtemp from unittest import TestCase from zipfile import ZipFile, ZIP_DEFLATED + from boltons import ioutils CUR_FILE_PATH = os.path.abspath(__file__) @@ -407,7 +415,9 @@ class TestMultiFileReader(TestCase): self.assertEqual(b'narftroz', r.read()) def test_read_seek_text(self): - r = ioutils.MultiFileReader(io.StringIO(u'narf'), io.StringIO(u'troz')) + # also tests StringIO.StringIO on py2 + r = ioutils.MultiFileReader(StringIO(u'narf'), + io.StringIO(u'troz')) self.assertEqual([u'nar', u'ftr', u'oz'], list(iter(lambda: r.read(3), u''))) r.seek(0)