diff --git a/libmproxy/authentication.py b/libmproxy/authentication.py index c928ebbdb..675f5dc5d 100644 --- a/libmproxy/authentication.py +++ b/libmproxy/authentication.py @@ -32,8 +32,8 @@ class BasicProxyAuth(NullProxyAuth): if not auth_value: return False try: - scheme, username, password = self.parse_authorization_header(auth_value[0]) - except: + scheme, username, password = self.parse_auth_value(auth_value[0]) + except ValueError: return False if scheme.lower()!='basic': return False @@ -45,12 +45,23 @@ class BasicProxyAuth(NullProxyAuth): def auth_challenge_headers(self): return {'Proxy-Authenticate':'Basic realm="%s"'%self.realm} - def parse_authorization_header(self, auth_value): + def unparse_auth_value(self, scheme, username, password): + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + + def parse_auth_value(self, auth_value): words = auth_value.split() + if len(words) != 2: + raise ValueError("Invalid basic auth credential.") scheme = words[0] - user = binascii.a2b_base64(words[1]) - username, password = user.split(':') - return scheme, username, password + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + raise ValueError("Invalid basic auth credential: user:password pair not valid base64: %s"%words[1]) + parts = user.split(':') + if len(parts) != 2: + raise ValueError("Invalid basic auth credential: decoded user:password pair not valid: %s"%user) + return scheme, parts[0], parts[1] class PasswordManager(): diff --git a/test/test_authentication.py b/test/test_authentication.py index 25714263a..cc797d68b 100644 --- a/test/test_authentication.py +++ b/test/test_authentication.py @@ -1,5 +1,7 @@ +import binascii from libmproxy import authentication from netlib import odict +import tutils class TestNullProxyAuth: @@ -16,3 +18,13 @@ class TestBasicProxyAuth: assert ba.auth_challenge_headers() assert not ba.authenticate(h) + def test_parse_auth_value(self): + ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager()) + vals = ("basic", "foo", "bar") + assert ba.parse_auth_value(ba.unparse_auth_value(*vals)) == vals + tutils.raises(ValueError, ba.parse_auth_value, "") + tutils.raises(ValueError, ba.parse_auth_value, "foo bar") + + v = "basic " + binascii.b2a_base64("foo") + tutils.raises(ValueError, ba.parse_auth_value, v) +