From 09edbd9492e59c0c8dcae69b4b1f4b745867abe4 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 11 Jun 2016 19:52:24 +1200 Subject: [PATCH] Improve debugging of thread and other leaks - Add basethread.BaseThread that all threads outside of test suites should use - Add a signal handler to mitmproxy, mitmdump and mitmweb that dumps resource information to screen when SIGUSR1 is received. - Improve thread naming throughout to make thread dumps understandable --- mitmproxy/controller.py | 9 ++-- mitmproxy/main.py | 1 + mitmproxy/protocol/http2.py | 7 ++- mitmproxy/protocol/http_replay.py | 8 ++-- mitmproxy/script/concurrent.py | 9 ++-- netlib/basethread.py | 14 ++++++ netlib/debug.py | 71 +++++++++++++++++++++++++------ netlib/tcp.py | 18 +++++--- pathod/pathoc.py | 34 +++++++-------- pathod/test.py | 7 ++- setup.py | 2 +- test/netlib/test_debug.py | 8 ++++ 12 files changed, 138 insertions(+), 50 deletions(-) create mode 100644 netlib/basethread.py diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 084702a65..898be3bc2 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -5,8 +5,10 @@ import threading from six.moves import queue +from netlib import basethread from mitmproxy import exceptions + Events = frozenset([ "clientconnect", "clientdisconnect", @@ -95,12 +97,13 @@ class Master(object): self.should_exit.set() -class ServerThread(threading.Thread): +class ServerThread(basethread.BaseThread): def __init__(self, server): self.server = server - super(ServerThread, self).__init__() address = getattr(self.server, "address", None) - self.name = "ServerThread ({})".format(repr(address)) + super(ServerThread, self).__init__( + "ServerThread ({})".format(repr(address)) + ) def run(self): self.server.serve_forever() diff --git a/mitmproxy/main.py b/mitmproxy/main.py index 34d4aa6b1..53417fe86 100644 --- a/mitmproxy/main.py +++ b/mitmproxy/main.py @@ -47,6 +47,7 @@ def process_options(parser, options): sys.exit(0) if options.quiet: options.verbose = 0 + debug.register_info_dumper() return config.process_proxy_options(parser, options) diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index 9247e6577..957b8d64a 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -18,6 +18,7 @@ from mitmproxy.protocol import base from mitmproxy.protocol import http import netlib.http from netlib import tcp +from netlib import basethread from netlib.http import http2 @@ -261,10 +262,12 @@ class Http2Layer(base.Layer): self._cleanup_streams() -class Http2SingleStreamLayer(http._HttpTransmissionLayer, threading.Thread): +class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread): def __init__(self, ctx, stream_id, request_headers): - super(Http2SingleStreamLayer, self).__init__(ctx, name="Thread-Http2SingleStreamLayer-{}".format(stream_id)) + super(Http2SingleStreamLayer, self).__init__( + ctx, name="Http2SingleStreamLayer-{}".format(stream_id) + ) self.zombie = None self.client_stream_id = stream_id self.server_stream_id = None diff --git a/mitmproxy/protocol/http_replay.py b/mitmproxy/protocol/http_replay.py index 5928c0afe..e804eba9b 100644 --- a/mitmproxy/protocol/http_replay.py +++ b/mitmproxy/protocol/http_replay.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, print_function, division -import threading import traceback import netlib.exceptions @@ -8,12 +7,13 @@ from mitmproxy import controller from mitmproxy import exceptions from mitmproxy import models from netlib.http import http1 +from netlib import basethread # TODO: Doesn't really belong into mitmproxy.protocol... -class RequestReplayThread(threading.Thread): +class RequestReplayThread(basethread.BaseThread): name = "RequestReplayThread" def __init__(self, config, flow, event_queue, should_exit): @@ -26,7 +26,9 @@ class RequestReplayThread(threading.Thread): self.channel = controller.Channel(event_queue, should_exit) else: self.channel = None - super(RequestReplayThread, self).__init__() + super(RequestReplayThread, self).__init__( + "RequestReplay (%s)" % flow.request.url + ) def run(self): r = self.flow.request diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index 89c835f61..56d39d0bd 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -5,10 +5,10 @@ offload computations from mitmproxy's main master thread. from __future__ import absolute_import, print_function, division from mitmproxy import controller -import threading +from netlib import basethread -class ScriptThread(threading.Thread): +class ScriptThread(basethread.BaseThread): name = "ScriptThread" @@ -24,5 +24,8 @@ def concurrent(fn): if not obj.reply.acked: obj.reply.ack() obj.reply.take() - ScriptThread(target=run).start() + ScriptThread( + "script.concurrent (%s)" % fn.__name__, + target=run + ).start() return _concurrent diff --git a/netlib/basethread.py b/netlib/basethread.py new file mode 100644 index 000000000..7963eb7e7 --- /dev/null +++ b/netlib/basethread.py @@ -0,0 +1,14 @@ +import time +import threading + + +class BaseThread(threading.Thread): + def __init__(self, name, *args, **kwargs): + super(BaseThread, self).__init__(name=name, *args, **kwargs) + self._thread_started = time.time() + + def _threadinfo(self): + return "%s - age: %is" % ( + self.name, + int(time.time() - self._thread_started) + ) diff --git a/netlib/debug.py b/netlib/debug.py index bf446eb03..b48cb1220 100644 --- a/netlib/debug.py +++ b/netlib/debug.py @@ -1,29 +1,76 @@ +from __future__ import (absolute_import, print_function, division) + +import sys +import threading +import signal import platform + +import psutil + from netlib import version -""" - Some utilities to help with debugging. -""" def sysinfo(): data = [ - "Mitmproxy verison: %s"%version.VERSION, - "Python version: %s"%platform.python_version(), - "Platform: %s"%platform.platform(), + "Mitmproxy verison: %s" % version.VERSION, + "Python version: %s" % platform.python_version(), + "Platform: %s" % platform.platform(), ] d = platform.linux_distribution() - t = "Linux distro: %s %s %s"%d - if d[0]: # pragma: no-cover + t = "Linux distro: %s %s %s" % d + if d[0]: # pragma: no-cover data.append(t) d = platform.mac_ver() - t = "Mac version: %s %s %s"%d - if d[0]: # pragma: no-cover + t = "Mac version: %s %s %s" % d + if d[0]: # pragma: no-cover data.append(t) d = platform.win32_ver() - t = "Windows version: %s %s %s %s"%d - if d[0]: # pragma: no-cover + t = "Windows version: %s %s %s %s" % d + if d[0]: # pragma: no-cover data.append(t) return "\n".join(data) + + +def dump_info(sig, frm, file=sys.stdout): # pragma: no cover + p = psutil.Process() + + print("****************************************************", file=file) + print("Summary", file=file) + print("=======", file=file) + print("num threads: ", p.num_threads(), file=file) + print("num fds: ", p.num_fds(), file=file) + print("memory: ", p.memory_info(), file=file) + + print(file=file) + print("Threads", file=file) + print("=======", file=file) + bthreads = [] + for i in threading.enumerate(): + if hasattr(i, "_threadinfo"): + bthreads.append(i) + else: + print(i.name, file=file) + bthreads.sort(key=lambda x: x._thread_started) + for i in bthreads: + print(i._threadinfo(), file=file) + + print(file=file) + print("Files", file=file) + print("=====", file=file) + for i in p.open_files(): + print(i, file=file) + + print(file=file) + print("Connections", file=file) + print("===========", file=file) + for i in p.connections(): + print(i, file=file) + + print("****************************************************", file=file) + + +def register_info_dumper(): # pragma: no cover + signal.signal(signal.SIGUSR1, dump_info) diff --git a/netlib/tcp.py b/netlib/tcp.py index 0eec326bb..acd67cad6 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -17,7 +17,11 @@ import six import OpenSSL from OpenSSL import SSL -from netlib import certutils, version_check, basetypes, exceptions +from netlib import certutils +from netlib import version_check +from netlib import basetypes +from netlib import exceptions +from netlib import basethread # This is a rather hackish way to make sure that # the latest version of pyOpenSSL is actually installed. @@ -900,12 +904,16 @@ class TCPServer(object): raise if self.socket in r: connection, client_address = self.socket.accept() - t = threading.Thread( + t = basethread.BaseThread( + "TCPConnectionHandler (%s: %s:%s -> %s:%s)" % ( + self.__class__.__name__, + client_address[0], + client_address[1], + self.address.host, + self.address.port + ), target=self.connection_thread, args=(connection, client_address), - name="ConnectionThread (%s:%s -> %s:%s)" % - (client_address[0], client_address[1], - self.address.host, self.address.port) ) t.setDaemon(1) try: diff --git a/pathod/pathoc.py b/pathod/pathoc.py index def6cfcfe..b25639887 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -8,15 +8,15 @@ from six.moves import queue import random import select import time -import threading import OpenSSL.crypto import six from netlib import tcp, certutils, websockets, socks -from netlib.exceptions import HttpException, TcpDisconnect, TcpTimeout, TlsException, TcpException, \ - NetlibException -from netlib.http import http1, http2 +from netlib import exceptions +from netlib.http import http1 +from netlib.http import http2 +from netlib import basethread from pathod import log, language @@ -77,7 +77,7 @@ class SSLInfo(object): return "\n".join(parts) -class WebsocketFrameReader(threading.Thread): +class WebsocketFrameReader(basethread.BaseThread): def __init__( self, @@ -88,7 +88,7 @@ class WebsocketFrameReader(threading.Thread): ws_read_limit, timeout ): - threading.Thread.__init__(self) + basethread.BaseThread.__init__(self, "WebsocketFrameReader") self.timeout = timeout self.ws_read_limit = ws_read_limit self.logfp = logfp @@ -129,7 +129,7 @@ class WebsocketFrameReader(threading.Thread): with self.logger.ctx() as log: try: frm = websockets.Frame.from_file(self.rfile) - except TcpDisconnect: + except exceptions.TcpDisconnect: return self.frames_queue.put(frm) log("<< %s" % frm.header.human_readable()) @@ -241,8 +241,8 @@ class Pathoc(tcp.TCPClient): try: resp = self.protocol.read_response(self.rfile, treq(method="CONNECT")) if resp.status_code != 200: - raise HttpException("Unexpected status code: %s" % resp.status_code) - except HttpException as e: + raise exceptions.HttpException("Unexpected status code: %s" % resp.status_code) + except exceptions.HttpException as e: six.reraise(PathocError, PathocError( "Proxy CONNECT failed: %s" % repr(e) )) @@ -280,7 +280,7 @@ class Pathoc(tcp.TCPClient): connect_reply.msg, "SOCKS server error" ) - except (socks.SocksError, TcpDisconnect) as e: + except (socks.SocksError, exceptions.TcpDisconnect) as e: raise PathocError(str(e)) def connect(self, connect_to=None, showssl=False, fp=sys.stdout): @@ -310,7 +310,7 @@ class Pathoc(tcp.TCPClient): cipher_list=self.ciphers, alpn_protos=alpn_protos ) - except TlsException as v: + except exceptions.TlsException as v: raise PathocError(str(v)) self.sslinfo = SSLInfo( @@ -406,7 +406,7 @@ class Pathoc(tcp.TCPClient): Returns Response if we have a non-ignored response. - May raise a NetlibException + May raise a exceptions.NetlibException """ logger = log.ConnectionLogger( self.fp, @@ -424,10 +424,10 @@ class Pathoc(tcp.TCPClient): resp = self.protocol.read_response(self.rfile, treq(method=req["method"].encode())) resp.sslinfo = self.sslinfo - except HttpException as v: + except exceptions.HttpException as v: lg("Invalid server response: %s" % v) raise - except TcpTimeout: + except exceptions.TcpTimeout: if self.ignoretimeout: lg("Timeout (ignored)") return None @@ -451,7 +451,7 @@ class Pathoc(tcp.TCPClient): Returns Response if we have a non-ignored response. - May raise a NetlibException + May raise a exceptions.NetlibException """ if isinstance(r, basestring): r = language.parse_pathoc(r, self.use_http2).next() @@ -530,11 +530,11 @@ def main(args): # pragma: no cover # We consume the queue when we can, so it doesn't build up. for i_ in p.wait(timeout=0, finish=False): pass - except NetlibException: + except exceptions.NetlibException: break for i_ in p.wait(timeout=0.01, finish=True): pass - except TcpException as v: + except exceptions.TcpException as v: print(str(v), file=sys.stderr) continue except PathocError as v: diff --git a/pathod/test.py b/pathod/test.py index 114627296..3ba541b13 100644 --- a/pathod/test.py +++ b/pathod/test.py @@ -1,10 +1,10 @@ from six.moves import cStringIO as StringIO -import threading import time from six.moves import queue from . import pathod +from netlib import basethread class TimeoutError(Exception): @@ -95,11 +95,10 @@ class Daemon: self.thread.join() -class _PaThread(threading.Thread): +class _PaThread(basethread.BaseThread): def __init__(self, iface, q, ssl, daemonargs): - threading.Thread.__init__(self) - self.name = "PathodThread" + basethread.BaseThread.__init__(self, "PathodThread") self.iface, self.q, self.ssl = iface, q, ssl self.daemonargs = daemonargs self.server = None diff --git a/setup.py b/setup.py index 050043b36..cd1230448 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ from setuptools import setup, find_packages from codecs import open import os -import sys # Based on https://github.com/pypa/sampleproject/blob/master/setup.py # and https://python-packaging-user-guide.readthedocs.org/ @@ -73,6 +72,7 @@ setup( "lxml>=3.5.0, <3.7", "Pillow>=3.2, <3.3", "passlib>=1.6.5, <1.7", + "psutil>=4.2.0, <4.3", "pyasn1>=0.1.9, <0.2", "pyOpenSSL>=16.0, <17.0", "pyparsing>=2.1.3, <2.2", diff --git a/test/netlib/test_debug.py b/test/netlib/test_debug.py index d174bb5fd..c39d37528 100644 --- a/test/netlib/test_debug.py +++ b/test/netlib/test_debug.py @@ -1,6 +1,14 @@ +from __future__ import (absolute_import, print_function, division) +from six.moves import cStringIO as StringIO from netlib import debug +def test_dump_info(): + cs = StringIO() + debug.dump_info(None, None, file=cs) + assert cs.getvalue() + + def test_sysinfo(): assert debug.sysinfo()