Improve tell() and seek() implementation

- Improvements to tell codepoint handling for
  SpooledStringIO
- Seek now works on codepoints for SpooledStringIO
This commit is contained in:
Brant Watson 2016-10-25 12:30:18 -05:00
parent 1a69c00278
commit 43b058d52e
2 changed files with 93 additions and 19 deletions

View File

@ -28,6 +28,14 @@ else:
text_type = unicode
binary_type = str
READ_CHUNK_SIZE = 21333
"""
Number of bytes to read at a time. The value is ~ 1/3rd of 64k which means that
the value will easily fit in the L2 cache of most processors even if every
codepoint in a string is three bytes long which makes it a nice fast default
value.
"""
class SpooledIOBase(object):
"""
@ -53,6 +61,10 @@ class SpooledIOBase(object):
def write(self, s):
"""Write into the buffer"""
@abstractmethod
def seek(self, pos, mode=0):
"""Seek to a specific point in a file"""
@abstractmethod
def readline(self, length=None):
"""Returns the next available line"""
@ -93,9 +105,6 @@ class SpooledIOBase(object):
def _file(self):
return self.buffer
def seek(self, pos, mode=0):
return self.buffer.seek(pos, mode)
def close(self):
return self.buffer.close()
@ -203,6 +212,9 @@ class SpooledBytesIO(SpooledIOBase):
if self.tell() >= self._max_size:
self.rollover()
def seek(self, pos, mode=0):
return self.buffer.seek(pos, mode)
def readline(self, length=None):
return self.buffer.readline(length)
@ -281,6 +293,25 @@ class SpooledStringIO(SpooledIOBase):
if self.tell() >= self._max_size:
self.rollover()
def seek(self, pos, mode=0):
if mode == os.SEEK_SET:
self.buffer.seek(0)
self.buffer.read(pos)
return pos
elif mode == os.SEEK_CUR:
start_pos = self.tell()
self.buffer.read(pos)
return start_pos + pos
elif mode == os.SEEK_END:
self.buffer.seek(0)
dest_position = self.len - pos
self.buffer.read(dest_position)
return dest_position
else:
raise ValueError(
"Invalid whence ({0}, should be 0, 1, or 2)".format(mode)
)
def readline(self, length=None):
return self.buffer.readline(length).decode('utf-8')
@ -314,27 +345,55 @@ class SpooledStringIO(SpooledIOBase):
pos = self.buffer.tell()
self.seek(0)
enc_pos = 0
while True:
current_pos = self.buffer.tell()
pos_read = False
current_pos = 0
while not pos_read:
if current_pos == pos:
# By chance, our enc_pos is some multiple of READ_CHUNK_SIZE.
# We're done!
break
ret = self.read(21333)
if self.buffer.tell() < pos:
enc_pos += len(ret)
continue
chunk = self.read(READ_CHUNK_SIZE)
chunk_bytes = len(chunk.encode('utf-8'))
# If the previous check fails, then we've seeked beyond the
# intended position. Go back to our position at the start of the
# loop and iterate one codepoint at a time until we reach the
# buffers position.
self.buffer.seek(current_pos)
while True:
if self.buffer.tell() == pos:
break
self.read(1)
if current_pos + chunk_bytes < pos:
current_pos += chunk_bytes
enc_pos += len(chunk)
continue
else:
# The chunk that we've read should contain the remaining bytes
# needed to put us over the pos value.
pos_read = True
else:
# We're done with the buffer at this point, so seek to the previous
# pos for happy return and error raising.
self.buffer.seek(pos)
# We've read past the underlying seek value. We now iterate our
# last retrieved chunk until the encoded character bytes plus the
# current_pos add up to the pos. Anything over that means that our
# underlying object seeked to a non code point.
for char in chunk:
enc_pos += 1
break
current_pos += len(char.encode('utf-8'))
if current_pos == pos:
break
elif current_pos >= pos:
raise IOError(
"SpooledStringIO's underlying EncodedFile seeked to "
"invalid code point!")
if current_pos != pos:
# If we made it this far and our current_pos doesn't equal our pos,
# something has gone wrong with our flo.
raise IOError(
"SpooledStringIO unable to read to previous file "
"position. Likely data truncation while we were "
"reading it.")
self.buffer.seek(pos)
return enc_pos

View File

@ -278,3 +278,18 @@ class TestSpooledStringIO(TestCase, BaseTestMixin, AssertionsMixin):
self.spooled_flo.seek(0)
self.spooled_flo.read(40)
self.assertEqual(self.spooled_flo.tell(), 40)
def test_codepoints_all_enc(self):
""""Test getting read, seek, tell, on various codepoints"""
test_str = u"\u2014\u2014\u2014"
self.spooled_flo.write(test_str)
self.spooled_flo.seek(1)
self.assertEqual(self.spooled_flo.read(), u"\u2014\u2014")
self.assertEqual(len(self.spooled_flo), len(test_str))
self.assertEqual(self.spooled_flo.tell(), 3)
def test_seek_codepoints(self):
"""Make seek() moves to positions along codepoints"""
self.spooled_flo.write(self.test_str)
ret = self.spooled_flo.seek(0, os.SEEK_END)
self.assertEqual(ret, len(self.test_str))