mirror of https://github.com/mahmoud/boltons.git
Improve tell() and seek() implementation
- Improvements to tell codepoint handling for SpooledStringIO - Seek now works on codepoints for SpooledStringIO
This commit is contained in:
parent
1a69c00278
commit
43b058d52e
|
@ -28,6 +28,14 @@ else:
|
|||
text_type = unicode
|
||||
binary_type = str
|
||||
|
||||
READ_CHUNK_SIZE = 21333
|
||||
"""
|
||||
Number of bytes to read at a time. The value is ~ 1/3rd of 64k which means that
|
||||
the value will easily fit in the L2 cache of most processors even if every
|
||||
codepoint in a string is three bytes long which makes it a nice fast default
|
||||
value.
|
||||
"""
|
||||
|
||||
|
||||
class SpooledIOBase(object):
|
||||
"""
|
||||
|
@ -53,6 +61,10 @@ class SpooledIOBase(object):
|
|||
def write(self, s):
|
||||
"""Write into the buffer"""
|
||||
|
||||
@abstractmethod
|
||||
def seek(self, pos, mode=0):
|
||||
"""Seek to a specific point in a file"""
|
||||
|
||||
@abstractmethod
|
||||
def readline(self, length=None):
|
||||
"""Returns the next available line"""
|
||||
|
@ -93,9 +105,6 @@ class SpooledIOBase(object):
|
|||
def _file(self):
|
||||
return self.buffer
|
||||
|
||||
def seek(self, pos, mode=0):
|
||||
return self.buffer.seek(pos, mode)
|
||||
|
||||
def close(self):
|
||||
return self.buffer.close()
|
||||
|
||||
|
@ -203,6 +212,9 @@ class SpooledBytesIO(SpooledIOBase):
|
|||
if self.tell() >= self._max_size:
|
||||
self.rollover()
|
||||
|
||||
def seek(self, pos, mode=0):
|
||||
return self.buffer.seek(pos, mode)
|
||||
|
||||
def readline(self, length=None):
|
||||
return self.buffer.readline(length)
|
||||
|
||||
|
@ -281,6 +293,25 @@ class SpooledStringIO(SpooledIOBase):
|
|||
if self.tell() >= self._max_size:
|
||||
self.rollover()
|
||||
|
||||
def seek(self, pos, mode=0):
|
||||
if mode == os.SEEK_SET:
|
||||
self.buffer.seek(0)
|
||||
self.buffer.read(pos)
|
||||
return pos
|
||||
elif mode == os.SEEK_CUR:
|
||||
start_pos = self.tell()
|
||||
self.buffer.read(pos)
|
||||
return start_pos + pos
|
||||
elif mode == os.SEEK_END:
|
||||
self.buffer.seek(0)
|
||||
dest_position = self.len - pos
|
||||
self.buffer.read(dest_position)
|
||||
return dest_position
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid whence ({0}, should be 0, 1, or 2)".format(mode)
|
||||
)
|
||||
|
||||
def readline(self, length=None):
|
||||
return self.buffer.readline(length).decode('utf-8')
|
||||
|
||||
|
@ -314,27 +345,55 @@ class SpooledStringIO(SpooledIOBase):
|
|||
pos = self.buffer.tell()
|
||||
self.seek(0)
|
||||
enc_pos = 0
|
||||
while True:
|
||||
current_pos = self.buffer.tell()
|
||||
pos_read = False
|
||||
|
||||
current_pos = 0
|
||||
|
||||
while not pos_read:
|
||||
if current_pos == pos:
|
||||
# By chance, our enc_pos is some multiple of READ_CHUNK_SIZE.
|
||||
# We're done!
|
||||
break
|
||||
|
||||
ret = self.read(21333)
|
||||
if self.buffer.tell() < pos:
|
||||
enc_pos += len(ret)
|
||||
continue
|
||||
chunk = self.read(READ_CHUNK_SIZE)
|
||||
chunk_bytes = len(chunk.encode('utf-8'))
|
||||
|
||||
# If the previous check fails, then we've seeked beyond the
|
||||
# intended position. Go back to our position at the start of the
|
||||
# loop and iterate one codepoint at a time until we reach the
|
||||
# buffers position.
|
||||
self.buffer.seek(current_pos)
|
||||
while True:
|
||||
if self.buffer.tell() == pos:
|
||||
break
|
||||
self.read(1)
|
||||
if current_pos + chunk_bytes < pos:
|
||||
current_pos += chunk_bytes
|
||||
enc_pos += len(chunk)
|
||||
continue
|
||||
else:
|
||||
# The chunk that we've read should contain the remaining bytes
|
||||
# needed to put us over the pos value.
|
||||
pos_read = True
|
||||
else:
|
||||
# We're done with the buffer at this point, so seek to the previous
|
||||
# pos for happy return and error raising.
|
||||
self.buffer.seek(pos)
|
||||
|
||||
# We've read past the underlying seek value. We now iterate our
|
||||
# last retrieved chunk until the encoded character bytes plus the
|
||||
# current_pos add up to the pos. Anything over that means that our
|
||||
# underlying object seeked to a non code point.
|
||||
for char in chunk:
|
||||
enc_pos += 1
|
||||
break
|
||||
current_pos += len(char.encode('utf-8'))
|
||||
|
||||
if current_pos == pos:
|
||||
break
|
||||
elif current_pos >= pos:
|
||||
raise IOError(
|
||||
"SpooledStringIO's underlying EncodedFile seeked to "
|
||||
"invalid code point!")
|
||||
|
||||
if current_pos != pos:
|
||||
# If we made it this far and our current_pos doesn't equal our pos,
|
||||
# something has gone wrong with our flo.
|
||||
raise IOError(
|
||||
"SpooledStringIO unable to read to previous file "
|
||||
"position. Likely data truncation while we were "
|
||||
"reading it.")
|
||||
|
||||
self.buffer.seek(pos)
|
||||
return enc_pos
|
||||
|
||||
|
|
|
@ -278,3 +278,18 @@ class TestSpooledStringIO(TestCase, BaseTestMixin, AssertionsMixin):
|
|||
self.spooled_flo.seek(0)
|
||||
self.spooled_flo.read(40)
|
||||
self.assertEqual(self.spooled_flo.tell(), 40)
|
||||
|
||||
def test_codepoints_all_enc(self):
|
||||
""""Test getting read, seek, tell, on various codepoints"""
|
||||
test_str = u"\u2014\u2014\u2014"
|
||||
self.spooled_flo.write(test_str)
|
||||
self.spooled_flo.seek(1)
|
||||
self.assertEqual(self.spooled_flo.read(), u"\u2014\u2014")
|
||||
self.assertEqual(len(self.spooled_flo), len(test_str))
|
||||
self.assertEqual(self.spooled_flo.tell(), 3)
|
||||
|
||||
def test_seek_codepoints(self):
|
||||
"""Make seek() moves to positions along codepoints"""
|
||||
self.spooled_flo.write(self.test_str)
|
||||
ret = self.spooled_flo.seek(0, os.SEEK_END)
|
||||
self.assertEqual(ret, len(self.test_str))
|
||||
|
|
Loading…
Reference in New Issue