diff --git a/netlib/http.py b/netlib/http.py index 2c9e69cb2..0f2caa5a3 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,5 +1,5 @@ import string, urlparse, binascii -import odict +import odict, utils class HttpError(Exception): def __init__(self, code, msg): @@ -12,6 +12,22 @@ class HttpError(Exception): class HttpErrorConnClosed(HttpError): pass +def _is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def _is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -42,17 +58,11 @@ def parse_url(url): path = urlparse.urlunparse(('', '', path, params, query, fragment)) if not path.startswith("/"): path = "/" + path - try: - host.decode("idna") - except ValueError: + if not _is_valid_host(host): return None - if "\0" in host: + if not utils.isascii(path): return None - try: - path.decode("ascii") - except ValueError: - return None - if not 0 <= port <= 65535: + if not _is_valid_port(port): return None return scheme, host, port, path @@ -236,6 +246,10 @@ def parse_init_connect(line): port = int(port) except ValueError: return None + if not _is_valid_port(port): + return None + if not _is_valid_host(host): + return None return host, port, httpversion @@ -260,7 +274,8 @@ def parse_init_http(line): if not v: return None method, url, httpversion = v - + if not utils.isascii(url): + return None if not (url.startswith("/") or url == "*"): return None return method, url, httpversion diff --git a/netlib/utils.py b/netlib/utils.py index 7621a1dc0..61fd54aef 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,4 +1,12 @@ +def isascii(s): + try: + s.decode("ascii") + except ValueError: + return False + return True + + def cleanBin(s, fixspacing=False): """ Cleans binary data to make it safe to display. If fixspacing is True, diff --git a/test/test_http.py b/test/test_http.py index f7d861fd8..e98a891f8 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -136,6 +136,8 @@ def test_parse_http_protocol(): def test_parse_init_connect(): assert http.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not http.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not http.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") assert not http.parse_init_connect("bogus") assert not http.parse_init_connect("GET host.com:443 HTTP/1.0") assert not http.parse_init_connect("CONNECT host.com443 HTTP/1.0") @@ -164,11 +166,10 @@ def test_parse_init_http(): assert m == "GET" assert u == "/test" assert httpversion == (1, 1) - assert not http.parse_init_http("invalid") assert not http.parse_init_http("GET invalid HTTP/1.1") assert not http.parse_init_http("GET /test foo/1.1") - + assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") class TestReadHeaders: def _read(self, data, verbatim=False):