diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 26e3c2c46..9a227010d 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -1,8 +1,28 @@ from __future__ import (absolute_import, print_function, division) from argparse import Action, ArgumentTypeError +import binascii from .. import http +def parse_http_basic_auth(s): + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + class NullProxyAuth(object): @@ -47,7 +67,7 @@ class BasicProxyAuth(NullProxyAuth): auth_value = headers.get(self.AUTH_HEADER, []) if not auth_value: return False - parts = http.http1.parse_http_basic_auth(auth_value[0]) + parts = parse_http_basic_auth(auth_value[0]) if not parts: return False scheme, username, password = parts diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 0f7a0bd36..97c119a95 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -85,22 +85,9 @@ def read_chunked(fp, limit, is_request): return -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks - - def has_chunked_encoding(headers): return "chunked" in [ - i.lower() for i in get_header_tokens(headers, "transfer-encoding") + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") ] @@ -123,28 +110,6 @@ def parse_http_protocol(s): return major, minor -def parse_http_basic_auth(s): - # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]) - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v - - def parse_init(line): try: method, url, protocol = string.split(line) @@ -221,7 +186,7 @@ def connection_close(httpversion, headers): """ # At first, check if we have an explicit Connection header. if "connection" in headers: - toks = get_header_tokens(headers, "connection") + toks = http.get_header_tokens(headers, "connection") if "close" in toks: return True elif "keep-alive" in toks: diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e7e84fe35..a62c93e3e 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -49,7 +49,6 @@ def is_valid_host(host): return True - def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -92,3 +91,16 @@ def parse_url(url): if not is_valid_port(port): return None return scheme, host, port, path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 05e828311..d0a2ee025 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -71,13 +71,13 @@ def test_connection_close(): def test_get_header_tokens(): h = odict.ODictCaseless() - assert protocol.get_header_tokens(h, "foo") == [] + assert http.get_header_tokens(h, "foo") == [] h["foo"] = ["bar"] - assert protocol.get_header_tokens(h, "foo") == ["bar"] + assert http.get_header_tokens(h, "foo") == ["bar"] h["foo"] = ["bar, voing"] - assert protocol.get_header_tokens(h, "foo") == ["bar", "voing"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing"] h["foo"] = ["bar, voing", "oink"] - assert protocol.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] def test_read_http_body_request(): @@ -357,17 +357,6 @@ def test_read_response(): assert tst(data, "GET", None, include_body=False).content is None -def test_parse_http_basic_auth(): - vals = ("basic", "foo", "bar") - assert protocol.parse_http_basic_auth( - protocol.assemble_http_basic_auth(*vals) - ) == vals - assert not protocol.parse_http_basic_auth("") - assert not protocol.parse_http_basic_auth("foo bar") - v = "basic " + binascii.b2a_base64("foo") - assert not protocol.parse_http_basic_auth(v) - - def test_get_request_line(): r = cStringIO.StringIO("\nfoo") assert protocol.get_request_line(r) == "foo" diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index c0dae1a28..8f231643f 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -1,8 +1,21 @@ +import binascii + from netlib import odict, http from netlib.http import authentication from .. import tutils +def test_parse_http_basic_auth(): + vals = ("basic", "foo", "bar") + assert http.authentication.parse_http_basic_auth( + http.authentication.assemble_http_basic_auth(*vals) + ) == vals + assert not http.authentication.parse_http_basic_auth("") + assert not http.authentication.parse_http_basic_auth("foo bar") + v = "basic " + binascii.b2a_base64("foo") + assert not http.authentication.parse_http_basic_auth(v) + + class TestPassManNonAnon: def test_simple(self): @@ -23,7 +36,7 @@ class TestPassManHtpasswd: pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) vals = ("basic", "test", "test") - http.http1.assemble_http_basic_auth(*vals) + authentication.assemble_http_basic_auth(*vals) assert pm.test("test", "test") assert not pm.test("test", "foo") assert not pm.test("foo", "test") @@ -62,7 +75,7 @@ class TestBasicProxyAuth: hdrs = odict.ODictCaseless() vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert ba.authenticate(hdrs) ba.clean(hdrs) @@ -75,12 +88,12 @@ class TestBasicProxyAuth: assert not ba.authenticate(hdrs) vals = ("foo", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs) ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs)