Rename our tcpserver to netlib, expand to include client network functions.

This commit is contained in:
Aldo Cortesi 2012-06-16 13:38:10 +12:00
parent 8ae64337ed
commit 4e53f1ee90
4 changed files with 128 additions and 103 deletions

View File

@ -1,4 +1,80 @@
import select, socket, threading
from OpenSSL import SSL
class NetLibError(Exception): pass
class FileLike:
def __init__(self, o):
self.o = o
def __getattr__(self, attr):
return getattr(self.o, attr)
def flush(self):
pass
def read(self, length):
result = ''
while len(result) < length:
try:
data = self.o.read(length)
except AttributeError:
break
except SSL.ZeroReturnError:
break
if not data:
break
result += data
return result
def write(self, v):
self.o.sendall(v)
def readline(self, size = None):
result = ''
bytes_read = 0
while True:
if size is not None and bytes_read >= size:
break
ch = self.read(1)
bytes_read += 1
if not ch:
break
else:
result += ch
if ch == '\n':
break
return result
class TCPClient:
def __init__(self, ssl, host, port, clientcert):
self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert
self.sock, self.rfile, self.wfile = None, None, None
self.cert = None
self.connect()
def connect(self):
try:
addr = socket.gethostbyname(self.host)
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if self.ssl:
context = SSL.Context(SSL.SSLv23_METHOD)
if self.clientcert:
context.use_certificate_file(self.clientcert)
server = SSL.Connection(context, server)
server.connect((addr, self.port))
if self.ssl:
self.cert = server.get_peer_certificate()
self.rfile, self.wfile = FileLike(server), FileLike(server)
else:
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
except socket.error, err:
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
self.sock = server
class BaseHandler:
rbufsize = -1
@ -13,6 +89,15 @@ class BaseHandler:
self.handle()
self.finish()
def convert_to_ssl(self, cert, key):
ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_privatekey_file(key)
ctx.use_certificate_file(cert)
self.connection = SSL.Connection(ctx, self.connection)
self.connection.set_accept_state()
self.rfile = FileLike(self.connection)
self.wfile = FileLike(self.connection)
def finish(self):
try:
if not getattr(self.wfile, "closed", False):

View File

@ -21,7 +21,7 @@
import sys, os, string, socket, time
import shutil, tempfile, threading
import optparse, SocketServer
import utils, flow, certutils, version, wsgi, tcpserver
import utils, flow, certutils, version, wsgi, netlib
from OpenSSL import SSL
@ -232,50 +232,6 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit):
return read_http_body(rfile, headers, False, limit)
class FileLike:
def __init__(self, o):
self.o = o
def __getattr__(self, attr):
return getattr(self.o, attr)
def flush(self):
pass
def read(self, length):
result = ''
while len(result) < length:
try:
data = self.o.read(length)
except AttributeError:
break
except SSL.ZeroReturnError:
break
if not data:
break
result += data
return result
def write(self, v):
self.o.sendall(v)
def readline(self, size = None):
result = ''
bytes_read = 0
while True:
if size is not None and bytes_read >= size:
break
ch = self.read(1)
bytes_read += 1
if not ch:
break
else:
result += ch
if ch == '\n':
break
return result
class RequestReplayThread(threading.Thread):
def __init__(self, config, flow, masterq):
self.config, self.flow, self.masterq = config, flow, masterq
@ -291,41 +247,27 @@ class RequestReplayThread(threading.Thread):
except ProxyError, v:
err = flow.Error(self.flow.request, v.msg)
err._send(self.masterq)
except netlib.NetLibError, v:
raise ProxyError(502, v)
class ServerConnection:
class ServerConnection(netlib.TCPClient):
def __init__(self, config, scheme, host, port):
self.config, self.scheme, self.host, self.port = config, scheme, host, port
self.cert = None
self.sock, self.rfile, self.wfile = None, None, None
self.connect()
clientcert = None
if config.clientcerts:
path = os.path.join(config.clientcerts, self.host) + ".pem"
if os.path.exists(clientcert):
clientcert = path
netlib.TCPClient.__init__(
self,
True if scheme == "https" else False,
host,
port,
clientcert
)
self.config, self.scheme = config, scheme
self.requestcount = 0
def connect(self):
try:
addr = socket.gethostbyname(self.host)
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if self.scheme == "https":
if self.config.clientcerts:
clientcert = os.path.join(self.config.clientcerts, self.host) + ".pem"
if not os.path.exists(clientcert):
clientcert = None
else:
clientcert = None
context = SSL.Context(SSL.SSLv23_METHOD)
if clientcert:
context.use_certificate_file(clientcert)
server = SSL.Connection(context, server)
server.connect((addr, self.port))
if self.scheme == "https":
self.cert = server.get_peer_certificate()
self.rfile, self.wfile = FileLike(server), FileLike(server)
else:
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
except socket.error, err:
raise ProxyError(502, 'Error connecting to "%s": %s' % (self.host, err))
self.sock = server
def send(self, request):
self.requestcount += 1
try:
@ -374,13 +316,13 @@ class ServerConnection:
pass
class ProxyHandler(tcpserver.BaseHandler):
class ProxyHandler(netlib.BaseHandler):
def __init__(self, config, connection, client_address, server, q):
self.mqueue = q
self.config = config
self.server_conn = None
self.proxy_connect_state = None
tcpserver.BaseHandler.__init__(self, connection, client_address, server)
netlib.BaseHandler.__init__(self, connection, client_address, server)
def handle(self):
cc = flow.ClientConnect(self.client_address)
@ -397,7 +339,10 @@ class ProxyHandler(tcpserver.BaseHandler):
sc.terminate()
self.server_conn = None
if not self.server_conn:
self.server_conn = ServerConnection(self.config, scheme, host, port)
try:
self.server_conn = ServerConnection(self.config, scheme, host, port)
except netlib.NetLibError, v:
raise ProxyError(502, v)
def handle_request(self, cc):
try:
@ -473,15 +418,6 @@ class ProxyHandler(tcpserver.BaseHandler):
raise ProxyError(502, "mitmproxy: Unable to generate dummy cert.")
return ret
def convert_to_ssl(self, cert):
ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_privatekey_file(self.config.certfile or self.config.cacert)
ctx.use_certificate_file(cert)
self.connection = SSL.Connection(ctx, self.connection)
self.connection.set_accept_state()
self.rfile = FileLike(self.connection)
self.wfile = FileLike(self.connection)
def read_request(self, client_conn):
line = self.rfile.readline()
if line == "\r\n" or line == "\n": # Possible leftover from previous message
@ -494,7 +430,7 @@ class ProxyHandler(tcpserver.BaseHandler):
if port in self.config.transparent_proxy["sslports"]:
scheme = "https"
certfile = self.find_cert(host, port)
self.convert_to_ssl(certfile)
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
else:
scheme = "http"
method, path, httpversion = parse_init_http(line)
@ -527,7 +463,7 @@ class ProxyHandler(tcpserver.BaseHandler):
)
self.wfile.flush()
certfile = self.find_cert(host, port)
self.convert_to_ssl(certfile)
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
self.proxy_connect_state = (host, port, httpversion)
line = self.rfile.readline(line)
if self.proxy_connect_state:
@ -572,7 +508,7 @@ class ProxyHandler(tcpserver.BaseHandler):
class ProxyServerError(Exception): pass
class ProxyServer(tcpserver.TCPServer):
class ProxyServer(netlib.TCPServer):
allow_reuse_address = True
bound = True
def __init__(self, config, port, address=''):
@ -581,7 +517,7 @@ class ProxyServer(tcpserver.TCPServer):
"""
self.config, self.port, self.address = config, port, address
try:
tcpserver.TCPServer.__init__(self, (address, port))
netlib.TCPServer.__init__(self, (address, port))
except socket.error, v:
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
self.masterq = None
@ -600,7 +536,7 @@ class ProxyServer(tcpserver.TCPServer):
ProxyHandler(self.config, request, client_address, self, self.masterq)
def shutdown(self):
tcpserver.TCPServer.shutdown(self)
netlib.TCPServer.shutdown(self)
try:
shutil.rmtree(self.certdir)
except OSError:

15
test/test_netlib.py Normal file
View File

@ -0,0 +1,15 @@
import cStringIO
from libmproxy import netlib
class TestFileLike:
def test_wrap(self):
s = cStringIO.StringIO("foobar\nfoobar")
s = netlib.FileLike(s)
s.flush()
assert s.readline() == "foobar\n"
assert s.readline() == "foobar"
# Test __getattr__
assert s.isatty

View File

@ -60,17 +60,6 @@ def test_read_http_body():
assert len(proxy.read_http_body(s, h, True, 100)) == 7
class TestFileLike:
def test_wrap(self):
s = cStringIO.StringIO("foobar\nfoobar")
s = proxy.FileLike(s)
s.flush()
assert s.readline() == "foobar\n"
assert s.readline() == "foobar"
# Test __getattr__
assert s.isatty
class TestProxyError:
def test_simple(self):
p = proxy.ProxyError(111, "msg")