Improve debugging of thread and other leaks
- Add basethread.BaseThread that all threads outside of test suites should use - Add a signal handler to mitmproxy, mitmdump and mitmweb that dumps resource information to screen when SIGUSR1 is received. - Improve thread naming throughout to make thread dumps understandable
This commit is contained in:
parent
5b9f07c81c
commit
09edbd9492
|
@ -5,8 +5,10 @@ import threading
|
||||||
|
|
||||||
from six.moves import queue
|
from six.moves import queue
|
||||||
|
|
||||||
|
from netlib import basethread
|
||||||
from mitmproxy import exceptions
|
from mitmproxy import exceptions
|
||||||
|
|
||||||
|
|
||||||
Events = frozenset([
|
Events = frozenset([
|
||||||
"clientconnect",
|
"clientconnect",
|
||||||
"clientdisconnect",
|
"clientdisconnect",
|
||||||
|
@ -95,12 +97,13 @@ class Master(object):
|
||||||
self.should_exit.set()
|
self.should_exit.set()
|
||||||
|
|
||||||
|
|
||||||
class ServerThread(threading.Thread):
|
class ServerThread(basethread.BaseThread):
|
||||||
def __init__(self, server):
|
def __init__(self, server):
|
||||||
self.server = server
|
self.server = server
|
||||||
super(ServerThread, self).__init__()
|
|
||||||
address = getattr(self.server, "address", None)
|
address = getattr(self.server, "address", None)
|
||||||
self.name = "ServerThread ({})".format(repr(address))
|
super(ServerThread, self).__init__(
|
||||||
|
"ServerThread ({})".format(repr(address))
|
||||||
|
)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.server.serve_forever()
|
self.server.serve_forever()
|
||||||
|
|
|
@ -47,6 +47,7 @@ def process_options(parser, options):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
if options.quiet:
|
if options.quiet:
|
||||||
options.verbose = 0
|
options.verbose = 0
|
||||||
|
debug.register_info_dumper()
|
||||||
return config.process_proxy_options(parser, options)
|
return config.process_proxy_options(parser, options)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ from mitmproxy.protocol import base
|
||||||
from mitmproxy.protocol import http
|
from mitmproxy.protocol import http
|
||||||
import netlib.http
|
import netlib.http
|
||||||
from netlib import tcp
|
from netlib import tcp
|
||||||
|
from netlib import basethread
|
||||||
from netlib.http import http2
|
from netlib.http import http2
|
||||||
|
|
||||||
|
|
||||||
|
@ -261,10 +262,12 @@ class Http2Layer(base.Layer):
|
||||||
self._cleanup_streams()
|
self._cleanup_streams()
|
||||||
|
|
||||||
|
|
||||||
class Http2SingleStreamLayer(http._HttpTransmissionLayer, threading.Thread):
|
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
|
||||||
|
|
||||||
def __init__(self, ctx, stream_id, request_headers):
|
def __init__(self, ctx, stream_id, request_headers):
|
||||||
super(Http2SingleStreamLayer, self).__init__(ctx, name="Thread-Http2SingleStreamLayer-{}".format(stream_id))
|
super(Http2SingleStreamLayer, self).__init__(
|
||||||
|
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
|
||||||
|
)
|
||||||
self.zombie = None
|
self.zombie = None
|
||||||
self.client_stream_id = stream_id
|
self.client_stream_id = stream_id
|
||||||
self.server_stream_id = None
|
self.server_stream_id = None
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from __future__ import absolute_import, print_function, division
|
from __future__ import absolute_import, print_function, division
|
||||||
|
|
||||||
import threading
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import netlib.exceptions
|
import netlib.exceptions
|
||||||
|
@ -8,12 +7,13 @@ from mitmproxy import controller
|
||||||
from mitmproxy import exceptions
|
from mitmproxy import exceptions
|
||||||
from mitmproxy import models
|
from mitmproxy import models
|
||||||
from netlib.http import http1
|
from netlib.http import http1
|
||||||
|
from netlib import basethread
|
||||||
|
|
||||||
|
|
||||||
# TODO: Doesn't really belong into mitmproxy.protocol...
|
# TODO: Doesn't really belong into mitmproxy.protocol...
|
||||||
|
|
||||||
|
|
||||||
class RequestReplayThread(threading.Thread):
|
class RequestReplayThread(basethread.BaseThread):
|
||||||
name = "RequestReplayThread"
|
name = "RequestReplayThread"
|
||||||
|
|
||||||
def __init__(self, config, flow, event_queue, should_exit):
|
def __init__(self, config, flow, event_queue, should_exit):
|
||||||
|
@ -26,7 +26,9 @@ class RequestReplayThread(threading.Thread):
|
||||||
self.channel = controller.Channel(event_queue, should_exit)
|
self.channel = controller.Channel(event_queue, should_exit)
|
||||||
else:
|
else:
|
||||||
self.channel = None
|
self.channel = None
|
||||||
super(RequestReplayThread, self).__init__()
|
super(RequestReplayThread, self).__init__(
|
||||||
|
"RequestReplay (%s)" % flow.request.url
|
||||||
|
)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
r = self.flow.request
|
r = self.flow.request
|
||||||
|
|
|
@ -5,10 +5,10 @@ offload computations from mitmproxy's main master thread.
|
||||||
from __future__ import absolute_import, print_function, division
|
from __future__ import absolute_import, print_function, division
|
||||||
|
|
||||||
from mitmproxy import controller
|
from mitmproxy import controller
|
||||||
import threading
|
from netlib import basethread
|
||||||
|
|
||||||
|
|
||||||
class ScriptThread(threading.Thread):
|
class ScriptThread(basethread.BaseThread):
|
||||||
name = "ScriptThread"
|
name = "ScriptThread"
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,5 +24,8 @@ def concurrent(fn):
|
||||||
if not obj.reply.acked:
|
if not obj.reply.acked:
|
||||||
obj.reply.ack()
|
obj.reply.ack()
|
||||||
obj.reply.take()
|
obj.reply.take()
|
||||||
ScriptThread(target=run).start()
|
ScriptThread(
|
||||||
|
"script.concurrent (%s)" % fn.__name__,
|
||||||
|
target=run
|
||||||
|
).start()
|
||||||
return _concurrent
|
return _concurrent
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class BaseThread(threading.Thread):
|
||||||
|
def __init__(self, name, *args, **kwargs):
|
||||||
|
super(BaseThread, self).__init__(name=name, *args, **kwargs)
|
||||||
|
self._thread_started = time.time()
|
||||||
|
|
||||||
|
def _threadinfo(self):
|
||||||
|
return "%s - age: %is" % (
|
||||||
|
self.name,
|
||||||
|
int(time.time() - self._thread_started)
|
||||||
|
)
|
|
@ -1,29 +1,76 @@
|
||||||
|
from __future__ import (absolute_import, print_function, division)
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import signal
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
from netlib import version
|
from netlib import version
|
||||||
|
|
||||||
"""
|
|
||||||
Some utilities to help with debugging.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def sysinfo():
|
def sysinfo():
|
||||||
data = [
|
data = [
|
||||||
"Mitmproxy verison: %s"%version.VERSION,
|
"Mitmproxy verison: %s" % version.VERSION,
|
||||||
"Python version: %s"%platform.python_version(),
|
"Python version: %s" % platform.python_version(),
|
||||||
"Platform: %s"%platform.platform(),
|
"Platform: %s" % platform.platform(),
|
||||||
]
|
]
|
||||||
d = platform.linux_distribution()
|
d = platform.linux_distribution()
|
||||||
t = "Linux distro: %s %s %s"%d
|
t = "Linux distro: %s %s %s" % d
|
||||||
if d[0]: # pragma: no-cover
|
if d[0]: # pragma: no-cover
|
||||||
data.append(t)
|
data.append(t)
|
||||||
|
|
||||||
d = platform.mac_ver()
|
d = platform.mac_ver()
|
||||||
t = "Mac version: %s %s %s"%d
|
t = "Mac version: %s %s %s" % d
|
||||||
if d[0]: # pragma: no-cover
|
if d[0]: # pragma: no-cover
|
||||||
data.append(t)
|
data.append(t)
|
||||||
|
|
||||||
d = platform.win32_ver()
|
d = platform.win32_ver()
|
||||||
t = "Windows version: %s %s %s %s"%d
|
t = "Windows version: %s %s %s %s" % d
|
||||||
if d[0]: # pragma: no-cover
|
if d[0]: # pragma: no-cover
|
||||||
data.append(t)
|
data.append(t)
|
||||||
|
|
||||||
return "\n".join(data)
|
return "\n".join(data)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_info(sig, frm, file=sys.stdout): # pragma: no cover
|
||||||
|
p = psutil.Process()
|
||||||
|
|
||||||
|
print("****************************************************", file=file)
|
||||||
|
print("Summary", file=file)
|
||||||
|
print("=======", file=file)
|
||||||
|
print("num threads: ", p.num_threads(), file=file)
|
||||||
|
print("num fds: ", p.num_fds(), file=file)
|
||||||
|
print("memory: ", p.memory_info(), file=file)
|
||||||
|
|
||||||
|
print(file=file)
|
||||||
|
print("Threads", file=file)
|
||||||
|
print("=======", file=file)
|
||||||
|
bthreads = []
|
||||||
|
for i in threading.enumerate():
|
||||||
|
if hasattr(i, "_threadinfo"):
|
||||||
|
bthreads.append(i)
|
||||||
|
else:
|
||||||
|
print(i.name, file=file)
|
||||||
|
bthreads.sort(key=lambda x: x._thread_started)
|
||||||
|
for i in bthreads:
|
||||||
|
print(i._threadinfo(), file=file)
|
||||||
|
|
||||||
|
print(file=file)
|
||||||
|
print("Files", file=file)
|
||||||
|
print("=====", file=file)
|
||||||
|
for i in p.open_files():
|
||||||
|
print(i, file=file)
|
||||||
|
|
||||||
|
print(file=file)
|
||||||
|
print("Connections", file=file)
|
||||||
|
print("===========", file=file)
|
||||||
|
for i in p.connections():
|
||||||
|
print(i, file=file)
|
||||||
|
|
||||||
|
print("****************************************************", file=file)
|
||||||
|
|
||||||
|
|
||||||
|
def register_info_dumper(): # pragma: no cover
|
||||||
|
signal.signal(signal.SIGUSR1, dump_info)
|
||||||
|
|
|
@ -17,7 +17,11 @@ import six
|
||||||
import OpenSSL
|
import OpenSSL
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
from netlib import certutils, version_check, basetypes, exceptions
|
from netlib import certutils
|
||||||
|
from netlib import version_check
|
||||||
|
from netlib import basetypes
|
||||||
|
from netlib import exceptions
|
||||||
|
from netlib import basethread
|
||||||
|
|
||||||
# This is a rather hackish way to make sure that
|
# This is a rather hackish way to make sure that
|
||||||
# the latest version of pyOpenSSL is actually installed.
|
# the latest version of pyOpenSSL is actually installed.
|
||||||
|
@ -900,12 +904,16 @@ class TCPServer(object):
|
||||||
raise
|
raise
|
||||||
if self.socket in r:
|
if self.socket in r:
|
||||||
connection, client_address = self.socket.accept()
|
connection, client_address = self.socket.accept()
|
||||||
t = threading.Thread(
|
t = basethread.BaseThread(
|
||||||
|
"TCPConnectionHandler (%s: %s:%s -> %s:%s)" % (
|
||||||
|
self.__class__.__name__,
|
||||||
|
client_address[0],
|
||||||
|
client_address[1],
|
||||||
|
self.address.host,
|
||||||
|
self.address.port
|
||||||
|
),
|
||||||
target=self.connection_thread,
|
target=self.connection_thread,
|
||||||
args=(connection, client_address),
|
args=(connection, client_address),
|
||||||
name="ConnectionThread (%s:%s -> %s:%s)" %
|
|
||||||
(client_address[0], client_address[1],
|
|
||||||
self.address.host, self.address.port)
|
|
||||||
)
|
)
|
||||||
t.setDaemon(1)
|
t.setDaemon(1)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -8,15 +8,15 @@ from six.moves import queue
|
||||||
import random
|
import random
|
||||||
import select
|
import select
|
||||||
import time
|
import time
|
||||||
import threading
|
|
||||||
|
|
||||||
import OpenSSL.crypto
|
import OpenSSL.crypto
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from netlib import tcp, certutils, websockets, socks
|
from netlib import tcp, certutils, websockets, socks
|
||||||
from netlib.exceptions import HttpException, TcpDisconnect, TcpTimeout, TlsException, TcpException, \
|
from netlib import exceptions
|
||||||
NetlibException
|
from netlib.http import http1
|
||||||
from netlib.http import http1, http2
|
from netlib.http import http2
|
||||||
|
from netlib import basethread
|
||||||
|
|
||||||
from pathod import log, language
|
from pathod import log, language
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ class SSLInfo(object):
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
class WebsocketFrameReader(threading.Thread):
|
class WebsocketFrameReader(basethread.BaseThread):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -88,7 +88,7 @@ class WebsocketFrameReader(threading.Thread):
|
||||||
ws_read_limit,
|
ws_read_limit,
|
||||||
timeout
|
timeout
|
||||||
):
|
):
|
||||||
threading.Thread.__init__(self)
|
basethread.BaseThread.__init__(self, "WebsocketFrameReader")
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.ws_read_limit = ws_read_limit
|
self.ws_read_limit = ws_read_limit
|
||||||
self.logfp = logfp
|
self.logfp = logfp
|
||||||
|
@ -129,7 +129,7 @@ class WebsocketFrameReader(threading.Thread):
|
||||||
with self.logger.ctx() as log:
|
with self.logger.ctx() as log:
|
||||||
try:
|
try:
|
||||||
frm = websockets.Frame.from_file(self.rfile)
|
frm = websockets.Frame.from_file(self.rfile)
|
||||||
except TcpDisconnect:
|
except exceptions.TcpDisconnect:
|
||||||
return
|
return
|
||||||
self.frames_queue.put(frm)
|
self.frames_queue.put(frm)
|
||||||
log("<< %s" % frm.header.human_readable())
|
log("<< %s" % frm.header.human_readable())
|
||||||
|
@ -241,8 +241,8 @@ class Pathoc(tcp.TCPClient):
|
||||||
try:
|
try:
|
||||||
resp = self.protocol.read_response(self.rfile, treq(method="CONNECT"))
|
resp = self.protocol.read_response(self.rfile, treq(method="CONNECT"))
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
raise HttpException("Unexpected status code: %s" % resp.status_code)
|
raise exceptions.HttpException("Unexpected status code: %s" % resp.status_code)
|
||||||
except HttpException as e:
|
except exceptions.HttpException as e:
|
||||||
six.reraise(PathocError, PathocError(
|
six.reraise(PathocError, PathocError(
|
||||||
"Proxy CONNECT failed: %s" % repr(e)
|
"Proxy CONNECT failed: %s" % repr(e)
|
||||||
))
|
))
|
||||||
|
@ -280,7 +280,7 @@ class Pathoc(tcp.TCPClient):
|
||||||
connect_reply.msg,
|
connect_reply.msg,
|
||||||
"SOCKS server error"
|
"SOCKS server error"
|
||||||
)
|
)
|
||||||
except (socks.SocksError, TcpDisconnect) as e:
|
except (socks.SocksError, exceptions.TcpDisconnect) as e:
|
||||||
raise PathocError(str(e))
|
raise PathocError(str(e))
|
||||||
|
|
||||||
def connect(self, connect_to=None, showssl=False, fp=sys.stdout):
|
def connect(self, connect_to=None, showssl=False, fp=sys.stdout):
|
||||||
|
@ -310,7 +310,7 @@ class Pathoc(tcp.TCPClient):
|
||||||
cipher_list=self.ciphers,
|
cipher_list=self.ciphers,
|
||||||
alpn_protos=alpn_protos
|
alpn_protos=alpn_protos
|
||||||
)
|
)
|
||||||
except TlsException as v:
|
except exceptions.TlsException as v:
|
||||||
raise PathocError(str(v))
|
raise PathocError(str(v))
|
||||||
|
|
||||||
self.sslinfo = SSLInfo(
|
self.sslinfo = SSLInfo(
|
||||||
|
@ -406,7 +406,7 @@ class Pathoc(tcp.TCPClient):
|
||||||
|
|
||||||
Returns Response if we have a non-ignored response.
|
Returns Response if we have a non-ignored response.
|
||||||
|
|
||||||
May raise a NetlibException
|
May raise a exceptions.NetlibException
|
||||||
"""
|
"""
|
||||||
logger = log.ConnectionLogger(
|
logger = log.ConnectionLogger(
|
||||||
self.fp,
|
self.fp,
|
||||||
|
@ -424,10 +424,10 @@ class Pathoc(tcp.TCPClient):
|
||||||
|
|
||||||
resp = self.protocol.read_response(self.rfile, treq(method=req["method"].encode()))
|
resp = self.protocol.read_response(self.rfile, treq(method=req["method"].encode()))
|
||||||
resp.sslinfo = self.sslinfo
|
resp.sslinfo = self.sslinfo
|
||||||
except HttpException as v:
|
except exceptions.HttpException as v:
|
||||||
lg("Invalid server response: %s" % v)
|
lg("Invalid server response: %s" % v)
|
||||||
raise
|
raise
|
||||||
except TcpTimeout:
|
except exceptions.TcpTimeout:
|
||||||
if self.ignoretimeout:
|
if self.ignoretimeout:
|
||||||
lg("Timeout (ignored)")
|
lg("Timeout (ignored)")
|
||||||
return None
|
return None
|
||||||
|
@ -451,7 +451,7 @@ class Pathoc(tcp.TCPClient):
|
||||||
|
|
||||||
Returns Response if we have a non-ignored response.
|
Returns Response if we have a non-ignored response.
|
||||||
|
|
||||||
May raise a NetlibException
|
May raise a exceptions.NetlibException
|
||||||
"""
|
"""
|
||||||
if isinstance(r, basestring):
|
if isinstance(r, basestring):
|
||||||
r = language.parse_pathoc(r, self.use_http2).next()
|
r = language.parse_pathoc(r, self.use_http2).next()
|
||||||
|
@ -530,11 +530,11 @@ def main(args): # pragma: no cover
|
||||||
# We consume the queue when we can, so it doesn't build up.
|
# We consume the queue when we can, so it doesn't build up.
|
||||||
for i_ in p.wait(timeout=0, finish=False):
|
for i_ in p.wait(timeout=0, finish=False):
|
||||||
pass
|
pass
|
||||||
except NetlibException:
|
except exceptions.NetlibException:
|
||||||
break
|
break
|
||||||
for i_ in p.wait(timeout=0.01, finish=True):
|
for i_ in p.wait(timeout=0.01, finish=True):
|
||||||
pass
|
pass
|
||||||
except TcpException as v:
|
except exceptions.TcpException as v:
|
||||||
print(str(v), file=sys.stderr)
|
print(str(v), file=sys.stderr)
|
||||||
continue
|
continue
|
||||||
except PathocError as v:
|
except PathocError as v:
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from six.moves import cStringIO as StringIO
|
from six.moves import cStringIO as StringIO
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from six.moves import queue
|
from six.moves import queue
|
||||||
|
|
||||||
from . import pathod
|
from . import pathod
|
||||||
|
from netlib import basethread
|
||||||
|
|
||||||
|
|
||||||
class TimeoutError(Exception):
|
class TimeoutError(Exception):
|
||||||
|
@ -95,11 +95,10 @@ class Daemon:
|
||||||
self.thread.join()
|
self.thread.join()
|
||||||
|
|
||||||
|
|
||||||
class _PaThread(threading.Thread):
|
class _PaThread(basethread.BaseThread):
|
||||||
|
|
||||||
def __init__(self, iface, q, ssl, daemonargs):
|
def __init__(self, iface, q, ssl, daemonargs):
|
||||||
threading.Thread.__init__(self)
|
basethread.BaseThread.__init__(self, "PathodThread")
|
||||||
self.name = "PathodThread"
|
|
||||||
self.iface, self.q, self.ssl = iface, q, ssl
|
self.iface, self.q, self.ssl = iface, q, ssl
|
||||||
self.daemonargs = daemonargs
|
self.daemonargs = daemonargs
|
||||||
self.server = None
|
self.server = None
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -1,7 +1,6 @@
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
from codecs import open
|
from codecs import open
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
# Based on https://github.com/pypa/sampleproject/blob/master/setup.py
|
# Based on https://github.com/pypa/sampleproject/blob/master/setup.py
|
||||||
# and https://python-packaging-user-guide.readthedocs.org/
|
# and https://python-packaging-user-guide.readthedocs.org/
|
||||||
|
@ -73,6 +72,7 @@ setup(
|
||||||
"lxml>=3.5.0, <3.7",
|
"lxml>=3.5.0, <3.7",
|
||||||
"Pillow>=3.2, <3.3",
|
"Pillow>=3.2, <3.3",
|
||||||
"passlib>=1.6.5, <1.7",
|
"passlib>=1.6.5, <1.7",
|
||||||
|
"psutil>=4.2.0, <4.3",
|
||||||
"pyasn1>=0.1.9, <0.2",
|
"pyasn1>=0.1.9, <0.2",
|
||||||
"pyOpenSSL>=16.0, <17.0",
|
"pyOpenSSL>=16.0, <17.0",
|
||||||
"pyparsing>=2.1.3, <2.2",
|
"pyparsing>=2.1.3, <2.2",
|
||||||
|
|
|
@ -1,6 +1,14 @@
|
||||||
|
from __future__ import (absolute_import, print_function, division)
|
||||||
|
from six.moves import cStringIO as StringIO
|
||||||
|
|
||||||
from netlib import debug
|
from netlib import debug
|
||||||
|
|
||||||
|
|
||||||
|
def test_dump_info():
|
||||||
|
cs = StringIO()
|
||||||
|
debug.dump_info(None, None, file=cs)
|
||||||
|
assert cs.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def test_sysinfo():
|
def test_sysinfo():
|
||||||
assert debug.sysinfo()
|
assert debug.sysinfo()
|
||||||
|
|
Loading…
Reference in New Issue