diff --git a/libmproxy/authentication.py b/libmproxy/authentication.py index 675f5dc5d..1f1f40ae2 100644 --- a/libmproxy/authentication.py +++ b/libmproxy/authentication.py @@ -9,9 +9,15 @@ class NullProxyAuth(): self.password_manager = password_manager self.username = "" + def clean(self, headers): + """ + Clean up authentication headers, so they're not passed upstream. + """ + pass + def authenticate(self, headers): """ - Tests that the specified user is allowed to use the proxy (stub) + Tests that the user is allowed to use the proxy """ return True @@ -23,12 +29,17 @@ class NullProxyAuth(): class BasicProxyAuth(NullProxyAuth): - def __init__(self, password_manager, realm="mitmproxy"): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + def __init__(self, password_manager, realm): NullProxyAuth.__init__(self, password_manager) - self.realm = "mitmproxy" + self.realm = realm + + def clean(self, headers): + del headers[self.AUTH_HEADER] def authenticate(self, headers): - auth_value = headers.get('Proxy-Authorization', []) + auth_value = headers.get(self.AUTH_HEADER, []) if not auth_value: return False try: @@ -43,7 +54,7 @@ class BasicProxyAuth(NullProxyAuth): return True def auth_challenge_headers(self): - return {'Proxy-Authenticate':'Basic realm="%s"'%self.realm} + return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm} def unparse_auth_value(self, scheme, username, password): v = binascii.b2a_base64(username + ":" + password) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 2c62a8808..0cba4cbca 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -356,12 +356,15 @@ class ProxyHandler(tcp.BaseHandler): headers = http.read_headers(self.rfile) if headers is None: raise ProxyError(400, "Invalid headers") - if authenticate and self.config.authenticator and not self.config.authenticator.authenticate(headers): - raise ProxyError( - 407, - "Proxy Authentication Required", - self.config.authenticator.auth_challenge_headers() - ) + if authenticate and self.config.authenticator: + if self.config.authenticator.authenticate(headers): + self.config.authenticator.clean(headers) + else: + raise ProxyError( + 407, + "Proxy Authentication Required", + self.config.authenticator.auth_challenge_headers() + ) return headers def send_response(self, response): @@ -552,7 +555,7 @@ def process_proxy_options(parser, options): password_manager = authentication.HtpasswdPasswordManager(options.auth_htpasswd) # in the meanwhile, basic auth is the only true authentication scheme we support # so just use it - authenticator = authentication.BasicProxyAuth(password_manager) + authenticator = authentication.BasicProxyAuth(password_manager, "mitmproxy") else: authenticator = authentication.NullProxyAuth(None) diff --git a/test/test_authentication.py b/test/test_authentication.py index cc797d68b..f7a5ecd34 100644 --- a/test/test_authentication.py +++ b/test/test_authentication.py @@ -9,17 +9,18 @@ class TestNullProxyAuth: na = authentication.NullProxyAuth(authentication.PermissivePasswordManager()) assert not na.auth_challenge_headers() assert na.authenticate("foo") + na.clean({}) class TestBasicProxyAuth: def test_simple(self): - ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager()) + ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager(), "test") h = odict.ODictCaseless() assert ba.auth_challenge_headers() assert not ba.authenticate(h) def test_parse_auth_value(self): - ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager()) + ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager(), "test") vals = ("basic", "foo", "bar") assert ba.parse_auth_value(ba.unparse_auth_value(*vals)) == vals tutils.raises(ValueError, ba.parse_auth_value, "") @@ -28,3 +29,30 @@ class TestBasicProxyAuth: v = "basic " + binascii.b2a_base64("foo") tutils.raises(ValueError, ba.parse_auth_value, v) + def test_authenticate_clean(self): + ba = authentication.BasicProxyAuth(authentication.PermissivePasswordManager(), "test") + + hdrs = odict.ODictCaseless() + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [ba.unparse_auth_value(*vals)] + assert ba.authenticate(hdrs) + + ba.clean(hdrs) + assert not ba.AUTH_HEADER in hdrs + + + hdrs[ba.AUTH_HEADER] = [""] + assert not ba.authenticate(hdrs) + + hdrs[ba.AUTH_HEADER] = ["foo"] + assert not ba.authenticate(hdrs) + + vals = ("foo", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [ba.unparse_auth_value(*vals)] + assert not ba.authenticate(hdrs) + + ba = authentication.BasicProxyAuth(authentication.PasswordManager(), "test") + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [ba.unparse_auth_value(*vals)] + assert not ba.authenticate(hdrs) +