Improve debugging of thread and other leaks

- Add basethread.BaseThread that all threads outside of test suites should use
- Add a signal handler to mitmproxy, mitmdump and mitmweb that dumps resource
information to screen when SIGUSR1 is received.
- Improve thread naming throughout to make thread dumps understandable
This commit is contained in:
Aldo Cortesi 2016-06-11 19:52:24 +12:00
parent 5b9f07c81c
commit 09edbd9492
12 changed files with 138 additions and 50 deletions

View File

@ -5,8 +5,10 @@ import threading
from six.moves import queue
from netlib import basethread
from mitmproxy import exceptions
Events = frozenset([
"clientconnect",
"clientdisconnect",
@ -95,12 +97,13 @@ class Master(object):
self.should_exit.set()
class ServerThread(threading.Thread):
class ServerThread(basethread.BaseThread):
def __init__(self, server):
self.server = server
super(ServerThread, self).__init__()
address = getattr(self.server, "address", None)
self.name = "ServerThread ({})".format(repr(address))
super(ServerThread, self).__init__(
"ServerThread ({})".format(repr(address))
)
def run(self):
self.server.serve_forever()

View File

@ -47,6 +47,7 @@ def process_options(parser, options):
sys.exit(0)
if options.quiet:
options.verbose = 0
debug.register_info_dumper()
return config.process_proxy_options(parser, options)

View File

@ -18,6 +18,7 @@ from mitmproxy.protocol import base
from mitmproxy.protocol import http
import netlib.http
from netlib import tcp
from netlib import basethread
from netlib.http import http2
@ -261,10 +262,12 @@ class Http2Layer(base.Layer):
self._cleanup_streams()
class Http2SingleStreamLayer(http._HttpTransmissionLayer, threading.Thread):
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
def __init__(self, ctx, stream_id, request_headers):
super(Http2SingleStreamLayer, self).__init__(ctx, name="Thread-Http2SingleStreamLayer-{}".format(stream_id))
super(Http2SingleStreamLayer, self).__init__(
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
)
self.zombie = None
self.client_stream_id = stream_id
self.server_stream_id = None

View File

@ -1,6 +1,5 @@
from __future__ import absolute_import, print_function, division
import threading
import traceback
import netlib.exceptions
@ -8,12 +7,13 @@ from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import models
from netlib.http import http1
from netlib import basethread
# TODO: Doesn't really belong into mitmproxy.protocol...
class RequestReplayThread(threading.Thread):
class RequestReplayThread(basethread.BaseThread):
name = "RequestReplayThread"
def __init__(self, config, flow, event_queue, should_exit):
@ -26,7 +26,9 @@ class RequestReplayThread(threading.Thread):
self.channel = controller.Channel(event_queue, should_exit)
else:
self.channel = None
super(RequestReplayThread, self).__init__()
super(RequestReplayThread, self).__init__(
"RequestReplay (%s)" % flow.request.url
)
def run(self):
r = self.flow.request

View File

@ -5,10 +5,10 @@ offload computations from mitmproxy's main master thread.
from __future__ import absolute_import, print_function, division
from mitmproxy import controller
import threading
from netlib import basethread
class ScriptThread(threading.Thread):
class ScriptThread(basethread.BaseThread):
name = "ScriptThread"
@ -24,5 +24,8 @@ def concurrent(fn):
if not obj.reply.acked:
obj.reply.ack()
obj.reply.take()
ScriptThread(target=run).start()
ScriptThread(
"script.concurrent (%s)" % fn.__name__,
target=run
).start()
return _concurrent

14
netlib/basethread.py Normal file
View File

@ -0,0 +1,14 @@
import time
import threading
class BaseThread(threading.Thread):
def __init__(self, name, *args, **kwargs):
super(BaseThread, self).__init__(name=name, *args, **kwargs)
self._thread_started = time.time()
def _threadinfo(self):
return "%s - age: %is" % (
self.name,
int(time.time() - self._thread_started)
)

View File

@ -1,29 +1,76 @@
from __future__ import (absolute_import, print_function, division)
import sys
import threading
import signal
import platform
import psutil
from netlib import version
"""
Some utilities to help with debugging.
"""
def sysinfo():
data = [
"Mitmproxy verison: %s"%version.VERSION,
"Python version: %s"%platform.python_version(),
"Platform: %s"%platform.platform(),
"Mitmproxy verison: %s" % version.VERSION,
"Python version: %s" % platform.python_version(),
"Platform: %s" % platform.platform(),
]
d = platform.linux_distribution()
t = "Linux distro: %s %s %s"%d
if d[0]: # pragma: no-cover
t = "Linux distro: %s %s %s" % d
if d[0]: # pragma: no-cover
data.append(t)
d = platform.mac_ver()
t = "Mac version: %s %s %s"%d
if d[0]: # pragma: no-cover
t = "Mac version: %s %s %s" % d
if d[0]: # pragma: no-cover
data.append(t)
d = platform.win32_ver()
t = "Windows version: %s %s %s %s"%d
if d[0]: # pragma: no-cover
t = "Windows version: %s %s %s %s" % d
if d[0]: # pragma: no-cover
data.append(t)
return "\n".join(data)
def dump_info(sig, frm, file=sys.stdout): # pragma: no cover
p = psutil.Process()
print("****************************************************", file=file)
print("Summary", file=file)
print("=======", file=file)
print("num threads: ", p.num_threads(), file=file)
print("num fds: ", p.num_fds(), file=file)
print("memory: ", p.memory_info(), file=file)
print(file=file)
print("Threads", file=file)
print("=======", file=file)
bthreads = []
for i in threading.enumerate():
if hasattr(i, "_threadinfo"):
bthreads.append(i)
else:
print(i.name, file=file)
bthreads.sort(key=lambda x: x._thread_started)
for i in bthreads:
print(i._threadinfo(), file=file)
print(file=file)
print("Files", file=file)
print("=====", file=file)
for i in p.open_files():
print(i, file=file)
print(file=file)
print("Connections", file=file)
print("===========", file=file)
for i in p.connections():
print(i, file=file)
print("****************************************************", file=file)
def register_info_dumper(): # pragma: no cover
signal.signal(signal.SIGUSR1, dump_info)

View File

@ -17,7 +17,11 @@ import six
import OpenSSL
from OpenSSL import SSL
from netlib import certutils, version_check, basetypes, exceptions
from netlib import certutils
from netlib import version_check
from netlib import basetypes
from netlib import exceptions
from netlib import basethread
# This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed.
@ -900,12 +904,16 @@ class TCPServer(object):
raise
if self.socket in r:
connection, client_address = self.socket.accept()
t = threading.Thread(
t = basethread.BaseThread(
"TCPConnectionHandler (%s: %s:%s -> %s:%s)" % (
self.__class__.__name__,
client_address[0],
client_address[1],
self.address.host,
self.address.port
),
target=self.connection_thread,
args=(connection, client_address),
name="ConnectionThread (%s:%s -> %s:%s)" %
(client_address[0], client_address[1],
self.address.host, self.address.port)
)
t.setDaemon(1)
try:

View File

@ -8,15 +8,15 @@ from six.moves import queue
import random
import select
import time
import threading
import OpenSSL.crypto
import six
from netlib import tcp, certutils, websockets, socks
from netlib.exceptions import HttpException, TcpDisconnect, TcpTimeout, TlsException, TcpException, \
NetlibException
from netlib.http import http1, http2
from netlib import exceptions
from netlib.http import http1
from netlib.http import http2
from netlib import basethread
from pathod import log, language
@ -77,7 +77,7 @@ class SSLInfo(object):
return "\n".join(parts)
class WebsocketFrameReader(threading.Thread):
class WebsocketFrameReader(basethread.BaseThread):
def __init__(
self,
@ -88,7 +88,7 @@ class WebsocketFrameReader(threading.Thread):
ws_read_limit,
timeout
):
threading.Thread.__init__(self)
basethread.BaseThread.__init__(self, "WebsocketFrameReader")
self.timeout = timeout
self.ws_read_limit = ws_read_limit
self.logfp = logfp
@ -129,7 +129,7 @@ class WebsocketFrameReader(threading.Thread):
with self.logger.ctx() as log:
try:
frm = websockets.Frame.from_file(self.rfile)
except TcpDisconnect:
except exceptions.TcpDisconnect:
return
self.frames_queue.put(frm)
log("<< %s" % frm.header.human_readable())
@ -241,8 +241,8 @@ class Pathoc(tcp.TCPClient):
try:
resp = self.protocol.read_response(self.rfile, treq(method="CONNECT"))
if resp.status_code != 200:
raise HttpException("Unexpected status code: %s" % resp.status_code)
except HttpException as e:
raise exceptions.HttpException("Unexpected status code: %s" % resp.status_code)
except exceptions.HttpException as e:
six.reraise(PathocError, PathocError(
"Proxy CONNECT failed: %s" % repr(e)
))
@ -280,7 +280,7 @@ class Pathoc(tcp.TCPClient):
connect_reply.msg,
"SOCKS server error"
)
except (socks.SocksError, TcpDisconnect) as e:
except (socks.SocksError, exceptions.TcpDisconnect) as e:
raise PathocError(str(e))
def connect(self, connect_to=None, showssl=False, fp=sys.stdout):
@ -310,7 +310,7 @@ class Pathoc(tcp.TCPClient):
cipher_list=self.ciphers,
alpn_protos=alpn_protos
)
except TlsException as v:
except exceptions.TlsException as v:
raise PathocError(str(v))
self.sslinfo = SSLInfo(
@ -406,7 +406,7 @@ class Pathoc(tcp.TCPClient):
Returns Response if we have a non-ignored response.
May raise a NetlibException
May raise a exceptions.NetlibException
"""
logger = log.ConnectionLogger(
self.fp,
@ -424,10 +424,10 @@ class Pathoc(tcp.TCPClient):
resp = self.protocol.read_response(self.rfile, treq(method=req["method"].encode()))
resp.sslinfo = self.sslinfo
except HttpException as v:
except exceptions.HttpException as v:
lg("Invalid server response: %s" % v)
raise
except TcpTimeout:
except exceptions.TcpTimeout:
if self.ignoretimeout:
lg("Timeout (ignored)")
return None
@ -451,7 +451,7 @@ class Pathoc(tcp.TCPClient):
Returns Response if we have a non-ignored response.
May raise a NetlibException
May raise a exceptions.NetlibException
"""
if isinstance(r, basestring):
r = language.parse_pathoc(r, self.use_http2).next()
@ -530,11 +530,11 @@ def main(args): # pragma: no cover
# We consume the queue when we can, so it doesn't build up.
for i_ in p.wait(timeout=0, finish=False):
pass
except NetlibException:
except exceptions.NetlibException:
break
for i_ in p.wait(timeout=0.01, finish=True):
pass
except TcpException as v:
except exceptions.TcpException as v:
print(str(v), file=sys.stderr)
continue
except PathocError as v:

View File

@ -1,10 +1,10 @@
from six.moves import cStringIO as StringIO
import threading
import time
from six.moves import queue
from . import pathod
from netlib import basethread
class TimeoutError(Exception):
@ -95,11 +95,10 @@ class Daemon:
self.thread.join()
class _PaThread(threading.Thread):
class _PaThread(basethread.BaseThread):
def __init__(self, iface, q, ssl, daemonargs):
threading.Thread.__init__(self)
self.name = "PathodThread"
basethread.BaseThread.__init__(self, "PathodThread")
self.iface, self.q, self.ssl = iface, q, ssl
self.daemonargs = daemonargs
self.server = None

View File

@ -1,7 +1,6 @@
from setuptools import setup, find_packages
from codecs import open
import os
import sys
# Based on https://github.com/pypa/sampleproject/blob/master/setup.py
# and https://python-packaging-user-guide.readthedocs.org/
@ -73,6 +72,7 @@ setup(
"lxml>=3.5.0, <3.7",
"Pillow>=3.2, <3.3",
"passlib>=1.6.5, <1.7",
"psutil>=4.2.0, <4.3",
"pyasn1>=0.1.9, <0.2",
"pyOpenSSL>=16.0, <17.0",
"pyparsing>=2.1.3, <2.2",

View File

@ -1,6 +1,14 @@
from __future__ import (absolute_import, print_function, division)
from six.moves import cStringIO as StringIO
from netlib import debug
def test_dump_info():
cs = StringIO()
debug.dump_info(None, None, file=cs)
assert cs.getvalue()
def test_sysinfo():
assert debug.sysinfo()