improved type checking for MultiFileReader (enables more file-like objects including codecs.open etc.), more tests to match. also corrected some error messages with appropriate class name.

This commit is contained in:
Mahmoud Hashemi 2017-07-29 17:33:59 -07:00
parent ae56836176
commit b61d5af99e
2 changed files with 44 additions and 11 deletions

View File

@ -1,4 +1,4 @@
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-
# Coding decl above needed for rendering the emdash properly in the
# documentation.
@ -409,11 +409,16 @@ class SpooledStringIO(SpooledIOBase):
class MultiFileReader(object):
def __init__(self, *fileobjs):
if all(isinstance(f, TextIOBase) for f in fileobjs):
self._joiner = ''
elif any(isinstance(f, TextIOBase) for f in fileobjs):
raise ValueError('All arguments to MultiFileReader must be either '
'bytes IO or text IO, not a mix')
if not all([callable(getattr(f, 'read', None)) and
callable(getattr(f, 'seek', None)) for f in fileobjs]):
raise TypeError('MultiFileReader expected file-like objects'
' with .read() and .seek()')
if all([hasattr(f, 'encoding') for f in fileobjs]):
# codecs.open and io.TextIOBase
self._joiner = u''
elif any([hasattr(f, 'encoding') for f in fileobjs]):
raise ValueError('All arguments to MultiFileReader must handle'
' bytes OR text, not a mix')
else:
self._joiner = b''
self._fileobjs = fileobjs
@ -434,11 +439,9 @@ class MultiFileReader(object):
def seek(self, offset, whence=os.SEEK_SET):
if whence != os.SEEK_SET:
raise NotImplementedError(
'fileprepender does not support anything other'
' than os.SEEK_SET for whence on seek()')
'MultiFileReader.seek() only supports os.SEEK_SET')
if offset != 0:
raise NotImplementedError(
'fileprepender only supports seeking to start, but that '
'could be fixed if you need it')
'MultiFileReader only supports seeking to start at this time')
for f in self._fileobjs:
f.seek(0)

View File

@ -1,14 +1,18 @@
import io
import os
import sys
import codecs
import random
import string
import sys
from tempfile import mkdtemp
from unittest import TestCase
from zipfile import ZipFile, ZIP_DEFLATED
from boltons import ioutils
CUR_FILE_PATH = os.path.abspath(__file__)
# Python2/3 compat
if sys.version_info[0] == 3:
text_type = str
@ -412,3 +416,29 @@ class TestMultiFileReader(TestCase):
def test_no_mixed_bytes_and_text(self):
self.assertRaises(ValueError, ioutils.MultiFileReader,
io.BytesIO(b'narf'), io.StringIO(u'troz'))
def test_open(self):
with open(CUR_FILE_PATH, 'r') as f:
r_file_str = f.read()
with open(CUR_FILE_PATH, 'r') as f1:
with open(CUR_FILE_PATH, 'r') as f2:
mfr = ioutils.MultiFileReader(f1, f2)
r_double_file_str = mfr.read()
assert r_double_file_str == (r_file_str * 2)
with open(CUR_FILE_PATH, 'rb') as f:
rb_file_str = f.read()
with open(CUR_FILE_PATH, 'rb') as f1:
with open(CUR_FILE_PATH, 'rb') as f2:
mfr = ioutils.MultiFileReader(f1, f2)
rb_double_file_str = mfr.read()
assert rb_double_file_str == (rb_file_str * 2)
utf8_file_str = codecs.open(CUR_FILE_PATH, encoding='utf8').read()
f1, f2 = (codecs.open(CUR_FILE_PATH, encoding='utf8'),
codecs.open(CUR_FILE_PATH, encoding='utf8'))
mfr = ioutils.MultiFileReader(f1, f2)
utf8_double_file_str = mfr.read()
assert utf8_double_file_str == (utf8_file_str * 2)