Create our own TCP server class.
We're going to need more control for advanced features and speed, and we can also ditch some of the idiocies in the SocketServer module.
This commit is contained in:
parent
c7952371b7
commit
8ae64337ed
|
@ -21,7 +21,7 @@
|
|||
import sys, os, string, socket, time
|
||||
import shutil, tempfile, threading
|
||||
import optparse, SocketServer
|
||||
import utils, flow, certutils, version, wsgi
|
||||
import utils, flow, certutils, version, wsgi, tcpserver
|
||||
from OpenSSL import SSL
|
||||
|
||||
|
||||
|
@ -50,7 +50,7 @@ class ProxyConfig:
|
|||
def read_headers(fp):
|
||||
"""
|
||||
Read a set of headers from a file pointer. Stop once a blank line
|
||||
is reached. Return a ODict object.
|
||||
is reached. Return a ODictCaseless object.
|
||||
"""
|
||||
ret = []
|
||||
name = ''
|
||||
|
@ -374,13 +374,13 @@ class ServerConnection:
|
|||
pass
|
||||
|
||||
|
||||
class ProxyHandler(SocketServer.StreamRequestHandler):
|
||||
def __init__(self, config, request, client_address, server, q):
|
||||
class ProxyHandler(tcpserver.BaseHandler):
|
||||
def __init__(self, config, connection, client_address, server, q):
|
||||
self.mqueue = q
|
||||
self.config = config
|
||||
self.server_conn = None
|
||||
self.proxy_connect_state = None
|
||||
SocketServer.StreamRequestHandler.__init__(self, request, client_address, server)
|
||||
tcpserver.BaseHandler.__init__(self, connection, client_address, server)
|
||||
|
||||
def handle(self):
|
||||
cc = flow.ClientConnect(self.client_address)
|
||||
|
@ -390,7 +390,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
|
|||
cc.close = True
|
||||
cd = flow.ClientDisconnect(cc)
|
||||
cd._send(self.mqueue)
|
||||
self.finish()
|
||||
|
||||
def server_connect(self, scheme, host, port):
|
||||
sc = self.server_conn
|
||||
|
@ -554,18 +553,6 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
|
|||
self.wfile.write(d)
|
||||
self.wfile.flush()
|
||||
|
||||
def terminate(self, connection, wfile, rfile):
|
||||
self.request.close()
|
||||
try:
|
||||
if not getattr(wfile, "closed", False):
|
||||
wfile.flush()
|
||||
connection.close()
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def finish(self):
|
||||
self.terminate(self.connection, self.wfile, self.rfile)
|
||||
|
||||
def send_error(self, code, body):
|
||||
try:
|
||||
import BaseHTTPServer
|
||||
|
@ -584,10 +571,8 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
|
|||
|
||||
class ProxyServerError(Exception): pass
|
||||
|
||||
ServerBase = SocketServer.ThreadingTCPServer
|
||||
ServerBase.daemon_threads = True # Terminate workers when main thread terminates
|
||||
class ProxyServer(ServerBase):
|
||||
request_queue_size = 20
|
||||
|
||||
class ProxyServer(tcpserver.TCPServer):
|
||||
allow_reuse_address = True
|
||||
bound = True
|
||||
def __init__(self, config, port, address=''):
|
||||
|
@ -596,7 +581,7 @@ class ProxyServer(ServerBase):
|
|||
"""
|
||||
self.config, self.port, self.address = config, port, address
|
||||
try:
|
||||
ServerBase.__init__(self, (address, port), ProxyHandler)
|
||||
tcpserver.TCPServer.__init__(self, (address, port))
|
||||
except socket.error, v:
|
||||
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
|
||||
self.masterq = None
|
||||
|
@ -611,11 +596,11 @@ class ProxyServer(ServerBase):
|
|||
def set_mqueue(self, q):
|
||||
self.masterq = q
|
||||
|
||||
def finish_request(self, request, client_address):
|
||||
self.RequestHandlerClass(self.config, request, client_address, self, self.masterq)
|
||||
def handle_connection(self, request, client_address):
|
||||
ProxyHandler(self.config, request, client_address, self, self.masterq)
|
||||
|
||||
def shutdown(self):
|
||||
ServerBase.shutdown(self)
|
||||
tcpserver.TCPServer.shutdown(self)
|
||||
try:
|
||||
shutil.rmtree(self.certdir)
|
||||
except OSError:
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
import select, socket, threading
|
||||
|
||||
class BaseHandler:
|
||||
rbufsize = -1
|
||||
wbufsize = 0
|
||||
def __init__(self, connection, client_address, server):
|
||||
self.connection = connection
|
||||
self.rfile = self.connection.makefile('rb', self.rbufsize)
|
||||
self.wfile = self.connection.makefile('wb', self.wbufsize)
|
||||
|
||||
self.client_address = client_address
|
||||
self.server = server
|
||||
self.handle()
|
||||
self.finish()
|
||||
|
||||
def finish(self):
|
||||
try:
|
||||
if not getattr(self.wfile, "closed", False):
|
||||
self.wfile.flush()
|
||||
self.connection.close()
|
||||
self.wfile.close()
|
||||
self.rfile.close()
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
def handle(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TCPServer:
|
||||
request_queue_size = 20
|
||||
def __init__(self, server_address):
|
||||
self.server_address = server_address
|
||||
self.__is_shut_down = threading.Event()
|
||||
self.__shutdown_request = False
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.bind(self.server_address)
|
||||
self.server_address = self.socket.getsockname()
|
||||
self.socket.listen(self.request_queue_size)
|
||||
|
||||
def fileno(self):
|
||||
return self.socket.fileno()
|
||||
|
||||
def request_thread(self, request, client_address):
|
||||
try:
|
||||
self.handle_connection(request, client_address)
|
||||
request.close()
|
||||
except:
|
||||
self.handle_error(request, client_address)
|
||||
request.close()
|
||||
|
||||
def serve_forever(self, poll_interval=0.5):
|
||||
self.__is_shut_down.clear()
|
||||
try:
|
||||
while not self.__shutdown_request:
|
||||
r, w, e = select.select([self], [], [], poll_interval)
|
||||
if self in r:
|
||||
try:
|
||||
request, client_address = self.socket.accept()
|
||||
except socket.error:
|
||||
return
|
||||
try:
|
||||
t = threading.Thread(target = self.request_thread,
|
||||
args = (request, client_address))
|
||||
t.setDaemon (1)
|
||||
t.start()
|
||||
except:
|
||||
self.handle_error(request, client_address)
|
||||
request.close()
|
||||
finally:
|
||||
self.__shutdown_request = False
|
||||
self.__is_shut_down.set()
|
||||
|
||||
def shutdown(self):
|
||||
self.__shutdown_request = True
|
||||
self.__is_shut_down.wait()
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
print '-'*40
|
||||
print 'Exception happened during processing of request from',
|
||||
print client_address
|
||||
import traceback
|
||||
traceback.print_exc() # XXX But this goes to stderr!
|
||||
print '-'*40
|
||||
|
||||
def handle_connection(self, request, client_address):
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue