Handle port numbers in host header

from: https://github.com/mitmproxy/netlib/pull/121
This commit is contained in:
Shadab Zafar 2016-02-17 08:48:59 +05:30
parent 887ecf8896
commit 6f96da08c9
2 changed files with 27 additions and 3 deletions

View File

@ -1,5 +1,6 @@
from __future__ import absolute_import, print_function, division
import re
import warnings
import six
@ -12,6 +13,10 @@ from .. import encoding
from .headers import Headers
from .message import Message, _native, _always_bytes, MessageData
# This regex extracts & splits the host header into host and port.
# Handles the edge case of IPv6 addresses containing colons.
# https://bugzilla.mozilla.org/show_bug.cgi?id=45891
host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
class RequestData(MessageData):
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None,
@ -159,6 +164,18 @@ class Request(Message):
def url(self, url):
self.scheme, self.host, self.port, self.path = utils.parse_url(url)
def _parse_host_header(self):
"""Extract the host and port from Host header"""
if "host" not in self.headers:
return None, None
host, port = self.headers["host"], None
m = host_header_re.match(host)
if m:
host = m.group("host").strip("[]")
if m.group("port"):
port = int(m.group("port"))
return host, port
@property
def pretty_host(self):
"""
@ -166,16 +183,19 @@ class Request(Message):
This is useful in transparent mode where :py:attr:`host` is only an IP address,
but may not reflect the actual destination as the Host header could be spoofed.
"""
return self.headers.get("host", self.host)
return self._parse_host_header()[0] or self.host
@property
def pretty_url(self):
"""
Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`.
"""
host, port = self._parse_host_header()
host = host or self.host
port = port or self.port
if self.first_line_format == "authority":
return "%s:%d" % (self.pretty_host, self.port)
return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path)
return "%s:%d" % (host, port)
return utils.unparse_url(self.scheme, host, port, self.path)
@property
def query(self):

View File

@ -106,6 +106,8 @@ class TestRequestUtils(object):
request = treq()
assert request.pretty_host == "address"
assert request.host == "address"
request.headers["host"] = "other:22"
assert request.pretty_host == "other"
request.headers["host"] = "other"
assert request.pretty_host == "other"
assert request.host == "address"
@ -123,6 +125,8 @@ class TestRequestUtils(object):
assert request.pretty_url == "http://address:22/path"
request.headers["host"] = "other"
assert request.pretty_url == "http://other:22/path"
request.headers["host"] = "other:33"
assert request.pretty_url == "http://other:33/path"
def test_pretty_url_authority(self):
request = treq(first_line_format="authority")