Rename our tcpserver to netlib, expand to include client network functions.
This commit is contained in:
parent
8ae64337ed
commit
4e53f1ee90
|
@ -1,4 +1,80 @@
|
|||
import select, socket, threading
|
||||
from OpenSSL import SSL
|
||||
|
||||
|
||||
class NetLibError(Exception): pass
|
||||
|
||||
|
||||
class FileLike:
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.o, attr)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def read(self, length):
|
||||
result = ''
|
||||
while len(result) < length:
|
||||
try:
|
||||
data = self.o.read(length)
|
||||
except AttributeError:
|
||||
break
|
||||
except SSL.ZeroReturnError:
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
result += data
|
||||
return result
|
||||
|
||||
def write(self, v):
|
||||
self.o.sendall(v)
|
||||
|
||||
def readline(self, size = None):
|
||||
result = ''
|
||||
bytes_read = 0
|
||||
while True:
|
||||
if size is not None and bytes_read >= size:
|
||||
break
|
||||
ch = self.read(1)
|
||||
bytes_read += 1
|
||||
if not ch:
|
||||
break
|
||||
else:
|
||||
result += ch
|
||||
if ch == '\n':
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
class TCPClient:
|
||||
def __init__(self, ssl, host, port, clientcert):
|
||||
self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert
|
||||
self.sock, self.rfile, self.wfile = None, None, None
|
||||
self.cert = None
|
||||
self.connect()
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
addr = socket.gethostbyname(self.host)
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
if self.ssl:
|
||||
context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
if self.clientcert:
|
||||
context.use_certificate_file(self.clientcert)
|
||||
server = SSL.Connection(context, server)
|
||||
server.connect((addr, self.port))
|
||||
if self.ssl:
|
||||
self.cert = server.get_peer_certificate()
|
||||
self.rfile, self.wfile = FileLike(server), FileLike(server)
|
||||
else:
|
||||
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
|
||||
except socket.error, err:
|
||||
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
|
||||
self.sock = server
|
||||
|
||||
|
||||
class BaseHandler:
|
||||
rbufsize = -1
|
||||
|
@ -13,6 +89,15 @@ class BaseHandler:
|
|||
self.handle()
|
||||
self.finish()
|
||||
|
||||
def convert_to_ssl(self, cert, key):
|
||||
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
||||
ctx.use_privatekey_file(key)
|
||||
ctx.use_certificate_file(cert)
|
||||
self.connection = SSL.Connection(ctx, self.connection)
|
||||
self.connection.set_accept_state()
|
||||
self.rfile = FileLike(self.connection)
|
||||
self.wfile = FileLike(self.connection)
|
||||
|
||||
def finish(self):
|
||||
try:
|
||||
if not getattr(self.wfile, "closed", False):
|
|
@ -21,7 +21,7 @@
|
|||
import sys, os, string, socket, time
|
||||
import shutil, tempfile, threading
|
||||
import optparse, SocketServer
|
||||
import utils, flow, certutils, version, wsgi, tcpserver
|
||||
import utils, flow, certutils, version, wsgi, netlib
|
||||
from OpenSSL import SSL
|
||||
|
||||
|
||||
|
@ -232,50 +232,6 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit):
|
|||
return read_http_body(rfile, headers, False, limit)
|
||||
|
||||
|
||||
class FileLike:
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.o, attr)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def read(self, length):
|
||||
result = ''
|
||||
while len(result) < length:
|
||||
try:
|
||||
data = self.o.read(length)
|
||||
except AttributeError:
|
||||
break
|
||||
except SSL.ZeroReturnError:
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
result += data
|
||||
return result
|
||||
|
||||
def write(self, v):
|
||||
self.o.sendall(v)
|
||||
|
||||
def readline(self, size = None):
|
||||
result = ''
|
||||
bytes_read = 0
|
||||
while True:
|
||||
if size is not None and bytes_read >= size:
|
||||
break
|
||||
ch = self.read(1)
|
||||
bytes_read += 1
|
||||
if not ch:
|
||||
break
|
||||
else:
|
||||
result += ch
|
||||
if ch == '\n':
|
||||
break
|
||||
return result
|
||||
|
||||
|
||||
class RequestReplayThread(threading.Thread):
|
||||
def __init__(self, config, flow, masterq):
|
||||
self.config, self.flow, self.masterq = config, flow, masterq
|
||||
|
@ -291,41 +247,27 @@ class RequestReplayThread(threading.Thread):
|
|||
except ProxyError, v:
|
||||
err = flow.Error(self.flow.request, v.msg)
|
||||
err._send(self.masterq)
|
||||
except netlib.NetLibError, v:
|
||||
raise ProxyError(502, v)
|
||||
|
||||
|
||||
class ServerConnection:
|
||||
class ServerConnection(netlib.TCPClient):
|
||||
def __init__(self, config, scheme, host, port):
|
||||
self.config, self.scheme, self.host, self.port = config, scheme, host, port
|
||||
self.cert = None
|
||||
self.sock, self.rfile, self.wfile = None, None, None
|
||||
self.connect()
|
||||
clientcert = None
|
||||
if config.clientcerts:
|
||||
path = os.path.join(config.clientcerts, self.host) + ".pem"
|
||||
if os.path.exists(clientcert):
|
||||
clientcert = path
|
||||
netlib.TCPClient.__init__(
|
||||
self,
|
||||
True if scheme == "https" else False,
|
||||
host,
|
||||
port,
|
||||
clientcert
|
||||
)
|
||||
self.config, self.scheme = config, scheme
|
||||
self.requestcount = 0
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
addr = socket.gethostbyname(self.host)
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
if self.scheme == "https":
|
||||
if self.config.clientcerts:
|
||||
clientcert = os.path.join(self.config.clientcerts, self.host) + ".pem"
|
||||
if not os.path.exists(clientcert):
|
||||
clientcert = None
|
||||
else:
|
||||
clientcert = None
|
||||
context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
if clientcert:
|
||||
context.use_certificate_file(clientcert)
|
||||
server = SSL.Connection(context, server)
|
||||
server.connect((addr, self.port))
|
||||
if self.scheme == "https":
|
||||
self.cert = server.get_peer_certificate()
|
||||
self.rfile, self.wfile = FileLike(server), FileLike(server)
|
||||
else:
|
||||
self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
|
||||
except socket.error, err:
|
||||
raise ProxyError(502, 'Error connecting to "%s": %s' % (self.host, err))
|
||||
self.sock = server
|
||||
|
||||
def send(self, request):
|
||||
self.requestcount += 1
|
||||
try:
|
||||
|
@ -374,13 +316,13 @@ class ServerConnection:
|
|||
pass
|
||||
|
||||
|
||||
class ProxyHandler(tcpserver.BaseHandler):
|
||||
class ProxyHandler(netlib.BaseHandler):
|
||||
def __init__(self, config, connection, client_address, server, q):
|
||||
self.mqueue = q
|
||||
self.config = config
|
||||
self.server_conn = None
|
||||
self.proxy_connect_state = None
|
||||
tcpserver.BaseHandler.__init__(self, connection, client_address, server)
|
||||
netlib.BaseHandler.__init__(self, connection, client_address, server)
|
||||
|
||||
def handle(self):
|
||||
cc = flow.ClientConnect(self.client_address)
|
||||
|
@ -397,7 +339,10 @@ class ProxyHandler(tcpserver.BaseHandler):
|
|||
sc.terminate()
|
||||
self.server_conn = None
|
||||
if not self.server_conn:
|
||||
self.server_conn = ServerConnection(self.config, scheme, host, port)
|
||||
try:
|
||||
self.server_conn = ServerConnection(self.config, scheme, host, port)
|
||||
except netlib.NetLibError, v:
|
||||
raise ProxyError(502, v)
|
||||
|
||||
def handle_request(self, cc):
|
||||
try:
|
||||
|
@ -473,15 +418,6 @@ class ProxyHandler(tcpserver.BaseHandler):
|
|||
raise ProxyError(502, "mitmproxy: Unable to generate dummy cert.")
|
||||
return ret
|
||||
|
||||
def convert_to_ssl(self, cert):
|
||||
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
||||
ctx.use_privatekey_file(self.config.certfile or self.config.cacert)
|
||||
ctx.use_certificate_file(cert)
|
||||
self.connection = SSL.Connection(ctx, self.connection)
|
||||
self.connection.set_accept_state()
|
||||
self.rfile = FileLike(self.connection)
|
||||
self.wfile = FileLike(self.connection)
|
||||
|
||||
def read_request(self, client_conn):
|
||||
line = self.rfile.readline()
|
||||
if line == "\r\n" or line == "\n": # Possible leftover from previous message
|
||||
|
@ -494,7 +430,7 @@ class ProxyHandler(tcpserver.BaseHandler):
|
|||
if port in self.config.transparent_proxy["sslports"]:
|
||||
scheme = "https"
|
||||
certfile = self.find_cert(host, port)
|
||||
self.convert_to_ssl(certfile)
|
||||
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
|
||||
else:
|
||||
scheme = "http"
|
||||
method, path, httpversion = parse_init_http(line)
|
||||
|
@ -527,7 +463,7 @@ class ProxyHandler(tcpserver.BaseHandler):
|
|||
)
|
||||
self.wfile.flush()
|
||||
certfile = self.find_cert(host, port)
|
||||
self.convert_to_ssl(certfile)
|
||||
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
|
||||
self.proxy_connect_state = (host, port, httpversion)
|
||||
line = self.rfile.readline(line)
|
||||
if self.proxy_connect_state:
|
||||
|
@ -572,7 +508,7 @@ class ProxyHandler(tcpserver.BaseHandler):
|
|||
class ProxyServerError(Exception): pass
|
||||
|
||||
|
||||
class ProxyServer(tcpserver.TCPServer):
|
||||
class ProxyServer(netlib.TCPServer):
|
||||
allow_reuse_address = True
|
||||
bound = True
|
||||
def __init__(self, config, port, address=''):
|
||||
|
@ -581,7 +517,7 @@ class ProxyServer(tcpserver.TCPServer):
|
|||
"""
|
||||
self.config, self.port, self.address = config, port, address
|
||||
try:
|
||||
tcpserver.TCPServer.__init__(self, (address, port))
|
||||
netlib.TCPServer.__init__(self, (address, port))
|
||||
except socket.error, v:
|
||||
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
|
||||
self.masterq = None
|
||||
|
@ -600,7 +536,7 @@ class ProxyServer(tcpserver.TCPServer):
|
|||
ProxyHandler(self.config, request, client_address, self, self.masterq)
|
||||
|
||||
def shutdown(self):
|
||||
tcpserver.TCPServer.shutdown(self)
|
||||
netlib.TCPServer.shutdown(self)
|
||||
try:
|
||||
shutil.rmtree(self.certdir)
|
||||
except OSError:
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
import cStringIO
|
||||
from libmproxy import netlib
|
||||
|
||||
|
||||
class TestFileLike:
|
||||
def test_wrap(self):
|
||||
s = cStringIO.StringIO("foobar\nfoobar")
|
||||
s = netlib.FileLike(s)
|
||||
s.flush()
|
||||
assert s.readline() == "foobar\n"
|
||||
assert s.readline() == "foobar"
|
||||
# Test __getattr__
|
||||
assert s.isatty
|
||||
|
||||
|
|
@ -60,17 +60,6 @@ def test_read_http_body():
|
|||
assert len(proxy.read_http_body(s, h, True, 100)) == 7
|
||||
|
||||
|
||||
class TestFileLike:
|
||||
def test_wrap(self):
|
||||
s = cStringIO.StringIO("foobar\nfoobar")
|
||||
s = proxy.FileLike(s)
|
||||
s.flush()
|
||||
assert s.readline() == "foobar\n"
|
||||
assert s.readline() == "foobar"
|
||||
# Test __getattr__
|
||||
assert s.isatty
|
||||
|
||||
|
||||
class TestProxyError:
|
||||
def test_simple(self):
|
||||
p = proxy.ProxyError(111, "msg")
|
||||
|
|
Loading…
Reference in New Issue