Reduce tell complexity by tracking position

Instead of calculating the position for tell by
sequentially accessing the file, track the position
on read, write, and seek operations.
This commit is contained in:
Brant Watson 2016-10-26 14:29:10 -05:00
parent 0b37b865af
commit fd77cb81a7
2 changed files with 27 additions and 66 deletions

View File

@ -277,10 +277,15 @@ class SpooledStringIO(SpooledIOBase):
... isinstance(f.read(), ioutils.text_type)
True
"""
"""
def __init__(self, *args, **kwargs):
self._tell = 0
super(SpooledStringIO, self).__init__(*args, **kwargs)
def read(self, n=-1):
return self.buffer.read(n).decode('utf-8')
ret = self.buffer.read(n).decode('utf-8')
self._tell = self.tell() + len(ret)
return ret
def write(self, s):
if not isinstance(s, text_type):
@ -288,10 +293,11 @@ class SpooledStringIO(SpooledIOBase):
text_type.__name__,
type(s).__name__
))
self.buffer.write(s.encode('utf-8'))
if self.tell() >= self._max_size:
current_pos = self.tell()
if self.buffer.tell() + len(s.encode('utf-8')) >= self._max_size:
self.rollover()
self.buffer.write(s.encode('utf-8'))
self._tell = current_pos + len(s)
def _traverse_codepoints(self, current_position, n):
"""Traverse from current position to the right n codepoints"""
@ -324,25 +330,33 @@ class SpooledStringIO(SpooledIOBase):
# Seek to position from the start of the file
if mode == os.SEEK_SET:
self.buffer.seek(0)
return self._traverse_codepoints(0, pos)
self._traverse_codepoints(0, pos)
self._tell = pos
# Seek to new position relative to current position
elif mode == os.SEEK_CUR:
start_pos = self.tell()
return self._traverse_codepoints(self.tell(), pos)
self._traverse_codepoints(self.tell(), pos)
self._tell = start_pos + pos
elif mode == os.SEEK_END:
self.buffer.seek(0)
dest_position = self.len - pos
return self._traverse_codepoints(0, dest_position)
self._traverse_codepoints(0, dest_position)
self._tell = dest_position
else:
raise ValueError(
"Invalid whence ({0}, should be 0, 1, or 2)".format(mode)
)
return self.tell()
def readline(self, length=None):
return self.buffer.readline(length).decode('utf-8')
ret = self.buffer.readline(length).decode('utf-8')
self._tell = self.tell() + len(ret)
return ret
def readlines(self, sizehint=0):
return [x.decode('utf-8') for x in self.buffer.readlines(sizehint)]
ret = [x.decode('utf-8') for x in self.buffer.readlines(sizehint)]
self._tell = self.tell() + sum((len(x) for x in ret))
return ret
@property
def buffer(self):
@ -368,61 +382,7 @@ class SpooledStringIO(SpooledIOBase):
def tell(self):
"""Return the codepoint position"""
pos = self.buffer.tell()
self.seek(0)
enc_pos = 0
pos_read = False
current_pos = 0
while not pos_read:
if current_pos == pos:
# By chance, our enc_pos is some multiple of READ_CHUNK_SIZE.
# We're done!
break
chunk = self.read(READ_CHUNK_SIZE)
chunk_bytes = len(chunk.encode('utf-8'))
if current_pos + chunk_bytes < pos:
current_pos += chunk_bytes
enc_pos += len(chunk)
continue
else:
# The chunk that we've read should contain the remaining bytes
# needed to put us over the pos value.
pos_read = True
else:
# We've read past the underlying seek value. We now iterate our
# last retrieved chunk until the encoded character bytes plus the
# current_pos add up to the pos. Anything over that means that our
# underlying object seeked to a non code point.
for char in chunk:
enc_pos += 1
current_pos += len(char.encode('utf-8'))
if current_pos >= pos:
break
self.buffer.seek(pos)
if current_pos > pos:
# This means that we've read a character that started before the
# seek location, and ended after the seek location meaning that the
# cursor isn't aligned to the characters.
raise IOError(
"SpooledStringIO's underlying EncodedFile seeked to "
"invalid code point!")
if current_pos < pos:
# If we made it this far and our current_pos doesn't equal our pos,
# something has gone wrong with our flo.
raise IOError(
"SpooledStringIO unable to read to previous file "
"position. Likely data truncation while we were "
"reading it.")
return enc_pos
return self._tell
@property
def len(self):
@ -431,7 +391,7 @@ class SpooledStringIO(SpooledIOBase):
self.seek(0)
total = 0
while True:
ret = self.read(21333)
ret = self.read(READ_CHUNK_SIZE)
if not ret:
break
total += len(ret)

View File

@ -322,6 +322,7 @@ class TestSpooledStringIO(TestCase, BaseTestMixin, AssertionsMixin):
test_str = u"\u2014\u2014\u2014"
self.spooled_flo.write(test_str)
self.spooled_flo.seek(1)
self.assertEqual(self.spooled_flo.tell(), 1)
ret = self.spooled_flo.seek(2, os.SEEK_CUR)
self.assertEqual(ret, 3)