Improve debugging of thread and other leaks

- Add basethread.BaseThread that all threads outside of test suites should use
- Add a signal handler to mitmproxy, mitmdump and mitmweb that dumps resource
information to screen when SIGUSR1 is received.
- Improve thread naming throughout to make thread dumps understandable
This commit is contained in:
Aldo Cortesi 2016-06-11 19:52:24 +12:00
parent 5b9f07c81c
commit 09edbd9492
12 changed files with 138 additions and 50 deletions

View File

@ -5,8 +5,10 @@ import threading
from six.moves import queue from six.moves import queue
from netlib import basethread
from mitmproxy import exceptions from mitmproxy import exceptions
Events = frozenset([ Events = frozenset([
"clientconnect", "clientconnect",
"clientdisconnect", "clientdisconnect",
@ -95,12 +97,13 @@ class Master(object):
self.should_exit.set() self.should_exit.set()
class ServerThread(threading.Thread): class ServerThread(basethread.BaseThread):
def __init__(self, server): def __init__(self, server):
self.server = server self.server = server
super(ServerThread, self).__init__()
address = getattr(self.server, "address", None) address = getattr(self.server, "address", None)
self.name = "ServerThread ({})".format(repr(address)) super(ServerThread, self).__init__(
"ServerThread ({})".format(repr(address))
)
def run(self): def run(self):
self.server.serve_forever() self.server.serve_forever()

View File

@ -47,6 +47,7 @@ def process_options(parser, options):
sys.exit(0) sys.exit(0)
if options.quiet: if options.quiet:
options.verbose = 0 options.verbose = 0
debug.register_info_dumper()
return config.process_proxy_options(parser, options) return config.process_proxy_options(parser, options)

View File

@ -18,6 +18,7 @@ from mitmproxy.protocol import base
from mitmproxy.protocol import http from mitmproxy.protocol import http
import netlib.http import netlib.http
from netlib import tcp from netlib import tcp
from netlib import basethread
from netlib.http import http2 from netlib.http import http2
@ -261,10 +262,12 @@ class Http2Layer(base.Layer):
self._cleanup_streams() self._cleanup_streams()
class Http2SingleStreamLayer(http._HttpTransmissionLayer, threading.Thread): class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
def __init__(self, ctx, stream_id, request_headers): def __init__(self, ctx, stream_id, request_headers):
super(Http2SingleStreamLayer, self).__init__(ctx, name="Thread-Http2SingleStreamLayer-{}".format(stream_id)) super(Http2SingleStreamLayer, self).__init__(
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
)
self.zombie = None self.zombie = None
self.client_stream_id = stream_id self.client_stream_id = stream_id
self.server_stream_id = None self.server_stream_id = None

View File

@ -1,6 +1,5 @@
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import threading
import traceback import traceback
import netlib.exceptions import netlib.exceptions
@ -8,12 +7,13 @@ from mitmproxy import controller
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import models from mitmproxy import models
from netlib.http import http1 from netlib.http import http1
from netlib import basethread
# TODO: Doesn't really belong into mitmproxy.protocol... # TODO: Doesn't really belong into mitmproxy.protocol...
class RequestReplayThread(threading.Thread): class RequestReplayThread(basethread.BaseThread):
name = "RequestReplayThread" name = "RequestReplayThread"
def __init__(self, config, flow, event_queue, should_exit): def __init__(self, config, flow, event_queue, should_exit):
@ -26,7 +26,9 @@ class RequestReplayThread(threading.Thread):
self.channel = controller.Channel(event_queue, should_exit) self.channel = controller.Channel(event_queue, should_exit)
else: else:
self.channel = None self.channel = None
super(RequestReplayThread, self).__init__() super(RequestReplayThread, self).__init__(
"RequestReplay (%s)" % flow.request.url
)
def run(self): def run(self):
r = self.flow.request r = self.flow.request

View File

@ -5,10 +5,10 @@ offload computations from mitmproxy's main master thread.
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from mitmproxy import controller from mitmproxy import controller
import threading from netlib import basethread
class ScriptThread(threading.Thread): class ScriptThread(basethread.BaseThread):
name = "ScriptThread" name = "ScriptThread"
@ -24,5 +24,8 @@ def concurrent(fn):
if not obj.reply.acked: if not obj.reply.acked:
obj.reply.ack() obj.reply.ack()
obj.reply.take() obj.reply.take()
ScriptThread(target=run).start() ScriptThread(
"script.concurrent (%s)" % fn.__name__,
target=run
).start()
return _concurrent return _concurrent

14
netlib/basethread.py Normal file
View File

@ -0,0 +1,14 @@
import time
import threading
class BaseThread(threading.Thread):
def __init__(self, name, *args, **kwargs):
super(BaseThread, self).__init__(name=name, *args, **kwargs)
self._thread_started = time.time()
def _threadinfo(self):
return "%s - age: %is" % (
self.name,
int(time.time() - self._thread_started)
)

View File

@ -1,29 +1,76 @@
from __future__ import (absolute_import, print_function, division)
import sys
import threading
import signal
import platform import platform
import psutil
from netlib import version from netlib import version
"""
Some utilities to help with debugging.
"""
def sysinfo(): def sysinfo():
data = [ data = [
"Mitmproxy verison: %s"%version.VERSION, "Mitmproxy verison: %s" % version.VERSION,
"Python version: %s"%platform.python_version(), "Python version: %s" % platform.python_version(),
"Platform: %s"%platform.platform(), "Platform: %s" % platform.platform(),
] ]
d = platform.linux_distribution() d = platform.linux_distribution()
t = "Linux distro: %s %s %s"%d t = "Linux distro: %s %s %s" % d
if d[0]: # pragma: no-cover if d[0]: # pragma: no-cover
data.append(t) data.append(t)
d = platform.mac_ver() d = platform.mac_ver()
t = "Mac version: %s %s %s"%d t = "Mac version: %s %s %s" % d
if d[0]: # pragma: no-cover if d[0]: # pragma: no-cover
data.append(t) data.append(t)
d = platform.win32_ver() d = platform.win32_ver()
t = "Windows version: %s %s %s %s"%d t = "Windows version: %s %s %s %s" % d
if d[0]: # pragma: no-cover if d[0]: # pragma: no-cover
data.append(t) data.append(t)
return "\n".join(data) return "\n".join(data)
def dump_info(sig, frm, file=sys.stdout): # pragma: no cover
p = psutil.Process()
print("****************************************************", file=file)
print("Summary", file=file)
print("=======", file=file)
print("num threads: ", p.num_threads(), file=file)
print("num fds: ", p.num_fds(), file=file)
print("memory: ", p.memory_info(), file=file)
print(file=file)
print("Threads", file=file)
print("=======", file=file)
bthreads = []
for i in threading.enumerate():
if hasattr(i, "_threadinfo"):
bthreads.append(i)
else:
print(i.name, file=file)
bthreads.sort(key=lambda x: x._thread_started)
for i in bthreads:
print(i._threadinfo(), file=file)
print(file=file)
print("Files", file=file)
print("=====", file=file)
for i in p.open_files():
print(i, file=file)
print(file=file)
print("Connections", file=file)
print("===========", file=file)
for i in p.connections():
print(i, file=file)
print("****************************************************", file=file)
def register_info_dumper(): # pragma: no cover
signal.signal(signal.SIGUSR1, dump_info)

View File

@ -17,7 +17,11 @@ import six
import OpenSSL import OpenSSL
from OpenSSL import SSL from OpenSSL import SSL
from netlib import certutils, version_check, basetypes, exceptions from netlib import certutils
from netlib import version_check
from netlib import basetypes
from netlib import exceptions
from netlib import basethread
# This is a rather hackish way to make sure that # This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed. # the latest version of pyOpenSSL is actually installed.
@ -900,12 +904,16 @@ class TCPServer(object):
raise raise
if self.socket in r: if self.socket in r:
connection, client_address = self.socket.accept() connection, client_address = self.socket.accept()
t = threading.Thread( t = basethread.BaseThread(
"TCPConnectionHandler (%s: %s:%s -> %s:%s)" % (
self.__class__.__name__,
client_address[0],
client_address[1],
self.address.host,
self.address.port
),
target=self.connection_thread, target=self.connection_thread,
args=(connection, client_address), args=(connection, client_address),
name="ConnectionThread (%s:%s -> %s:%s)" %
(client_address[0], client_address[1],
self.address.host, self.address.port)
) )
t.setDaemon(1) t.setDaemon(1)
try: try:

View File

@ -8,15 +8,15 @@ from six.moves import queue
import random import random
import select import select
import time import time
import threading
import OpenSSL.crypto import OpenSSL.crypto
import six import six
from netlib import tcp, certutils, websockets, socks from netlib import tcp, certutils, websockets, socks
from netlib.exceptions import HttpException, TcpDisconnect, TcpTimeout, TlsException, TcpException, \ from netlib import exceptions
NetlibException from netlib.http import http1
from netlib.http import http1, http2 from netlib.http import http2
from netlib import basethread
from pathod import log, language from pathod import log, language
@ -77,7 +77,7 @@ class SSLInfo(object):
return "\n".join(parts) return "\n".join(parts)
class WebsocketFrameReader(threading.Thread): class WebsocketFrameReader(basethread.BaseThread):
def __init__( def __init__(
self, self,
@ -88,7 +88,7 @@ class WebsocketFrameReader(threading.Thread):
ws_read_limit, ws_read_limit,
timeout timeout
): ):
threading.Thread.__init__(self) basethread.BaseThread.__init__(self, "WebsocketFrameReader")
self.timeout = timeout self.timeout = timeout
self.ws_read_limit = ws_read_limit self.ws_read_limit = ws_read_limit
self.logfp = logfp self.logfp = logfp
@ -129,7 +129,7 @@ class WebsocketFrameReader(threading.Thread):
with self.logger.ctx() as log: with self.logger.ctx() as log:
try: try:
frm = websockets.Frame.from_file(self.rfile) frm = websockets.Frame.from_file(self.rfile)
except TcpDisconnect: except exceptions.TcpDisconnect:
return return
self.frames_queue.put(frm) self.frames_queue.put(frm)
log("<< %s" % frm.header.human_readable()) log("<< %s" % frm.header.human_readable())
@ -241,8 +241,8 @@ class Pathoc(tcp.TCPClient):
try: try:
resp = self.protocol.read_response(self.rfile, treq(method="CONNECT")) resp = self.protocol.read_response(self.rfile, treq(method="CONNECT"))
if resp.status_code != 200: if resp.status_code != 200:
raise HttpException("Unexpected status code: %s" % resp.status_code) raise exceptions.HttpException("Unexpected status code: %s" % resp.status_code)
except HttpException as e: except exceptions.HttpException as e:
six.reraise(PathocError, PathocError( six.reraise(PathocError, PathocError(
"Proxy CONNECT failed: %s" % repr(e) "Proxy CONNECT failed: %s" % repr(e)
)) ))
@ -280,7 +280,7 @@ class Pathoc(tcp.TCPClient):
connect_reply.msg, connect_reply.msg,
"SOCKS server error" "SOCKS server error"
) )
except (socks.SocksError, TcpDisconnect) as e: except (socks.SocksError, exceptions.TcpDisconnect) as e:
raise PathocError(str(e)) raise PathocError(str(e))
def connect(self, connect_to=None, showssl=False, fp=sys.stdout): def connect(self, connect_to=None, showssl=False, fp=sys.stdout):
@ -310,7 +310,7 @@ class Pathoc(tcp.TCPClient):
cipher_list=self.ciphers, cipher_list=self.ciphers,
alpn_protos=alpn_protos alpn_protos=alpn_protos
) )
except TlsException as v: except exceptions.TlsException as v:
raise PathocError(str(v)) raise PathocError(str(v))
self.sslinfo = SSLInfo( self.sslinfo = SSLInfo(
@ -406,7 +406,7 @@ class Pathoc(tcp.TCPClient):
Returns Response if we have a non-ignored response. Returns Response if we have a non-ignored response.
May raise a NetlibException May raise a exceptions.NetlibException
""" """
logger = log.ConnectionLogger( logger = log.ConnectionLogger(
self.fp, self.fp,
@ -424,10 +424,10 @@ class Pathoc(tcp.TCPClient):
resp = self.protocol.read_response(self.rfile, treq(method=req["method"].encode())) resp = self.protocol.read_response(self.rfile, treq(method=req["method"].encode()))
resp.sslinfo = self.sslinfo resp.sslinfo = self.sslinfo
except HttpException as v: except exceptions.HttpException as v:
lg("Invalid server response: %s" % v) lg("Invalid server response: %s" % v)
raise raise
except TcpTimeout: except exceptions.TcpTimeout:
if self.ignoretimeout: if self.ignoretimeout:
lg("Timeout (ignored)") lg("Timeout (ignored)")
return None return None
@ -451,7 +451,7 @@ class Pathoc(tcp.TCPClient):
Returns Response if we have a non-ignored response. Returns Response if we have a non-ignored response.
May raise a NetlibException May raise a exceptions.NetlibException
""" """
if isinstance(r, basestring): if isinstance(r, basestring):
r = language.parse_pathoc(r, self.use_http2).next() r = language.parse_pathoc(r, self.use_http2).next()
@ -530,11 +530,11 @@ def main(args): # pragma: no cover
# We consume the queue when we can, so it doesn't build up. # We consume the queue when we can, so it doesn't build up.
for i_ in p.wait(timeout=0, finish=False): for i_ in p.wait(timeout=0, finish=False):
pass pass
except NetlibException: except exceptions.NetlibException:
break break
for i_ in p.wait(timeout=0.01, finish=True): for i_ in p.wait(timeout=0.01, finish=True):
pass pass
except TcpException as v: except exceptions.TcpException as v:
print(str(v), file=sys.stderr) print(str(v), file=sys.stderr)
continue continue
except PathocError as v: except PathocError as v:

View File

@ -1,10 +1,10 @@
from six.moves import cStringIO as StringIO from six.moves import cStringIO as StringIO
import threading
import time import time
from six.moves import queue from six.moves import queue
from . import pathod from . import pathod
from netlib import basethread
class TimeoutError(Exception): class TimeoutError(Exception):
@ -95,11 +95,10 @@ class Daemon:
self.thread.join() self.thread.join()
class _PaThread(threading.Thread): class _PaThread(basethread.BaseThread):
def __init__(self, iface, q, ssl, daemonargs): def __init__(self, iface, q, ssl, daemonargs):
threading.Thread.__init__(self) basethread.BaseThread.__init__(self, "PathodThread")
self.name = "PathodThread"
self.iface, self.q, self.ssl = iface, q, ssl self.iface, self.q, self.ssl = iface, q, ssl
self.daemonargs = daemonargs self.daemonargs = daemonargs
self.server = None self.server = None

View File

@ -1,7 +1,6 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
from codecs import open from codecs import open
import os import os
import sys
# Based on https://github.com/pypa/sampleproject/blob/master/setup.py # Based on https://github.com/pypa/sampleproject/blob/master/setup.py
# and https://python-packaging-user-guide.readthedocs.org/ # and https://python-packaging-user-guide.readthedocs.org/
@ -73,6 +72,7 @@ setup(
"lxml>=3.5.0, <3.7", "lxml>=3.5.0, <3.7",
"Pillow>=3.2, <3.3", "Pillow>=3.2, <3.3",
"passlib>=1.6.5, <1.7", "passlib>=1.6.5, <1.7",
"psutil>=4.2.0, <4.3",
"pyasn1>=0.1.9, <0.2", "pyasn1>=0.1.9, <0.2",
"pyOpenSSL>=16.0, <17.0", "pyOpenSSL>=16.0, <17.0",
"pyparsing>=2.1.3, <2.2", "pyparsing>=2.1.3, <2.2",

View File

@ -1,6 +1,14 @@
from __future__ import (absolute_import, print_function, division)
from six.moves import cStringIO as StringIO
from netlib import debug from netlib import debug
def test_dump_info():
cs = StringIO()
debug.dump_info(None, None, file=cs)
assert cs.getvalue()
def test_sysinfo(): def test_sysinfo():
assert debug.sysinfo() assert debug.sysinfo()