Create our own TCP server class.

We're going to need more control for advanced features and speed, and we can
also ditch some of the idiocies in the SocketServer module.
This commit is contained in:
Aldo Cortesi 2012-06-16 11:40:44 +12:00
parent c7952371b7
commit 8ae64337ed
2 changed files with 99 additions and 26 deletions

View File

@ -21,7 +21,7 @@
import sys, os, string, socket, time
import shutil, tempfile, threading
import optparse, SocketServer
import utils, flow, certutils, version, wsgi
import utils, flow, certutils, version, wsgi, tcpserver
from OpenSSL import SSL
@ -50,7 +50,7 @@ class ProxyConfig:
def read_headers(fp):
"""
Read a set of headers from a file pointer. Stop once a blank line
is reached. Return a ODict object.
is reached. Return a ODictCaseless object.
"""
ret = []
name = ''
@ -374,13 +374,13 @@ class ServerConnection:
pass
class ProxyHandler(SocketServer.StreamRequestHandler):
def __init__(self, config, request, client_address, server, q):
class ProxyHandler(tcpserver.BaseHandler):
def __init__(self, config, connection, client_address, server, q):
self.mqueue = q
self.config = config
self.server_conn = None
self.proxy_connect_state = None
SocketServer.StreamRequestHandler.__init__(self, request, client_address, server)
tcpserver.BaseHandler.__init__(self, connection, client_address, server)
def handle(self):
cc = flow.ClientConnect(self.client_address)
@ -390,7 +390,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
cc.close = True
cd = flow.ClientDisconnect(cc)
cd._send(self.mqueue)
self.finish()
def server_connect(self, scheme, host, port):
sc = self.server_conn
@ -554,18 +553,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
self.wfile.write(d)
self.wfile.flush()
def terminate(self, connection, wfile, rfile):
self.request.close()
try:
if not getattr(wfile, "closed", False):
wfile.flush()
connection.close()
except IOError:
pass
def finish(self):
self.terminate(self.connection, self.wfile, self.rfile)
def send_error(self, code, body):
try:
import BaseHTTPServer
@ -584,10 +571,8 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
class ProxyServerError(Exception): pass
ServerBase = SocketServer.ThreadingTCPServer
ServerBase.daemon_threads = True # Terminate workers when main thread terminates
class ProxyServer(ServerBase):
request_queue_size = 20
class ProxyServer(tcpserver.TCPServer):
allow_reuse_address = True
bound = True
def __init__(self, config, port, address=''):
@ -596,7 +581,7 @@ class ProxyServer(ServerBase):
"""
self.config, self.port, self.address = config, port, address
try:
ServerBase.__init__(self, (address, port), ProxyHandler)
tcpserver.TCPServer.__init__(self, (address, port))
except socket.error, v:
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
self.masterq = None
@ -611,11 +596,11 @@ class ProxyServer(ServerBase):
def set_mqueue(self, q):
self.masterq = q
def finish_request(self, request, client_address):
self.RequestHandlerClass(self.config, request, client_address, self, self.masterq)
def handle_connection(self, request, client_address):
ProxyHandler(self.config, request, client_address, self, self.masterq)
def shutdown(self):
ServerBase.shutdown(self)
tcpserver.TCPServer.shutdown(self)
try:
shutil.rmtree(self.certdir)
except OSError:

88
libmproxy/tcpserver.py Normal file
View File

@ -0,0 +1,88 @@
import select, socket, threading
class BaseHandler:
rbufsize = -1
wbufsize = 0
def __init__(self, connection, client_address, server):
self.connection = connection
self.rfile = self.connection.makefile('rb', self.rbufsize)
self.wfile = self.connection.makefile('wb', self.wbufsize)
self.client_address = client_address
self.server = server
self.handle()
self.finish()
def finish(self):
try:
if not getattr(self.wfile, "closed", False):
self.wfile.flush()
self.connection.close()
self.wfile.close()
self.rfile.close()
except IOError:
pass
def handle(self):
raise NotImplementedError
class TCPServer:
request_queue_size = 20
def __init__(self, server_address):
self.server_address = server_address
self.__is_shut_down = threading.Event()
self.__shutdown_request = False
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.server_address)
self.server_address = self.socket.getsockname()
self.socket.listen(self.request_queue_size)
def fileno(self):
return self.socket.fileno()
def request_thread(self, request, client_address):
try:
self.handle_connection(request, client_address)
request.close()
except:
self.handle_error(request, client_address)
request.close()
def serve_forever(self, poll_interval=0.5):
self.__is_shut_down.clear()
try:
while not self.__shutdown_request:
r, w, e = select.select([self], [], [], poll_interval)
if self in r:
try:
request, client_address = self.socket.accept()
except socket.error:
return
try:
t = threading.Thread(target = self.request_thread,
args = (request, client_address))
t.setDaemon (1)
t.start()
except:
self.handle_error(request, client_address)
request.close()
finally:
self.__shutdown_request = False
self.__is_shut_down.set()
def shutdown(self):
self.__shutdown_request = True
self.__is_shut_down.wait()
def handle_error(self, request, client_address):
print '-'*40
print 'Exception happened during processing of request from',
print client_address
import traceback
traceback.print_exc() # XXX But this goes to stderr!
print '-'*40
def handle_connection(self, request, client_address):
raise NotImplementedError