Improve debugging of thread and other leaks
- Add basethread.BaseThread that all threads outside of test suites should use - Add a signal handler to mitmproxy, mitmdump and mitmweb that dumps resource information to screen when SIGUSR1 is received. - Improve thread naming throughout to make thread dumps understandable
This commit is contained in:
parent
5b9f07c81c
commit
09edbd9492
|
@ -5,8 +5,10 @@ import threading
|
|||
|
||||
from six.moves import queue
|
||||
|
||||
from netlib import basethread
|
||||
from mitmproxy import exceptions
|
||||
|
||||
|
||||
Events = frozenset([
|
||||
"clientconnect",
|
||||
"clientdisconnect",
|
||||
|
@ -95,12 +97,13 @@ class Master(object):
|
|||
self.should_exit.set()
|
||||
|
||||
|
||||
class ServerThread(threading.Thread):
|
||||
class ServerThread(basethread.BaseThread):
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
super(ServerThread, self).__init__()
|
||||
address = getattr(self.server, "address", None)
|
||||
self.name = "ServerThread ({})".format(repr(address))
|
||||
super(ServerThread, self).__init__(
|
||||
"ServerThread ({})".format(repr(address))
|
||||
)
|
||||
|
||||
def run(self):
|
||||
self.server.serve_forever()
|
||||
|
|
|
@ -47,6 +47,7 @@ def process_options(parser, options):
|
|||
sys.exit(0)
|
||||
if options.quiet:
|
||||
options.verbose = 0
|
||||
debug.register_info_dumper()
|
||||
return config.process_proxy_options(parser, options)
|
||||
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ from mitmproxy.protocol import base
|
|||
from mitmproxy.protocol import http
|
||||
import netlib.http
|
||||
from netlib import tcp
|
||||
from netlib import basethread
|
||||
from netlib.http import http2
|
||||
|
||||
|
||||
|
@ -261,10 +262,12 @@ class Http2Layer(base.Layer):
|
|||
self._cleanup_streams()
|
||||
|
||||
|
||||
class Http2SingleStreamLayer(http._HttpTransmissionLayer, threading.Thread):
|
||||
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
|
||||
|
||||
def __init__(self, ctx, stream_id, request_headers):
|
||||
super(Http2SingleStreamLayer, self).__init__(ctx, name="Thread-Http2SingleStreamLayer-{}".format(stream_id))
|
||||
super(Http2SingleStreamLayer, self).__init__(
|
||||
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
|
||||
)
|
||||
self.zombie = None
|
||||
self.client_stream_id = stream_id
|
||||
self.server_stream_id = None
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import absolute_import, print_function, division
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import netlib.exceptions
|
||||
|
@ -8,12 +7,13 @@ from mitmproxy import controller
|
|||
from mitmproxy import exceptions
|
||||
from mitmproxy import models
|
||||
from netlib.http import http1
|
||||
from netlib import basethread
|
||||
|
||||
|
||||
# TODO: Doesn't really belong into mitmproxy.protocol...
|
||||
|
||||
|
||||
class RequestReplayThread(threading.Thread):
|
||||
class RequestReplayThread(basethread.BaseThread):
|
||||
name = "RequestReplayThread"
|
||||
|
||||
def __init__(self, config, flow, event_queue, should_exit):
|
||||
|
@ -26,7 +26,9 @@ class RequestReplayThread(threading.Thread):
|
|||
self.channel = controller.Channel(event_queue, should_exit)
|
||||
else:
|
||||
self.channel = None
|
||||
super(RequestReplayThread, self).__init__()
|
||||
super(RequestReplayThread, self).__init__(
|
||||
"RequestReplay (%s)" % flow.request.url
|
||||
)
|
||||
|
||||
def run(self):
|
||||
r = self.flow.request
|
||||
|
|
|
@ -5,10 +5,10 @@ offload computations from mitmproxy's main master thread.
|
|||
from __future__ import absolute_import, print_function, division
|
||||
|
||||
from mitmproxy import controller
|
||||
import threading
|
||||
from netlib import basethread
|
||||
|
||||
|
||||
class ScriptThread(threading.Thread):
|
||||
class ScriptThread(basethread.BaseThread):
|
||||
name = "ScriptThread"
|
||||
|
||||
|
||||
|
@ -24,5 +24,8 @@ def concurrent(fn):
|
|||
if not obj.reply.acked:
|
||||
obj.reply.ack()
|
||||
obj.reply.take()
|
||||
ScriptThread(target=run).start()
|
||||
ScriptThread(
|
||||
"script.concurrent (%s)" % fn.__name__,
|
||||
target=run
|
||||
).start()
|
||||
return _concurrent
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
import time
|
||||
import threading
|
||||
|
||||
|
||||
class BaseThread(threading.Thread):
|
||||
def __init__(self, name, *args, **kwargs):
|
||||
super(BaseThread, self).__init__(name=name, *args, **kwargs)
|
||||
self._thread_started = time.time()
|
||||
|
||||
def _threadinfo(self):
|
||||
return "%s - age: %is" % (
|
||||
self.name,
|
||||
int(time.time() - self._thread_started)
|
||||
)
|
|
@ -1,29 +1,76 @@
|
|||
from __future__ import (absolute_import, print_function, division)
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import signal
|
||||
import platform
|
||||
|
||||
import psutil
|
||||
|
||||
from netlib import version
|
||||
|
||||
"""
|
||||
Some utilities to help with debugging.
|
||||
"""
|
||||
|
||||
def sysinfo():
|
||||
data = [
|
||||
"Mitmproxy verison: %s"%version.VERSION,
|
||||
"Python version: %s"%platform.python_version(),
|
||||
"Platform: %s"%platform.platform(),
|
||||
"Mitmproxy verison: %s" % version.VERSION,
|
||||
"Python version: %s" % platform.python_version(),
|
||||
"Platform: %s" % platform.platform(),
|
||||
]
|
||||
d = platform.linux_distribution()
|
||||
t = "Linux distro: %s %s %s"%d
|
||||
if d[0]: # pragma: no-cover
|
||||
t = "Linux distro: %s %s %s" % d
|
||||
if d[0]: # pragma: no-cover
|
||||
data.append(t)
|
||||
|
||||
d = platform.mac_ver()
|
||||
t = "Mac version: %s %s %s"%d
|
||||
if d[0]: # pragma: no-cover
|
||||
t = "Mac version: %s %s %s" % d
|
||||
if d[0]: # pragma: no-cover
|
||||
data.append(t)
|
||||
|
||||
d = platform.win32_ver()
|
||||
t = "Windows version: %s %s %s %s"%d
|
||||
if d[0]: # pragma: no-cover
|
||||
t = "Windows version: %s %s %s %s" % d
|
||||
if d[0]: # pragma: no-cover
|
||||
data.append(t)
|
||||
|
||||
return "\n".join(data)
|
||||
|
||||
|
||||
def dump_info(sig, frm, file=sys.stdout): # pragma: no cover
|
||||
p = psutil.Process()
|
||||
|
||||
print("****************************************************", file=file)
|
||||
print("Summary", file=file)
|
||||
print("=======", file=file)
|
||||
print("num threads: ", p.num_threads(), file=file)
|
||||
print("num fds: ", p.num_fds(), file=file)
|
||||
print("memory: ", p.memory_info(), file=file)
|
||||
|
||||
print(file=file)
|
||||
print("Threads", file=file)
|
||||
print("=======", file=file)
|
||||
bthreads = []
|
||||
for i in threading.enumerate():
|
||||
if hasattr(i, "_threadinfo"):
|
||||
bthreads.append(i)
|
||||
else:
|
||||
print(i.name, file=file)
|
||||
bthreads.sort(key=lambda x: x._thread_started)
|
||||
for i in bthreads:
|
||||
print(i._threadinfo(), file=file)
|
||||
|
||||
print(file=file)
|
||||
print("Files", file=file)
|
||||
print("=====", file=file)
|
||||
for i in p.open_files():
|
||||
print(i, file=file)
|
||||
|
||||
print(file=file)
|
||||
print("Connections", file=file)
|
||||
print("===========", file=file)
|
||||
for i in p.connections():
|
||||
print(i, file=file)
|
||||
|
||||
print("****************************************************", file=file)
|
||||
|
||||
|
||||
def register_info_dumper(): # pragma: no cover
|
||||
signal.signal(signal.SIGUSR1, dump_info)
|
||||
|
|
|
@ -17,7 +17,11 @@ import six
|
|||
import OpenSSL
|
||||
from OpenSSL import SSL
|
||||
|
||||
from netlib import certutils, version_check, basetypes, exceptions
|
||||
from netlib import certutils
|
||||
from netlib import version_check
|
||||
from netlib import basetypes
|
||||
from netlib import exceptions
|
||||
from netlib import basethread
|
||||
|
||||
# This is a rather hackish way to make sure that
|
||||
# the latest version of pyOpenSSL is actually installed.
|
||||
|
@ -900,12 +904,16 @@ class TCPServer(object):
|
|||
raise
|
||||
if self.socket in r:
|
||||
connection, client_address = self.socket.accept()
|
||||
t = threading.Thread(
|
||||
t = basethread.BaseThread(
|
||||
"TCPConnectionHandler (%s: %s:%s -> %s:%s)" % (
|
||||
self.__class__.__name__,
|
||||
client_address[0],
|
||||
client_address[1],
|
||||
self.address.host,
|
||||
self.address.port
|
||||
),
|
||||
target=self.connection_thread,
|
||||
args=(connection, client_address),
|
||||
name="ConnectionThread (%s:%s -> %s:%s)" %
|
||||
(client_address[0], client_address[1],
|
||||
self.address.host, self.address.port)
|
||||
)
|
||||
t.setDaemon(1)
|
||||
try:
|
||||
|
|
|
@ -8,15 +8,15 @@ from six.moves import queue
|
|||
import random
|
||||
import select
|
||||
import time
|
||||
import threading
|
||||
|
||||
import OpenSSL.crypto
|
||||
import six
|
||||
|
||||
from netlib import tcp, certutils, websockets, socks
|
||||
from netlib.exceptions import HttpException, TcpDisconnect, TcpTimeout, TlsException, TcpException, \
|
||||
NetlibException
|
||||
from netlib.http import http1, http2
|
||||
from netlib import exceptions
|
||||
from netlib.http import http1
|
||||
from netlib.http import http2
|
||||
from netlib import basethread
|
||||
|
||||
from pathod import log, language
|
||||
|
||||
|
@ -77,7 +77,7 @@ class SSLInfo(object):
|
|||
return "\n".join(parts)
|
||||
|
||||
|
||||
class WebsocketFrameReader(threading.Thread):
|
||||
class WebsocketFrameReader(basethread.BaseThread):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -88,7 +88,7 @@ class WebsocketFrameReader(threading.Thread):
|
|||
ws_read_limit,
|
||||
timeout
|
||||
):
|
||||
threading.Thread.__init__(self)
|
||||
basethread.BaseThread.__init__(self, "WebsocketFrameReader")
|
||||
self.timeout = timeout
|
||||
self.ws_read_limit = ws_read_limit
|
||||
self.logfp = logfp
|
||||
|
@ -129,7 +129,7 @@ class WebsocketFrameReader(threading.Thread):
|
|||
with self.logger.ctx() as log:
|
||||
try:
|
||||
frm = websockets.Frame.from_file(self.rfile)
|
||||
except TcpDisconnect:
|
||||
except exceptions.TcpDisconnect:
|
||||
return
|
||||
self.frames_queue.put(frm)
|
||||
log("<< %s" % frm.header.human_readable())
|
||||
|
@ -241,8 +241,8 @@ class Pathoc(tcp.TCPClient):
|
|||
try:
|
||||
resp = self.protocol.read_response(self.rfile, treq(method="CONNECT"))
|
||||
if resp.status_code != 200:
|
||||
raise HttpException("Unexpected status code: %s" % resp.status_code)
|
||||
except HttpException as e:
|
||||
raise exceptions.HttpException("Unexpected status code: %s" % resp.status_code)
|
||||
except exceptions.HttpException as e:
|
||||
six.reraise(PathocError, PathocError(
|
||||
"Proxy CONNECT failed: %s" % repr(e)
|
||||
))
|
||||
|
@ -280,7 +280,7 @@ class Pathoc(tcp.TCPClient):
|
|||
connect_reply.msg,
|
||||
"SOCKS server error"
|
||||
)
|
||||
except (socks.SocksError, TcpDisconnect) as e:
|
||||
except (socks.SocksError, exceptions.TcpDisconnect) as e:
|
||||
raise PathocError(str(e))
|
||||
|
||||
def connect(self, connect_to=None, showssl=False, fp=sys.stdout):
|
||||
|
@ -310,7 +310,7 @@ class Pathoc(tcp.TCPClient):
|
|||
cipher_list=self.ciphers,
|
||||
alpn_protos=alpn_protos
|
||||
)
|
||||
except TlsException as v:
|
||||
except exceptions.TlsException as v:
|
||||
raise PathocError(str(v))
|
||||
|
||||
self.sslinfo = SSLInfo(
|
||||
|
@ -406,7 +406,7 @@ class Pathoc(tcp.TCPClient):
|
|||
|
||||
Returns Response if we have a non-ignored response.
|
||||
|
||||
May raise a NetlibException
|
||||
May raise a exceptions.NetlibException
|
||||
"""
|
||||
logger = log.ConnectionLogger(
|
||||
self.fp,
|
||||
|
@ -424,10 +424,10 @@ class Pathoc(tcp.TCPClient):
|
|||
|
||||
resp = self.protocol.read_response(self.rfile, treq(method=req["method"].encode()))
|
||||
resp.sslinfo = self.sslinfo
|
||||
except HttpException as v:
|
||||
except exceptions.HttpException as v:
|
||||
lg("Invalid server response: %s" % v)
|
||||
raise
|
||||
except TcpTimeout:
|
||||
except exceptions.TcpTimeout:
|
||||
if self.ignoretimeout:
|
||||
lg("Timeout (ignored)")
|
||||
return None
|
||||
|
@ -451,7 +451,7 @@ class Pathoc(tcp.TCPClient):
|
|||
|
||||
Returns Response if we have a non-ignored response.
|
||||
|
||||
May raise a NetlibException
|
||||
May raise a exceptions.NetlibException
|
||||
"""
|
||||
if isinstance(r, basestring):
|
||||
r = language.parse_pathoc(r, self.use_http2).next()
|
||||
|
@ -530,11 +530,11 @@ def main(args): # pragma: no cover
|
|||
# We consume the queue when we can, so it doesn't build up.
|
||||
for i_ in p.wait(timeout=0, finish=False):
|
||||
pass
|
||||
except NetlibException:
|
||||
except exceptions.NetlibException:
|
||||
break
|
||||
for i_ in p.wait(timeout=0.01, finish=True):
|
||||
pass
|
||||
except TcpException as v:
|
||||
except exceptions.TcpException as v:
|
||||
print(str(v), file=sys.stderr)
|
||||
continue
|
||||
except PathocError as v:
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from six.moves import cStringIO as StringIO
|
||||
import threading
|
||||
import time
|
||||
|
||||
from six.moves import queue
|
||||
|
||||
from . import pathod
|
||||
from netlib import basethread
|
||||
|
||||
|
||||
class TimeoutError(Exception):
|
||||
|
@ -95,11 +95,10 @@ class Daemon:
|
|||
self.thread.join()
|
||||
|
||||
|
||||
class _PaThread(threading.Thread):
|
||||
class _PaThread(basethread.BaseThread):
|
||||
|
||||
def __init__(self, iface, q, ssl, daemonargs):
|
||||
threading.Thread.__init__(self)
|
||||
self.name = "PathodThread"
|
||||
basethread.BaseThread.__init__(self, "PathodThread")
|
||||
self.iface, self.q, self.ssl = iface, q, ssl
|
||||
self.daemonargs = daemonargs
|
||||
self.server = None
|
||||
|
|
2
setup.py
2
setup.py
|
@ -1,7 +1,6 @@
|
|||
from setuptools import setup, find_packages
|
||||
from codecs import open
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Based on https://github.com/pypa/sampleproject/blob/master/setup.py
|
||||
# and https://python-packaging-user-guide.readthedocs.org/
|
||||
|
@ -73,6 +72,7 @@ setup(
|
|||
"lxml>=3.5.0, <3.7",
|
||||
"Pillow>=3.2, <3.3",
|
||||
"passlib>=1.6.5, <1.7",
|
||||
"psutil>=4.2.0, <4.3",
|
||||
"pyasn1>=0.1.9, <0.2",
|
||||
"pyOpenSSL>=16.0, <17.0",
|
||||
"pyparsing>=2.1.3, <2.2",
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
from __future__ import (absolute_import, print_function, division)
|
||||
from six.moves import cStringIO as StringIO
|
||||
|
||||
from netlib import debug
|
||||
|
||||
|
||||
def test_dump_info():
|
||||
cs = StringIO()
|
||||
debug.dump_info(None, None, file=cs)
|
||||
assert cs.getvalue()
|
||||
|
||||
|
||||
def test_sysinfo():
|
||||
assert debug.sysinfo()
|
||||
|
|
Loading…
Reference in New Issue