diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index bcbc8ea5a..89493e790 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -21,7 +21,7 @@ import sys, os, string, socket, time import shutil, tempfile, threading import optparse, SocketServer -import utils, flow, certutils, version, wsgi +import utils, flow, certutils, version, wsgi, tcpserver from OpenSSL import SSL @@ -50,7 +50,7 @@ class ProxyConfig: def read_headers(fp): """ Read a set of headers from a file pointer. Stop once a blank line - is reached. Return a ODict object. + is reached. Return a ODictCaseless object. """ ret = [] name = '' @@ -374,13 +374,13 @@ class ServerConnection: pass -class ProxyHandler(SocketServer.StreamRequestHandler): - def __init__(self, config, request, client_address, server, q): +class ProxyHandler(tcpserver.BaseHandler): + def __init__(self, config, connection, client_address, server, q): self.mqueue = q self.config = config self.server_conn = None self.proxy_connect_state = None - SocketServer.StreamRequestHandler.__init__(self, request, client_address, server) + tcpserver.BaseHandler.__init__(self, connection, client_address, server) def handle(self): cc = flow.ClientConnect(self.client_address) @@ -390,7 +390,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler): cc.close = True cd = flow.ClientDisconnect(cc) cd._send(self.mqueue) - self.finish() def server_connect(self, scheme, host, port): sc = self.server_conn @@ -554,18 +553,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler): self.wfile.write(d) self.wfile.flush() - def terminate(self, connection, wfile, rfile): - self.request.close() - try: - if not getattr(wfile, "closed", False): - wfile.flush() - connection.close() - except IOError: - pass - - def finish(self): - self.terminate(self.connection, self.wfile, self.rfile) - def send_error(self, code, body): try: import BaseHTTPServer @@ -584,10 +571,8 @@ class ProxyHandler(SocketServer.StreamRequestHandler): class ProxyServerError(Exception): pass -ServerBase = SocketServer.ThreadingTCPServer -ServerBase.daemon_threads = True # Terminate workers when main thread terminates -class ProxyServer(ServerBase): - request_queue_size = 20 + +class ProxyServer(tcpserver.TCPServer): allow_reuse_address = True bound = True def __init__(self, config, port, address=''): @@ -596,7 +581,7 @@ class ProxyServer(ServerBase): """ self.config, self.port, self.address = config, port, address try: - ServerBase.__init__(self, (address, port), ProxyHandler) + tcpserver.TCPServer.__init__(self, (address, port)) except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) self.masterq = None @@ -611,11 +596,11 @@ class ProxyServer(ServerBase): def set_mqueue(self, q): self.masterq = q - def finish_request(self, request, client_address): - self.RequestHandlerClass(self.config, request, client_address, self, self.masterq) + def handle_connection(self, request, client_address): + ProxyHandler(self.config, request, client_address, self, self.masterq) def shutdown(self): - ServerBase.shutdown(self) + tcpserver.TCPServer.shutdown(self) try: shutil.rmtree(self.certdir) except OSError: diff --git a/libmproxy/tcpserver.py b/libmproxy/tcpserver.py new file mode 100644 index 000000000..bf7ed0b43 --- /dev/null +++ b/libmproxy/tcpserver.py @@ -0,0 +1,88 @@ +import select, socket, threading + +class BaseHandler: + rbufsize = -1 + wbufsize = 0 + def __init__(self, connection, client_address, server): + self.connection = connection + self.rfile = self.connection.makefile('rb', self.rbufsize) + self.wfile = self.connection.makefile('wb', self.wbufsize) + + self.client_address = client_address + self.server = server + self.handle() + self.finish() + + def finish(self): + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.connection.close() + self.wfile.close() + self.rfile.close() + except IOError: + pass + + def handle(self): + raise NotImplementedError + + +class TCPServer: + request_queue_size = 20 + def __init__(self, server_address): + self.server_address = server_address + self.__is_shut_down = threading.Event() + self.__shutdown_request = False + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind(self.server_address) + self.server_address = self.socket.getsockname() + self.socket.listen(self.request_queue_size) + + def fileno(self): + return self.socket.fileno() + + def request_thread(self, request, client_address): + try: + self.handle_connection(request, client_address) + request.close() + except: + self.handle_error(request, client_address) + request.close() + + def serve_forever(self, poll_interval=0.5): + self.__is_shut_down.clear() + try: + while not self.__shutdown_request: + r, w, e = select.select([self], [], [], poll_interval) + if self in r: + try: + request, client_address = self.socket.accept() + except socket.error: + return + try: + t = threading.Thread(target = self.request_thread, + args = (request, client_address)) + t.setDaemon (1) + t.start() + except: + self.handle_error(request, client_address) + request.close() + finally: + self.__shutdown_request = False + self.__is_shut_down.set() + + def shutdown(self): + self.__shutdown_request = True + self.__is_shut_down.wait() + + def handle_error(self, request, client_address): + print '-'*40 + print 'Exception happened during processing of request from', + print client_address + import traceback + traceback.print_exc() # XXX But this goes to stderr! + print '-'*40 + + def handle_connection(self, request, client_address): + raise NotImplementedError