From fd77cb81a73b5f5f0a57932109189e47661b215b Mon Sep 17 00:00:00 2001 From: Brant Watson Date: Wed, 26 Oct 2016 14:29:10 -0500 Subject: [PATCH] Reduce tell complexity by tracking position Instead of calculating the position for tell by sequentially accessing the file, track the position on read, write, and seek operations. --- boltons/ioutils.py | 92 ++++++++++++------------------------------- tests/test_ioutils.py | 1 + 2 files changed, 27 insertions(+), 66 deletions(-) diff --git a/boltons/ioutils.py b/boltons/ioutils.py index 20015d8..ad18bae 100644 --- a/boltons/ioutils.py +++ b/boltons/ioutils.py @@ -277,10 +277,15 @@ class SpooledStringIO(SpooledIOBase): ... isinstance(f.read(), ioutils.text_type) True - """ + """ + def __init__(self, *args, **kwargs): + self._tell = 0 + super(SpooledStringIO, self).__init__(*args, **kwargs) def read(self, n=-1): - return self.buffer.read(n).decode('utf-8') + ret = self.buffer.read(n).decode('utf-8') + self._tell = self.tell() + len(ret) + return ret def write(self, s): if not isinstance(s, text_type): @@ -288,10 +293,11 @@ class SpooledStringIO(SpooledIOBase): text_type.__name__, type(s).__name__ )) - - self.buffer.write(s.encode('utf-8')) - if self.tell() >= self._max_size: + current_pos = self.tell() + if self.buffer.tell() + len(s.encode('utf-8')) >= self._max_size: self.rollover() + self.buffer.write(s.encode('utf-8')) + self._tell = current_pos + len(s) def _traverse_codepoints(self, current_position, n): """Traverse from current position to the right n codepoints""" @@ -324,25 +330,33 @@ class SpooledStringIO(SpooledIOBase): # Seek to position from the start of the file if mode == os.SEEK_SET: self.buffer.seek(0) - return self._traverse_codepoints(0, pos) + self._traverse_codepoints(0, pos) + self._tell = pos # Seek to new position relative to current position elif mode == os.SEEK_CUR: start_pos = self.tell() - return self._traverse_codepoints(self.tell(), pos) + self._traverse_codepoints(self.tell(), pos) + self._tell = start_pos + pos elif mode == os.SEEK_END: self.buffer.seek(0) dest_position = self.len - pos - return self._traverse_codepoints(0, dest_position) + self._traverse_codepoints(0, dest_position) + self._tell = dest_position else: raise ValueError( "Invalid whence ({0}, should be 0, 1, or 2)".format(mode) ) + return self.tell() def readline(self, length=None): - return self.buffer.readline(length).decode('utf-8') + ret = self.buffer.readline(length).decode('utf-8') + self._tell = self.tell() + len(ret) + return ret def readlines(self, sizehint=0): - return [x.decode('utf-8') for x in self.buffer.readlines(sizehint)] + ret = [x.decode('utf-8') for x in self.buffer.readlines(sizehint)] + self._tell = self.tell() + sum((len(x) for x in ret)) + return ret @property def buffer(self): @@ -368,61 +382,7 @@ class SpooledStringIO(SpooledIOBase): def tell(self): """Return the codepoint position""" - pos = self.buffer.tell() - self.seek(0) - enc_pos = 0 - pos_read = False - - current_pos = 0 - - while not pos_read: - if current_pos == pos: - # By chance, our enc_pos is some multiple of READ_CHUNK_SIZE. - # We're done! - break - - chunk = self.read(READ_CHUNK_SIZE) - chunk_bytes = len(chunk.encode('utf-8')) - - if current_pos + chunk_bytes < pos: - current_pos += chunk_bytes - enc_pos += len(chunk) - continue - else: - # The chunk that we've read should contain the remaining bytes - # needed to put us over the pos value. - pos_read = True - else: - # We've read past the underlying seek value. We now iterate our - # last retrieved chunk until the encoded character bytes plus the - # current_pos add up to the pos. Anything over that means that our - # underlying object seeked to a non code point. - for char in chunk: - enc_pos += 1 - current_pos += len(char.encode('utf-8')) - - if current_pos >= pos: - break - - self.buffer.seek(pos) - - if current_pos > pos: - # This means that we've read a character that started before the - # seek location, and ended after the seek location meaning that the - # cursor isn't aligned to the characters. - raise IOError( - "SpooledStringIO's underlying EncodedFile seeked to " - "invalid code point!") - - if current_pos < pos: - # If we made it this far and our current_pos doesn't equal our pos, - # something has gone wrong with our flo. - raise IOError( - "SpooledStringIO unable to read to previous file " - "position. Likely data truncation while we were " - "reading it.") - - return enc_pos + return self._tell @property def len(self): @@ -431,7 +391,7 @@ class SpooledStringIO(SpooledIOBase): self.seek(0) total = 0 while True: - ret = self.read(21333) + ret = self.read(READ_CHUNK_SIZE) if not ret: break total += len(ret) diff --git a/tests/test_ioutils.py b/tests/test_ioutils.py index f5e5d10..c749535 100644 --- a/tests/test_ioutils.py +++ b/tests/test_ioutils.py @@ -322,6 +322,7 @@ class TestSpooledStringIO(TestCase, BaseTestMixin, AssertionsMixin): test_str = u"\u2014\u2014\u2014" self.spooled_flo.write(test_str) self.spooled_flo.seek(1) + self.assertEqual(self.spooled_flo.tell(), 1) ret = self.spooled_flo.seek(2, os.SEEK_CUR) self.assertEqual(ret, 3)