mirror of https://github.com/mahmoud/boltons.git
Reduce tell complexity by tracking position
Instead of calculating the position for tell by sequentially accessing the file, track the position on read, write, and seek operations.
This commit is contained in:
parent
0b37b865af
commit
fd77cb81a7
|
@ -277,10 +277,15 @@ class SpooledStringIO(SpooledIOBase):
|
|||
... isinstance(f.read(), ioutils.text_type)
|
||||
True
|
||||
|
||||
"""
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._tell = 0
|
||||
super(SpooledStringIO, self).__init__(*args, **kwargs)
|
||||
|
||||
def read(self, n=-1):
|
||||
return self.buffer.read(n).decode('utf-8')
|
||||
ret = self.buffer.read(n).decode('utf-8')
|
||||
self._tell = self.tell() + len(ret)
|
||||
return ret
|
||||
|
||||
def write(self, s):
|
||||
if not isinstance(s, text_type):
|
||||
|
@ -288,10 +293,11 @@ class SpooledStringIO(SpooledIOBase):
|
|||
text_type.__name__,
|
||||
type(s).__name__
|
||||
))
|
||||
|
||||
self.buffer.write(s.encode('utf-8'))
|
||||
if self.tell() >= self._max_size:
|
||||
current_pos = self.tell()
|
||||
if self.buffer.tell() + len(s.encode('utf-8')) >= self._max_size:
|
||||
self.rollover()
|
||||
self.buffer.write(s.encode('utf-8'))
|
||||
self._tell = current_pos + len(s)
|
||||
|
||||
def _traverse_codepoints(self, current_position, n):
|
||||
"""Traverse from current position to the right n codepoints"""
|
||||
|
@ -324,25 +330,33 @@ class SpooledStringIO(SpooledIOBase):
|
|||
# Seek to position from the start of the file
|
||||
if mode == os.SEEK_SET:
|
||||
self.buffer.seek(0)
|
||||
return self._traverse_codepoints(0, pos)
|
||||
self._traverse_codepoints(0, pos)
|
||||
self._tell = pos
|
||||
# Seek to new position relative to current position
|
||||
elif mode == os.SEEK_CUR:
|
||||
start_pos = self.tell()
|
||||
return self._traverse_codepoints(self.tell(), pos)
|
||||
self._traverse_codepoints(self.tell(), pos)
|
||||
self._tell = start_pos + pos
|
||||
elif mode == os.SEEK_END:
|
||||
self.buffer.seek(0)
|
||||
dest_position = self.len - pos
|
||||
return self._traverse_codepoints(0, dest_position)
|
||||
self._traverse_codepoints(0, dest_position)
|
||||
self._tell = dest_position
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid whence ({0}, should be 0, 1, or 2)".format(mode)
|
||||
)
|
||||
return self.tell()
|
||||
|
||||
def readline(self, length=None):
|
||||
return self.buffer.readline(length).decode('utf-8')
|
||||
ret = self.buffer.readline(length).decode('utf-8')
|
||||
self._tell = self.tell() + len(ret)
|
||||
return ret
|
||||
|
||||
def readlines(self, sizehint=0):
|
||||
return [x.decode('utf-8') for x in self.buffer.readlines(sizehint)]
|
||||
ret = [x.decode('utf-8') for x in self.buffer.readlines(sizehint)]
|
||||
self._tell = self.tell() + sum((len(x) for x in ret))
|
||||
return ret
|
||||
|
||||
@property
|
||||
def buffer(self):
|
||||
|
@ -368,61 +382,7 @@ class SpooledStringIO(SpooledIOBase):
|
|||
|
||||
def tell(self):
|
||||
"""Return the codepoint position"""
|
||||
pos = self.buffer.tell()
|
||||
self.seek(0)
|
||||
enc_pos = 0
|
||||
pos_read = False
|
||||
|
||||
current_pos = 0
|
||||
|
||||
while not pos_read:
|
||||
if current_pos == pos:
|
||||
# By chance, our enc_pos is some multiple of READ_CHUNK_SIZE.
|
||||
# We're done!
|
||||
break
|
||||
|
||||
chunk = self.read(READ_CHUNK_SIZE)
|
||||
chunk_bytes = len(chunk.encode('utf-8'))
|
||||
|
||||
if current_pos + chunk_bytes < pos:
|
||||
current_pos += chunk_bytes
|
||||
enc_pos += len(chunk)
|
||||
continue
|
||||
else:
|
||||
# The chunk that we've read should contain the remaining bytes
|
||||
# needed to put us over the pos value.
|
||||
pos_read = True
|
||||
else:
|
||||
# We've read past the underlying seek value. We now iterate our
|
||||
# last retrieved chunk until the encoded character bytes plus the
|
||||
# current_pos add up to the pos. Anything over that means that our
|
||||
# underlying object seeked to a non code point.
|
||||
for char in chunk:
|
||||
enc_pos += 1
|
||||
current_pos += len(char.encode('utf-8'))
|
||||
|
||||
if current_pos >= pos:
|
||||
break
|
||||
|
||||
self.buffer.seek(pos)
|
||||
|
||||
if current_pos > pos:
|
||||
# This means that we've read a character that started before the
|
||||
# seek location, and ended after the seek location meaning that the
|
||||
# cursor isn't aligned to the characters.
|
||||
raise IOError(
|
||||
"SpooledStringIO's underlying EncodedFile seeked to "
|
||||
"invalid code point!")
|
||||
|
||||
if current_pos < pos:
|
||||
# If we made it this far and our current_pos doesn't equal our pos,
|
||||
# something has gone wrong with our flo.
|
||||
raise IOError(
|
||||
"SpooledStringIO unable to read to previous file "
|
||||
"position. Likely data truncation while we were "
|
||||
"reading it.")
|
||||
|
||||
return enc_pos
|
||||
return self._tell
|
||||
|
||||
@property
|
||||
def len(self):
|
||||
|
@ -431,7 +391,7 @@ class SpooledStringIO(SpooledIOBase):
|
|||
self.seek(0)
|
||||
total = 0
|
||||
while True:
|
||||
ret = self.read(21333)
|
||||
ret = self.read(READ_CHUNK_SIZE)
|
||||
if not ret:
|
||||
break
|
||||
total += len(ret)
|
||||
|
|
|
@ -322,6 +322,7 @@ class TestSpooledStringIO(TestCase, BaseTestMixin, AssertionsMixin):
|
|||
test_str = u"\u2014\u2014\u2014"
|
||||
self.spooled_flo.write(test_str)
|
||||
self.spooled_flo.seek(1)
|
||||
self.assertEqual(self.spooled_flo.tell(), 1)
|
||||
ret = self.spooled_flo.seek(2, os.SEEK_CUR)
|
||||
self.assertEqual(ret, 3)
|
||||
|
||||
|
|
Loading…
Reference in New Issue