diff --git a/boltons/ioutils.py b/boltons/ioutils.py index 2a513a6..7cdaac0 100644 --- a/boltons/ioutils.py +++ b/boltons/ioutils.py @@ -28,6 +28,14 @@ else: text_type = unicode binary_type = str +READ_CHUNK_SIZE = 21333 +""" +Number of bytes to read at a time. The value is ~ 1/3rd of 64k which means that +the value will easily fit in the L2 cache of most processors even if every +codepoint in a string is three bytes long which makes it a nice fast default +value. +""" + class SpooledIOBase(object): """ @@ -53,6 +61,10 @@ class SpooledIOBase(object): def write(self, s): """Write into the buffer""" + @abstractmethod + def seek(self, pos, mode=0): + """Seek to a specific point in a file""" + @abstractmethod def readline(self, length=None): """Returns the next available line""" @@ -93,9 +105,6 @@ class SpooledIOBase(object): def _file(self): return self.buffer - def seek(self, pos, mode=0): - return self.buffer.seek(pos, mode) - def close(self): return self.buffer.close() @@ -203,6 +212,9 @@ class SpooledBytesIO(SpooledIOBase): if self.tell() >= self._max_size: self.rollover() + def seek(self, pos, mode=0): + return self.buffer.seek(pos, mode) + def readline(self, length=None): return self.buffer.readline(length) @@ -281,6 +293,25 @@ class SpooledStringIO(SpooledIOBase): if self.tell() >= self._max_size: self.rollover() + def seek(self, pos, mode=0): + if mode == os.SEEK_SET: + self.buffer.seek(0) + self.buffer.read(pos) + return pos + elif mode == os.SEEK_CUR: + start_pos = self.tell() + self.buffer.read(pos) + return start_pos + pos + elif mode == os.SEEK_END: + self.buffer.seek(0) + dest_position = self.len - pos + self.buffer.read(dest_position) + return dest_position + else: + raise ValueError( + "Invalid whence ({0}, should be 0, 1, or 2)".format(mode) + ) + def readline(self, length=None): return self.buffer.readline(length).decode('utf-8') @@ -314,27 +345,55 @@ class SpooledStringIO(SpooledIOBase): pos = self.buffer.tell() self.seek(0) enc_pos = 0 - while True: - current_pos = self.buffer.tell() + pos_read = False + + current_pos = 0 + + while not pos_read: if current_pos == pos: + # By chance, our enc_pos is some multiple of READ_CHUNK_SIZE. + # We're done! break - ret = self.read(21333) - if self.buffer.tell() < pos: - enc_pos += len(ret) - continue + chunk = self.read(READ_CHUNK_SIZE) + chunk_bytes = len(chunk.encode('utf-8')) - # If the previous check fails, then we've seeked beyond the - # intended position. Go back to our position at the start of the - # loop and iterate one codepoint at a time until we reach the - # buffers position. - self.buffer.seek(current_pos) - while True: - if self.buffer.tell() == pos: - break - self.read(1) + if current_pos + chunk_bytes < pos: + current_pos += chunk_bytes + enc_pos += len(chunk) + continue + else: + # The chunk that we've read should contain the remaining bytes + # needed to put us over the pos value. + pos_read = True + else: + # We're done with the buffer at this point, so seek to the previous + # pos for happy return and error raising. + self.buffer.seek(pos) + + # We've read past the underlying seek value. We now iterate our + # last retrieved chunk until the encoded character bytes plus the + # current_pos add up to the pos. Anything over that means that our + # underlying object seeked to a non code point. + for char in chunk: enc_pos += 1 - break + current_pos += len(char.encode('utf-8')) + + if current_pos == pos: + break + elif current_pos >= pos: + raise IOError( + "SpooledStringIO's underlying EncodedFile seeked to " + "invalid code point!") + + if current_pos != pos: + # If we made it this far and our current_pos doesn't equal our pos, + # something has gone wrong with our flo. + raise IOError( + "SpooledStringIO unable to read to previous file " + "position. Likely data truncation while we were " + "reading it.") + self.buffer.seek(pos) return enc_pos diff --git a/tests/test_ioutils.py b/tests/test_ioutils.py index 6829fea..b3aad6b 100644 --- a/tests/test_ioutils.py +++ b/tests/test_ioutils.py @@ -278,3 +278,18 @@ class TestSpooledStringIO(TestCase, BaseTestMixin, AssertionsMixin): self.spooled_flo.seek(0) self.spooled_flo.read(40) self.assertEqual(self.spooled_flo.tell(), 40) + + def test_codepoints_all_enc(self): + """"Test getting read, seek, tell, on various codepoints""" + test_str = u"\u2014\u2014\u2014" + self.spooled_flo.write(test_str) + self.spooled_flo.seek(1) + self.assertEqual(self.spooled_flo.read(), u"\u2014\u2014") + self.assertEqual(len(self.spooled_flo), len(test_str)) + self.assertEqual(self.spooled_flo.tell(), 3) + + def test_seek_codepoints(self): + """Make seek() moves to positions along codepoints""" + self.spooled_flo.write(self.test_str) + ret = self.spooled_flo.seek(0, os.SEEK_END) + self.assertEqual(ret, len(self.test_str))