Merge branch 'better-replace'

This commit is contained in:
Maximilian Hils 2016-04-03 08:17:30 -07:00
commit 0259f47997
10 changed files with 100 additions and 65 deletions

View File

@ -26,30 +26,6 @@ class MessageMixin(object):
return self.content
return encoding.decode(ce, self.content)
def replace(self, pattern, repl, *args, **kwargs):
"""
Replaces a regular expression pattern with repl in both the headers
and the body of the message. Encoded body will be decoded
before replacement, and re-encoded afterwards.
Returns the number of replacements made.
"""
count = 0
if self.content:
with decoded(self):
self.content, count = utils.safe_subn(
pattern, repl, self.content, *args, **kwargs
)
fields = []
for name, value in self.headers.fields:
name, c = utils.safe_subn(pattern, repl, name, *args, **kwargs)
count += c
value, c = utils.safe_subn(pattern, repl, value, *args, **kwargs)
count += c
fields.append([name, value])
self.headers.fields = fields
return count
class HTTPRequest(MessageMixin, Request):
@ -165,22 +141,6 @@ class HTTPRequest(MessageMixin, Request):
def __hash__(self):
return id(self)
def replace(self, pattern, repl, *args, **kwargs):
"""
Replaces a regular expression pattern with repl in the headers, the
request path and the body of the request. Encoded content will be
decoded before replacement, and re-encoded afterwards.
Returns the number of replacements made.
"""
c = MessageMixin.replace(self, pattern, repl, *args, **kwargs)
self.path, pc = utils.safe_subn(
pattern, repl, self.path, *args, **kwargs
)
c += pc
return c
class HTTPResponse(MessageMixin, Response):
"""

View File

@ -165,12 +165,3 @@ def parse_size(s):
return int(s) * mult
except ValueError:
raise ValueError("Invalid size specification: %s" % s)
def safe_subn(pattern, repl, target, *args, **kwargs):
"""
There are Unicode conversion problems with re.subn. We try to smooth
that over by casting the pattern and replacement to strings. We really
need a better solution that is aware of the actual content ecoding.
"""
return re.subn(str(pattern), str(repl), target, *args, **kwargs)

View File

@ -6,6 +6,8 @@ See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
"""
from __future__ import absolute_import, print_function, division
import re
try:
from collections.abc import MutableMapping
except ImportError: # pragma: no cover
@ -198,4 +200,31 @@ class Headers(MutableMapping, Serializable):
@classmethod
def from_state(cls, state):
return cls([list(field) for field in state])
return cls([list(field) for field in state])
@_always_byte_args
def replace(self, pattern, repl, flags=0):
"""
Replaces a regular expression pattern with repl in each "name: value"
header line.
Returns:
The number of replacements made.
"""
pattern = re.compile(pattern, flags)
replacements = 0
fields = []
for name, value in self.fields:
line, n = pattern.subn(repl, name + b": " + value)
try:
name, value = line.split(b": ", 1)
except ValueError:
# We get a ValueError if the replacement removed the ": "
# There's not much we can do about this, so we just keep the header as-is.
pass
else:
replacements += n
fields.append([name, value])
self.fields = fields
return replacements

View File

@ -175,6 +175,25 @@ class Message(utils.Serializable):
self.headers["content-encoding"] = e
return True
def replace(self, pattern, repl, flags=0):
"""
Replaces a regular expression pattern with repl in both the headers
and the body of the message. Encoded body will be decoded
before replacement, and re-encoded afterwards.
Returns:
The number of replacements made.
"""
# TODO: Proper distinction between text and bytes.
replacements = 0
if self.content:
with decoded(self):
self.content, replacements = utils.safe_subn(
pattern, repl, self.content, flags=flags
)
replacements += self.headers.replace(pattern, repl, flags)
return replacements
# Legacy
@property

View File

@ -54,6 +54,23 @@ class Request(Message):
self.method, hostport, path
)
def replace(self, pattern, repl, flags=0):
"""
Replaces a regular expression pattern with repl in the headers, the
request path and the body of the request. Encoded content will be
decoded before replacement, and re-encoded afterwards.
Returns:
The number of replacements made.
"""
# TODO: Proper distinction between text and bytes.
c = super(Request, self).replace(pattern, repl, flags)
self.path, pc = utils.safe_subn(
pattern, repl, self.path, flags=flags
)
c += pc
return c
@property
def first_line_format(self):
"""

View File

@ -1,18 +1,8 @@
from __future__ import (absolute_import, print_function, division)
import re
import copy
import six
from .utils import Serializable
def safe_subn(pattern, repl, target, *args, **kwargs):
"""
There are Unicode conversion problems with re.subn. We try to smooth
that over by casting the pattern and replacement to strings. We really
need a better solution that is aware of the actual content ecoding.
"""
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
from .utils import Serializable, safe_subn
class ODict(Serializable):

View File

@ -414,8 +414,18 @@ def http2_read_raw_frame(rfile):
body = rfile.safe_read(length)
return [header, body]
def http2_read_frame(rfile):
header, body = http2_read_raw_frame(rfile)
frame, length = hyperframe.frame.Frame.parse_frame_header(header)
frame.parse_body(memoryview(body))
return frame
def safe_subn(pattern, repl, target, *args, **kwargs):
"""
There are Unicode conversion problems with re.subn. We try to smooth
that over by casting the pattern and replacement to strings. We really
need a better solution that is aware of the actual content ecoding.
"""
return re.subn(str(pattern), str(repl), target, *args, **kwargs)

View File

@ -99,7 +99,3 @@ def test_parse_size():
assert utils.parse_size("1g") == 1024**3
tutils.raises(ValueError, utils.parse_size, "1f")
tutils.raises(ValueError, utils.parse_size, "ak")
def test_safe_subn():
assert utils.safe_subn("foo", u"bar", "\xc2foo")

View File

@ -150,3 +150,22 @@ class TestHeaders(object):
assert headers != headers2
headers2.set_state(headers.get_state())
assert headers == headers2
def test_replace_simple(self):
headers = Headers(Host="example.com", Accept="text/plain")
replacements = headers.replace("Host: ", "X-Host: ")
assert replacements == 1
assert headers["X-Host"] == "example.com"
assert "Host" not in headers
assert headers["Accept"] == "text/plain"
def test_replace_multi(self):
headers = self._2host()
headers.replace(r"Host: example\.com", r"Host: example.de")
assert headers.get_all("Host") == ["example.de", "example.org"]
def test_replace_remove_spacer(self):
headers = Headers(Host="example.com")
replacements = headers.replace(r"Host: ", "X-Host ")
assert replacements == 0
assert headers["Host"] == "example.com"

View File

@ -166,3 +166,7 @@ class TestSerializable:
a.set_state(1)
assert a.i == 1
assert b.i == 42
def test_safe_subn():
assert utils.safe_subn("foo", u"bar", "\xc2foo")