diff --git a/libmproxy/netlib.py b/libmproxy/netlib.py index 7294cbe0b..08ccba091 100644 --- a/libmproxy/netlib.py +++ b/libmproxy/netlib.py @@ -1,4 +1,4 @@ -import select, socket, threading, traceback +import select, socket, threading, traceback, sys from OpenSSL import SSL @@ -20,8 +20,6 @@ class FileLike: while len(result) < length: try: data = self.o.read(length) - except AttributeError: - break except SSL.ZeroReturnError: break if not data: @@ -52,7 +50,7 @@ class FileLike: class TCPClient: def __init__(self, ssl, host, port, clientcert): self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert - self.sock, self.rfile, self.wfile = None, None, None + self.connection, self.rfile, self.wfile = None, None, None self.cert = None self.connect() @@ -73,7 +71,7 @@ class TCPClient: self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) - self.sock = server + self.connection = server class BaseHandler: @@ -105,10 +103,10 @@ class BaseHandler: self.connection.close() self.wfile.close() self.rfile.close() - except IOError: + except IOError: # pragma: no cover pass - def handle(self): + def handle(self): # pragma: no cover raise NotImplementedError @@ -123,6 +121,7 @@ class TCPServer: self.socket.bind(self.server_address) self.server_address = self.socket.getsockname() self.socket.listen(self.request_queue_size) + self.port = self.socket.getsockname()[1] def request_thread(self, request, client_address): try: @@ -143,9 +142,11 @@ class TCPServer: except socket.error: return try: - t = threading.Thread(target = self.request_thread, - args = (request, client_address)) - t.setDaemon (1) + t = threading.Thread( + target = self.request_thread, + args = (request, client_address) + ) + t.setDaemon(1) t.start() except: self.handle_error(request, client_address) @@ -159,16 +160,16 @@ class TCPServer: self.__is_shut_down.wait() self.handle_shutdown() - def handle_error(self, request, client_address): + def handle_error(self, request, client_address, fp=sys.stderr): """ Called when handle_connection raises an exception. """ - print >> sys.stderr, '-'*40 - print >> sys.stderr, "Error processing of request from %s"%client_address - traceback.print_exc() - print >> sys.stderr, '-'*40 + print >> fp, '-'*40 + print >> fp, "Error processing of request from %s:%s"%client_address + print >> fp, traceback.format_exc() + print >> fp, '-'*40 - def handle_connection(self, request, client_address): + def handle_connection(self, request, client_address): # pragma: no cover """ Called after client connection. """ diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 54536b39a..9ebe01539 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -255,7 +255,7 @@ class ServerConnection(netlib.TCPClient): netlib.TCPClient.__init__( self, True if scheme == "https" else False, - host, + host, port, clientcert ) @@ -305,7 +305,7 @@ class ServerConnection(netlib.TCPClient): try: if not self.wfile.closed: self.wfile.flush() - self.sock.close() + self.connection.close() except IOError: pass diff --git a/test/test_netlib.py b/test/test_netlib.py index 2b76c9cfd..12aa2acc1 100644 --- a/test/test_netlib.py +++ b/test/test_netlib.py @@ -1,5 +1,81 @@ -import cStringIO +import cStringIO, threading, Queue from libmproxy import netlib +import tutils + +class ServerThread(threading.Thread): + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase: + @classmethod + def setupAll(cls): + cls.server = ServerThread(cls.makeserver()) + cls.server.start() + + @classmethod + def teardownAll(cls): + cls.server.shutdown() + + +class THandler(netlib.BaseHandler): + def handle(self): + v = self.rfile.readline() + if v.startswith("echo"): + self.wfile.write(v) + elif v.startswith("error"): + raise ValueError("Testing an error.") + self.wfile.flush() + + +class TServer(netlib.TCPServer): + def __init__(self, addr, q): + netlib.TCPServer.__init__(self, addr) + self.q = q + + def handle_connection(self, request, client_address): + THandler(request, client_address, self) + + def handle_error(self, request, client_address): + s = cStringIO.StringIO() + netlib.TCPServer.handle_error(self, request, client_address, s) + self.q.put(s.getvalue()) + + +class TestServer(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), cls.q) + cls.port = s.port + return s + + def test_echo(self): + testval = "echo!\n" + c = netlib.TCPClient(False, "127.0.0.1", self.port, None) + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + def test_error(self): + testval = "error!\n" + c = netlib.TCPClient(False, "127.0.0.1", self.port, None) + c.wfile.write(testval) + c.wfile.flush() + assert "Testing an error" in self.q.get() + + + +class TestTCPClient: + def test_conerr(self): + tutils.raises(netlib.NetLibError, netlib.TCPClient, False, "127.0.0.1", 0, None) class TestFileLike: @@ -12,4 +88,8 @@ class TestFileLike: # Test __getattr__ assert s.isatty + def test_limit(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = netlib.FileLike(s) + assert s.readline(3) == "foo"