Merge remote-tracking branch 'upstream/main' into quic

This commit is contained in:
Manuel Meitinger 2022-11-18 21:44:18 +01:00
commit 75504e38a2
71 changed files with 308 additions and 190 deletions

View File

@ -2,7 +2,13 @@
## Unreleased: mitmproxy next
* ASGI/WSGI apps can now listen on all ports for a specific hostname.
This makes it simpler to accept both HTTP and HTTPS.
### Breaking Changes
* The `onboarding_port` option has been removed. The onboarding app now responds
to all requests for the hostname specified in `onboarding_host`.
## 02 November 2022: mitmproxy 9.0.1

View File

@ -1,19 +1,19 @@
"""
Use mitmproxy's filter pattern in scripts.
"""
from __future__ import annotations
import logging
from mitmproxy import ctx, http
from mitmproxy import http
from mitmproxy import flowfilter
class Filter:
def __init__(self):
self.filter: flowfilter.TFilter = None
filter: flowfilter.TFilter
def configure(self, updated):
if "flowfilter" in updated:
self.filter = flowfilter.parse(ctx.options.flowfilter)
self.filter = flowfilter.parse(".")
def load(self, l):
l.add_option("flowfilter", str, "", "Check that flow matches filter.")

View File

@ -2,6 +2,7 @@ import asyncio
import logging
import traceback
import urllib.parse
from typing import Optional
import asgiref.compatibility
import asgiref.wsgi
@ -20,7 +21,7 @@ class ASGIApp:
- It currently only implements the HTTP protocol (Lifespan and WebSocket are unimplemented).
"""
def __init__(self, asgi_app, host: str, port: int):
def __init__(self, asgi_app, host: str, port: Optional[int]):
asgi_app = asgiref.compatibility.guarantee_single_callable(asgi_app)
self.asgi_app, self.host, self.port = asgi_app, host, port
@ -30,7 +31,8 @@ class ASGIApp:
def should_serve(self, flow: http.HTTPFlow) -> bool:
return bool(
(flow.request.pretty_host, flow.request.port) == (self.host, self.port)
flow.request.pretty_host == self.host
and (self.port is None or flow.request.port == self.port)
and flow.live
and not flow.error
and not flow.response
@ -42,7 +44,7 @@ class ASGIApp:
class WSGIApp(ASGIApp):
def __init__(self, wsgi_app, host: str, port: int):
def __init__(self, wsgi_app, host: str, port: Optional[int]):
asgi_app = asgiref.wsgi.WsgiToAsgi(wsgi_app)
super().__init__(asgi_app, host, port)
@ -124,6 +126,7 @@ async def serve(app, flow: http.HTTPFlow):
)
flow.response.decode()
elif event["type"] == "http.response.body":
assert flow.response
flow.response.content += event.get("body", b"")
if not event.get("more_body", False):
nonlocal sent_response

View File

@ -36,7 +36,7 @@ def parse_spec(option: str) -> BlockSpec:
class BlockList:
def __init__(self):
def __init__(self) -> None:
self.items: list[BlockSpec] = []
def load(self, loader):

View File

@ -152,6 +152,7 @@ class ClientPlayback:
while True:
self.inflight = await self.queue.get()
try:
assert self.inflight
h = ReplayHandler(self.inflight, self.options)
if ctx.options.client_replay_concurrency == -1:
asyncio_utils.create_task(

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import logging
import itertools
@ -29,7 +30,7 @@ def indent(n: int, text: str) -> str:
return "\n".join(pad + i for i in l)
CONTENTVIEW_STYLES = {
CONTENTVIEW_STYLES: dict[str, dict[str, str | bool]] = {
"highlight": dict(bold=True),
"offset": dict(fg="blue"),
"header": dict(fg="green", bold=True),

View File

@ -9,7 +9,7 @@ from mitmproxy import log
class ErrorCheck:
"""Monitor startup for error log entries, and terminate immediately if there are some."""
def __init__(self, log_to_stderr: bool = False):
def __init__(self, log_to_stderr: bool = False) -> None:
self.log_to_stderr = log_to_stderr
self.logger = ErrorCheckHandler()
@ -31,7 +31,7 @@ class ErrorCheck:
class ErrorCheckHandler(log.MitmLogHandler):
def __init__(self):
def __init__(self) -> None:
super().__init__(logging.ERROR)
self.has_errored: list[logging.LogRecord] = []

View File

@ -10,7 +10,7 @@ from mitmproxy.utils import signals
class EventStore:
def __init__(self, size=10000):
def __init__(self, size: int = 10000) -> None:
self.data: collections.deque[LogEntry] = collections.deque(maxlen=size)
self.sig_add = signals.SyncSignal(lambda entry: None)
self.sig_refresh = signals.SyncSignal(lambda: None)

View File

@ -76,7 +76,7 @@ def file_candidates(url: str, spec: MapLocalSpec) -> list[Path]:
class MapLocal:
def __init__(self):
def __init__(self) -> None:
self.replacements: list[MapLocalSpec] = []
def load(self, loader):

View File

@ -24,7 +24,7 @@ def parse_map_remote_spec(option: str) -> MapRemoteSpec:
class MapRemote:
def __init__(self):
def __init__(self) -> None:
self.replacements: list[MapRemoteSpec] = []
def load(self, loader):

View File

@ -7,7 +7,7 @@ from mitmproxy.addons.modifyheaders import parse_modify_spec, ModifySpec
class ModifyBody:
def __init__(self):
def __init__(self) -> None:
self.replacements: list[ModifySpec] = []
def load(self, loader):

View File

@ -51,7 +51,7 @@ def parse_modify_spec(option: str, subject_is_regex: bool) -> ModifySpec:
class ModifyHeaders:
def __init__(self):
def __init__(self) -> None:
self.replacements: list[ModifySpec] = []
def load(self, loader):

View File

@ -3,14 +3,13 @@ from mitmproxy.addons.onboardingapp import app
from mitmproxy import ctx
APP_HOST = "mitm.it"
APP_PORT = 80
class Onboarding(asgiapp.WSGIApp):
name = "onboarding"
def __init__(self):
super().__init__(app, APP_HOST, APP_PORT)
super().__init__(app, APP_HOST, None)
def load(self, loader):
loader.add_option(
@ -25,13 +24,9 @@ class Onboarding(asgiapp.WSGIApp):
entry for the app domain is not present.
""",
)
loader.add_option(
"onboarding_port", int, APP_PORT, "Port to serve the onboarding app from."
)
def configure(self, updated):
self.host = ctx.options.onboarding_host
self.port = ctx.options.onboarding_port
app.config["CONFDIR"] = ctx.options.confdir
async def request(self, f):

View File

@ -22,7 +22,7 @@ REALM = "mitmproxy"
class ProxyAuth:
validator: Validator | None = None
def __init__(self):
def __init__(self) -> None:
self.authenticated: MutableMapping[
connection.Client, tuple[str, str]
] = weakref.WeakKeyDictionary()

View File

@ -197,7 +197,7 @@ class Proxyserver(ServerManager):
def running(self):
self.is_running = True
def configure(self, updated):
def configure(self, updated) -> None:
if "stream_large_bodies" in updated:
try:
human.parse_size(ctx.options.stream_large_bodies)

View File

@ -77,6 +77,7 @@ class Save:
self.maybe_rotate_to_new_file()
except OSError as e:
raise exceptions.OptionsError(str(e)) from e
assert self.stream
self.stream.flt = self.filt
else:
self.done()

View File

@ -39,6 +39,17 @@ def load_script(path: str) -> Optional[types.ModuleType]:
loader.exec_module(m)
if not getattr(m, "name", None):
m.name = path # type: ignore
except ImportError as e:
err_msg = str(e)
if getattr(sys, "frozen", False):
err_msg = (
f"{err_msg}. \n"
f"Note that mitmproxy's binaries include their own Python environment. "
f"If your addon requires the installation of additional dependencies, "
f"please install mitmproxy from PyPI "
f"(https://docs.mitmproxy.org/stable/overview-installation/#installation-from-the-python-package-index-pypi)."
)
script_error_handler(path, e, msg=err_msg)
except Exception as e:
script_error_handler(path, e, msg=str(e))
finally:
@ -79,7 +90,7 @@ class Script:
self.name = "scriptmanager:" + path
self.path = path
self.fullpath = os.path.expanduser(path.strip("'\" "))
self.ns = None
self.ns: types.ModuleType | None = None
self.is_running = False
if not os.path.isfile(self.fullpath):
@ -126,7 +137,7 @@ class Script:
ctx.master.addons.invoke_addon_sync(self.ns, hooks.RunningHook())
async def watcher(self):
last_mtime = 0
last_mtime = 0.0
while True:
try:
mtime = os.stat(self.fullpath).st_mtime

View File

@ -30,8 +30,8 @@ def domain_match(a: str, b: str) -> bool:
class StickyCookie:
def __init__(self):
self.jar: dict[TOrigin, dict[str, str]] = collections.defaultdict(dict)
def __init__(self) -> None:
self.jar: collections.defaultdict[TOrigin, dict[str, str]] = collections.defaultdict(dict)
self.flt: Optional[flowfilter.TFilter] = None
def load(self, loader):

View File

@ -142,9 +142,9 @@ def _sig_view_remove(flow: mitmproxy.flow.Flow, index: int) -> None:
class View(collections.abc.Sequence):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._store = collections.OrderedDict()
self._store: collections.OrderedDict[str, mitmproxy.flow.Flow] = collections.OrderedDict()
self.filter = flowfilter.match_all
# Should we show only marked flows?
self.show_marked = False
@ -156,7 +156,7 @@ class View(collections.abc.Sequence):
url=OrderRequestURL(self),
size=OrderKeySize(self),
)
self.order_key = self.default_order
self.order_key: _OrderKey = self.default_order
self.order_reversed = False
self.focus_follow = False
@ -316,9 +316,9 @@ class View(collections.abc.Sequence):
"""
if order_key not in self.orders:
raise exceptions.CommandError("Unknown flow order: %s" % order_key)
order_key = self.orders[order_key]
self.order_key = order_key
newview = sortedcontainers.SortedListWithKey(key=order_key)
key = self.orders[order_key]
self.order_key = key
newview = sortedcontainers.SortedListWithKey(key=key)
newview.update(self._view)
self._view = newview

View File

@ -227,7 +227,7 @@ class Client(Connection):
return client
def set_state(self, state):
self.peername = tuple(state["address"]) if state["address"] else None
self.peername = tuple(state["address"]) if state["address"] else None # type: ignore
self.alpn = state["alpn"]
self.cipher = state["cipher_name"]
self.id = state["id"]
@ -238,7 +238,7 @@ class Client(Connection):
self.tls_version = state["tls_version"]
# only used in sans-io
self.state = ConnectionState(state["state"])
self.sockname = tuple(state["sockname"]) if state["sockname"] else None
self.sockname = tuple(state["sockname"]) if state["sockname"] else None # type: ignore
self.error = state["error"]
self.tls = state["tls"]
self.certificate_list = [
@ -394,13 +394,13 @@ class Server(Connection):
return server
def set_state(self, state):
self.address = tuple(state["address"]) if state["address"] else None
self.address = tuple(state["address"]) if state["address"] else None # type: ignore
self.alpn = state["alpn"]
self.id = state["id"]
self.peername = tuple(state["ip_address"]) if state["ip_address"] else None
self.peername = tuple(state["ip_address"]) if state["ip_address"] else None # type: ignore
self.sni = state["sni"]
self.sockname = (
tuple(state["source_address"]) if state["source_address"] else None
tuple(state["source_address"]) if state["source_address"] else None # type: ignore
)
self.timestamp_end = state["timestamp_end"]
self.timestamp_start = state["timestamp_start"]

View File

@ -40,9 +40,8 @@ from . import (
)
from .base import View, KEY_MAX, format_text, format_dict, TViewResult
from ..http import HTTPFlow
from ..tcp import TCPMessage, TCPFlow
from ..udp import UDPMessage, UDPFlow
from ..tcp import TCPMessage
from ..udp import UDPMessage
from ..websocket import WebSocketMessage
views: list[View] = []
@ -97,7 +96,7 @@ def safe_to_print(lines, encoding="utf8"):
def get_message_content_view(
viewname: str,
message: Union[http.Message, TCPMessage, UDPMessage, WebSocketMessage],
flow: Union[HTTPFlow, TCPFlow, UDPFlow],
flow: flow.Flow,
):
"""
Like get_content_view, but also handles message encoding.

View File

@ -504,9 +504,11 @@ class ProtoParser:
if match:
if only_first_hit:
# only first match
self.name = fd.name
self.preferred_decoding = fd.intended_decoding
self.try_unpack = fd.as_packed
if fd.name is not None:
self.name = fd.name
if fd.intended_decoding is not None:
self.preferred_decoding = fd.intended_decoding
self.try_unpack = bool(fd.as_packed)
return
else:
# overwrite matches till last rule was inspected
@ -773,8 +775,8 @@ class ProtoParser:
def __init__(
self,
data: bytes,
rules: list[ProtoParser.ParserRule] = None,
parser_options: ParserOptions = None,
rules: list[ProtoParser.ParserRule] | None = None,
parser_options: ParserOptions | None = None,
) -> None:
self.data: bytes = data
if parser_options is None:
@ -979,7 +981,7 @@ class ViewGrpcProtobuf(base.View):
]
# allows to take external ParserOptions object. goes with defaults otherwise
def __init__(self, config: ViewConfig = None) -> None:
def __init__(self, config: ViewConfig | None = None) -> None:
super().__init__()
if config is None:
config = ViewConfig()

View File

@ -84,7 +84,7 @@ class ConnectionState:
class ViewHttp3(base.View):
name = "HTTP/3 Frames"
def __init__(self):
def __init__(self) -> None:
self.connections: defaultdict[tcp.TCPFlow, ConnectionState] = defaultdict(ConnectionState)
def __call__(

View File

@ -89,6 +89,7 @@ class MQTTControlPacket:
s = f"[{self.Names[self.packet_type]}]"
if self.packet_type == self.CONNECT:
assert self.payload
s += f"""
Client Id: {self.payload['ClientId']}
@ -101,6 +102,7 @@ Password: {strutils.bytes_to_escaped_str(self.payload.get('Password', b'None'))}
s += " sent topic filters: "
s += ", ".join([f"'{tf}'" for tf in self.topic_filters])
elif self.packet_type == self.PUBLISH:
assert self.payload
topic_name = strutils.bytes_to_escaped_str(self.topic_name)
payload = strutils.bytes_to_escaped_str(self.payload)

View File

@ -57,7 +57,7 @@ def format_pbuf(raw):
body = pair.value
try:
pairs = _parse_proto(body)
pairs = _parse_proto(body) # type: ignore
stack.extend([(pair, indent_level + 2) for pair in pairs[::-1]])
write_buf(out, pair.field_tag, None, indent_level)
except:

View File

@ -148,7 +148,7 @@ class MultiDict(_MultiDict[KT, VT], serializable.Serializable):
def __init__(self, fields=()):
super().__init__()
self.fields = tuple(tuple(i) for i in fields)
self.fields = tuple(tuple(i) for i in fields) # type: ignore
@staticmethod
def _reduce_values(values):
@ -162,7 +162,7 @@ class MultiDict(_MultiDict[KT, VT], serializable.Serializable):
return self.fields
def set_state(self, state):
self.fields = tuple(tuple(x) for x in state)
self.fields = tuple(tuple(x) for x in state) # type: ignore
@classmethod
def from_state(cls, state):

View File

@ -44,7 +44,7 @@ class Error(stateobject.StateObject):
def from_state(cls, state):
# the default implementation assumes an empty constructor. Override
# accordingly.
f = cls(None)
f = cls("")
f.set_state(state)
return f
@ -180,7 +180,7 @@ class Flow(stateobject.StateObject):
flow_cls = Flow.__types[state["type"]]
except KeyError:
raise ValueError(f"Unknown flow type: {state['type']}")
f = flow_cls(None, None) # noqa
f = flow_cls(None, None) # type: ignore
f.set_state(state)
return f

View File

@ -288,19 +288,19 @@ class FBod(_Rex):
@only(http.HTTPFlow, tcp.TCPFlow, udp.UDPFlow, dns.DNSFlow)
def __call__(self, f):
if isinstance(f, http.HTTPFlow):
if f.request and f.request.raw_content:
if self.re.search(f.request.get_content(strict=False)):
if f.request and (content := f.request.get_content(strict=False)) is not None:
if self.re.search(content):
return True
if f.response and f.response.raw_content:
if self.re.search(f.response.get_content(strict=False)):
if f.response and (content := f.response.get_content(strict=False)) is not None:
if self.re.search(content):
return True
if f.websocket:
for msg in f.websocket.messages:
if self.re.search(msg.content):
for wmsg in f.websocket.messages:
if wmsg.content is not None and self.re.search(wmsg.content):
return True
elif isinstance(f, (tcp.TCPFlow, udp.UDPFlow)):
for msg in f.messages:
if self.re.search(msg.content):
if msg.content is not None and self.re.search(msg.content):
return True
elif isinstance(f, dns.DNSFlow):
if f.request and self.re.search(f.request.content):
@ -318,12 +318,12 @@ class FBodRequest(_Rex):
@only(http.HTTPFlow, tcp.TCPFlow, udp.UDPFlow, dns.DNSFlow)
def __call__(self, f):
if isinstance(f, http.HTTPFlow):
if f.request and f.request.raw_content:
if self.re.search(f.request.get_content(strict=False)):
if f.request and (content := f.request.get_content(strict=False)) is not None:
if self.re.search(content):
return True
if f.websocket:
for msg in f.websocket.messages:
if msg.from_client and self.re.search(msg.content):
for wmsg in f.websocket.messages:
if wmsg.from_client and self.re.search(wmsg.content):
return True
elif isinstance(f, (tcp.TCPFlow, udp.UDPFlow)):
for msg in f.messages:
@ -342,12 +342,12 @@ class FBodResponse(_Rex):
@only(http.HTTPFlow, tcp.TCPFlow, udp.UDPFlow, dns.DNSFlow)
def __call__(self, f):
if isinstance(f, http.HTTPFlow):
if f.response and f.response.raw_content:
if self.re.search(f.response.get_content(strict=False)):
if f.response and (content := f.response.get_content(strict=False)) is not None:
if self.re.search(content):
return True
if f.websocket:
for msg in f.websocket.messages:
if not msg.from_client and self.re.search(msg.content):
for wmsg in f.websocket.messages:
if not wmsg.from_client and self.re.search(wmsg.content):
return True
elif isinstance(f, (tcp.TCPFlow, udp.UDPFlow)):
for msg in f.messages:

View File

@ -43,8 +43,8 @@ class Hook:
all_hooks[cls.name] = cls
# define a custom hash and __eq__ function so that events are hashable and not comparable.
cls.__hash__ = object.__hash__
cls.__eq__ = object.__eq__
cls.__hash__ = object.__hash__ # type: ignore
cls.__eq__ = object.__eq__ # type: ignore
all_hooks: dict[str, type[Hook]] = {}

View File

@ -156,7 +156,7 @@ class Headers(multidict.MultiDict): # type: ignore
name = _always_bytes(name)
return [_native(x) for x in super().get_all(name)]
def set_all(self, name: Union[str, bytes], values: list[Union[str, bytes]]):
def set_all(self, name: Union[str, bytes], values: Iterable[Union[str, bytes]]):
"""
Explicitly set multiple headers for the given key.
See `Headers.get_all`.
@ -985,7 +985,7 @@ class Request(Message):
is_valid_content_type = (
"multipart/form-data" in self.headers.get("content-type", "").lower()
)
if is_valid_content_type:
if is_valid_content_type and self.content is not None:
try:
return multipart.decode(self.headers.get("content-type"), self.content)
except ValueError:

View File

@ -6,7 +6,7 @@ v3.0.0dev) and versioning. Every change or migration gets a new flow file
version number, this prevents issues with developer builds and snapshots.
"""
import uuid
from typing import Any, Mapping, Union
from typing import Any, Union
from mitmproxy import version
from mitmproxy.utils import strutils
@ -139,8 +139,8 @@ def convert_300_4(data):
return data
client_connections: Mapping[str, str] = {}
server_connections: Mapping[str, str] = {}
client_connections: dict[tuple[str, ...], str] = {}
server_connections: dict[tuple[str, ...], str] = {}
def convert_4_5(data):

View File

@ -22,7 +22,7 @@ class Master:
event_loop: asyncio.AbstractEventLoop
def __init__(self, opts, event_loop: Optional[asyncio.AbstractEventLoop] = None):
def __init__(self, opts: options.Options, event_loop: Optional[asyncio.AbstractEventLoop] = None):
self.options: options.Options = opts or options.Options()
self.commands = command.CommandManager(self)
self.addons = addonmanager.AddonManager(self)
@ -79,7 +79,7 @@ class Master:
await self.addons.trigger_event(hooks.DoneHook())
self._legacy_log_events.uninstall()
def _asyncio_exception_handler(self, loop, context):
def _asyncio_exception_handler(self, loop, context) -> None:
try:
exc: Exception = context["exception"]
except KeyError:
@ -108,6 +108,7 @@ class Master:
# easy to replay saved flows against a different host.
# We may change this in the future so that clientplayback always replays to the first mode.
mode = ReverseMode.parse(self.options.mode[0])
assert isinstance(mode, ReverseMode)
f.request.host, f.request.port, *_ = mode.address
f.request.scheme = mode.scheme

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import re
import urllib.parse
from collections.abc import Sequence
@ -85,7 +86,7 @@ def unparse(scheme: str, host: str, port: int, path: str = "") -> str:
return f"{scheme}://{authority}{path}"
def encode(s: Sequence[tuple[str, str]], similar_to: str = None) -> str:
def encode(s: Sequence[tuple[str, str]], similar_to: str | None = None) -> str:
"""
Takes a list of (key, value) tuples and returns a urlencoded string.
If similar_to is passed, the output is formatted similar to the provided urlencoded string.

View File

@ -103,7 +103,7 @@ class OptManager:
mutation doesn't change the option state inadvertently.
"""
def __init__(self):
def __init__(self) -> None:
self.deferred: dict[str, Any] = {}
self.changed = signals.SyncSignal(_sig_changed_spec)
self.changed.connect(self._notify_subscribers)
@ -526,7 +526,7 @@ def parse(text):
snip = v.problem_mark.get_snippet()
raise exceptions.OptionsError(
"Config error at line %s:\n%s\n%s"
% (v.problem_mark.line + 1, snip, v.problem)
% (v.problem_mark.line + 1, snip, getattr(v, 'problem', ''))
)
else:
raise exceptions.OptionsError("Could not parse options.")

View File

@ -1,9 +1,9 @@
from __future__ import annotations
import collections
import collections.abc
import contextlib
import ctypes
import ctypes.wintypes
import io
import json
import os
import re
@ -12,7 +12,7 @@ import socketserver
import threading
import time
from collections.abc import Callable
from typing import Any, ClassVar, Optional
from typing import Any, ClassVar, IO, Optional, cast
import pydivert
import pydivert.consts
@ -27,20 +27,20 @@ REDIRECT_API_PORT = 8085
# Resolver
def read(rfile: io.BufferedReader) -> Any:
def read(rfile: IO[bytes]) -> Any:
x = rfile.readline().strip()
if not x:
return None
return json.loads(x)
def write(data, wfile: io.BufferedWriter) -> None:
def write(data, wfile: IO[bytes]) -> None:
wfile.write(json.dumps(data).encode() + b"\n")
wfile.flush()
class Resolver:
sock: socket.socket
sock: socket.socket | None
lock: threading.RLock
def __init__(self):
@ -84,7 +84,9 @@ class APIRequestHandler(socketserver.StreamRequestHandler):
for each received pickled client address, port tuple.
"""
def handle(self):
server: APIServer
def handle(self) -> None:
proxifier: TransparentProxy = self.server.proxifier
try:
pid: int = read(self.rfile)
@ -96,7 +98,7 @@ class APIRequestHandler(socketserver.StreamRequestHandler):
if c is None:
return
try:
server = proxifier.client_server_map[tuple(c)]
server = proxifier.client_server_map[cast(tuple[str, int], tuple(c))]
except KeyError:
server = None
write(server, self.wfile)
@ -205,7 +207,7 @@ class TcpConnectionTable(collections.abc.Mapping):
self._refresh_ipv6()
def _refresh_ipv4(self):
ret = ctypes.windll.iphlpapi.GetExtendedTcpTable(
ret = ctypes.windll.iphlpapi.GetExtendedTcpTable( # type: ignore
ctypes.byref(self._tcp),
ctypes.byref(self._tcp_size),
False,
@ -228,7 +230,7 @@ class TcpConnectionTable(collections.abc.Mapping):
)
def _refresh_ipv6(self):
ret = ctypes.windll.iphlpapi.GetExtendedTcpTable(
ret = ctypes.windll.iphlpapi.GetExtendedTcpTable( # type: ignore
ctypes.byref(self._tcp6),
ctypes.byref(self._tcp6_size),
False,
@ -275,7 +277,7 @@ class Redirect(threading.Thread):
try:
packet = self.windivert.recv()
except OSError as e:
if e.winerror == 995:
if getattr(e, "winerror", None) == 995:
return
else:
raise

View File

@ -76,7 +76,7 @@ class SendData(ConnectionCommand):
def __repr__(self):
target = str(self.connection).split("(", 1)[0].lower()
return f"SendData({target}, {self.data})"
return f"SendData({target}, {self.data!r})"
class OpenConnection(ConnectionCommand):

View File

@ -48,7 +48,7 @@ class DataReceived(ConnectionEvent):
def __repr__(self):
target = type(self.connection).__name__.lower()
return f"DataReceived({target}, {self.data})"
return f"DataReceived({target}, {self.data!r})"
class ConnectionClosed(ConnectionEvent):

View File

@ -4,6 +4,7 @@ Base class for protocol layers.
import collections
import textwrap
from abc import abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from logging import DEBUG
from typing import Any, ClassVar, Generator, NamedTuple, Optional, TypeVar
@ -98,6 +99,7 @@ class Layer:
message = message[:256] + ""
else:
Layer.__last_debug_message = message
assert self.debug is not None
return commands.Log(textwrap.indent(message, self.debug), DEBUG)
@property
@ -247,7 +249,7 @@ class NextLayer(Layer):
self.layer = None
self.events = []
self._ask_on_start = ask_on_start
self._handle = None
self._handle: Callable[[mevents.Event], CommandGenerator[None]] | None = None
def __repr__(self):
return f"NextLayer:{repr(self.layer)}"
@ -296,8 +298,8 @@ class NextLayer(Layer):
# 2. This layer is not needed anymore, so we directly reassign .handle_event.
# 3. Some layers may however still have a reference to the old .handle_event.
# ._handle is just an optimization to reduce the callstack in these cases.
self.handle_event = self.layer.handle_event
self._handle_event = self.layer.handle_event
self.handle_event = self.layer.handle_event # type: ignore
self._handle_event = self.layer.handle_event # type: ignore
self._handle = self.layer.handle_event
# Utility methods for whoever decides what the next layer is going to be.

View File

@ -137,12 +137,13 @@ class HttpStream(layer.Layer):
child_layer: Optional[layer.Layer] = None
@cached_property
def mode(self):
def mode(self) -> HTTPMode:
i = self.context.layers.index(self)
parent: HttpLayer = self.context.layers[i - 1]
parent = self.context.layers[i - 1]
assert isinstance(parent, HttpLayer)
return parent.mode
def __init__(self, context: Context, stream_id: int):
def __init__(self, context: Context, stream_id: int) -> None:
super().__init__(context)
self.request_body_buf = b""
self.response_body_buf = b""
@ -490,10 +491,11 @@ class HttpStream(layer.Layer):
if self.client_state == self.state_done:
yield from self.flow_done()
def flow_done(self):
def flow_done(self) -> layer.CommandGenerator[None]:
if not self.flow.websocket:
self.flow.live = False
assert self.flow.response
if self.flow.response.status_code == 101:
if self.flow.websocket:
self.child_layer = websocket.WebsocketLayer(self.context, self.flow)

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import socket
import struct
from abc import ABCMeta
from dataclasses import dataclass
from typing import Optional
from typing import Callable, Optional
from mitmproxy import connection
from mitmproxy.proxy import commands, events, layer
@ -156,7 +157,7 @@ class Socks5Proxy(DestinationKnown):
else:
raise AssertionError(f"Unknown event: {event}")
def state_greet(self):
def state_greet(self) -> layer.CommandGenerator[None]:
if len(self.buf) < 2:
return
@ -196,9 +197,9 @@ class Socks5Proxy(DestinationKnown):
self.buf = self.buf[2 + n_methods :]
yield from self.state()
state = state_greet
state: Callable[..., layer.CommandGenerator[None]] = state_greet
def state_auth(self):
def state_auth(self) -> layer.CommandGenerator[None]:
if len(self.buf) < 3:
return
@ -227,7 +228,7 @@ class Socks5Proxy(DestinationKnown):
self.state = self.state_connect
yield from self.state()
def state_connect(self):
def state_connect(self) -> layer.CommandGenerator[None]:
# Parse Connect Request
if len(self.buf) < 5:
return

View File

@ -257,7 +257,7 @@ class TLSLayer(tunnel.TunnelLayer):
conn.tls = True
def __repr__(self):
return super().__repr__().replace(")", f" {self.conn.sni} {self.conn.alpn})")
return super().__repr__().replace(")", f" {self.conn.sni!r} {self.conn.alpn!r})")
@property
def is_dtls(self):

View File

@ -81,7 +81,7 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
def __init_subclass__(cls, **kwargs):
"""Register all subclasses so that make() finds them."""
# extract mode from Generic[Mode].
mode = get_args(cls.__orig_bases__[0])[0]
mode = get_args(cls.__orig_bases__[0])[0] # type: ignore
if not isinstance(mode, TypeVar):
assert issubclass(mode, mode_specs.ProxyMode)
assert mode.type_name not in ServerInstance.__modes

View File

@ -41,7 +41,7 @@ class StateObject(serializable.Serializable):
setattr(self, attr, val)
else:
curr = getattr(self, attr, None)
if hasattr(curr, "set_state"):
if curr is not None and hasattr(curr, "set_state"):
curr.set_state(val)
else:
setattr(self, attr, make_object(cls, val))

View File

@ -125,6 +125,8 @@ class Commands(urwid.Pile, layoutwidget.LayoutWidget):
title = "Command Reference"
keyctx = "commands"
focus_position: int
def __init__(self, master):
oh = CommandHelp(master)
super().__init__(

View File

@ -242,11 +242,11 @@ def rle_append_beginning_modify(rle, a_r):
rle[0:0] = [(a, r)]
def colorize_host(host):
def colorize_host(host: str):
tld = get_tld(host)
sld = get_sld(host)
attr = []
attr: list = []
tld_size = len(tld)
sld_size = len(sld) - tld_size
@ -268,14 +268,14 @@ def colorize_host(host):
return attr
def colorize_req(s):
def colorize_req(s: str):
path = s.split("?", 2)[0]
i_query = len(path)
i_last_slash = path.rfind("/")
i_ext = path[i_last_slash + 1 :].rfind(".")
i_ext = i_last_slash + i_ext if i_ext >= 0 else len(s)
in_val = False
attr = []
attr: list = []
for i in range(len(s)):
c = s[i]
if (

View File

@ -1,8 +1,8 @@
import abc
import copy
import os
from collections.abc import Callable, Container, Iterable, Sequence
from typing import Any, AnyStr, Optional
from collections.abc import Callable, Container, Iterable, MutableSequence, Sequence
from typing import Any, AnyStr, ClassVar, Optional
import urwid
@ -117,7 +117,7 @@ class GridWalker(urwid.ListWalker):
"""
def __init__(self, lst: Iterable[list], editor: "GridEditor") -> None:
self.lst: Sequence[tuple[Any, set]] = [(i, set()) for i in lst]
self.lst: MutableSequence[tuple[Any, set]] = [(i, set()) for i in lst]
self.editor = editor
self.focus = 0
self.focus_col = 0
@ -150,7 +150,7 @@ class GridWalker(urwid.ListWalker):
errors = set()
row = list(self.lst[focus][0])
row[focus_col] = val
self.lst[focus] = [tuple(row), errors]
self.lst[focus] = [tuple(row), errors] # type: ignore
self._modified()
def delete_focus(self):
@ -180,7 +180,7 @@ class GridWalker(urwid.ListWalker):
self._modified()
def stop_edit(self):
if self.edit_row:
if self.edit_row and self.edit_row.edit_col:
try:
val = self.edit_row.edit_col.get_data()
except ValueError:
@ -242,7 +242,7 @@ FIRST_WIDTH_MAX = 40
class BaseGridEditor(urwid.WidgetWrap):
title: str = ""
keyctx = "grideditor"
keyctx: ClassVar[str] = "grideditor"
def __init__(
self,
@ -388,7 +388,7 @@ class BaseGridEditor(urwid.WidgetWrap):
class GridEditor(BaseGridEditor):
title = ""
columns: Sequence[Column] = ()
keyctx = "grideditor"
keyctx: ClassVar[str] = "grideditor"
def __init__(
self,
@ -408,7 +408,7 @@ class FocusEditor(urwid.WidgetWrap, layoutwidget.LayoutWidget):
A specialised GridEditor that edits the current focused flow.
"""
keyctx = "grideditor"
keyctx: ClassVar[str] = "grideditor"
def __init__(self, master):
self.master = master

View File

@ -28,10 +28,10 @@ class Column(col_bytes.Column):
class EncodingMixin:
def __init__(self, data, encoding_args):
self.encoding_args = encoding_args
super().__init__(data.__str__().encode(*self.encoding_args))
super().__init__(str(data).encode(*self.encoding_args)) # type: ignore
def get_data(self):
data = super().get_data()
data = super().get_data() # type: ignore
try:
return data.decode(*self.encoding_args)
except ValueError:

View File

@ -130,6 +130,7 @@ class KeyHelp(urwid.Frame):
class KeyBindings(urwid.Pile, layoutwidget.LayoutWidget):
title = "Key Bindings"
keyctx = "keybindings"
focus_position: int
def __init__(self, master):
oh = KeyHelp(master)

View File

@ -71,7 +71,7 @@ class Binding:
class Keymap:
def __init__(self, master):
self.executor = commandexecutor.CommandExecutor(master)
self.keys = {}
self.keys: dict[str, dict[str, Binding]] = {}
for c in Contexts:
self.keys[c] = {}
self.bindings = []
@ -161,7 +161,8 @@ class Keymap:
"""
b = self.get(context, key) or self.get("global", key)
if b:
return self.executor(b.command)
self.executor(b.command)
return None
return key
def handle_only(self, context: str, key: str) -> Optional[str]:
@ -171,7 +172,8 @@ class Keymap:
"""
b = self.get(context, key)
if b:
return self.executor(b.command)
self.executor(b.command)
return None
return key
@ -187,10 +189,13 @@ requiredKeyAttrs = {"key", "cmd"}
class KeymapConfig:
defaultFile = "keys.yaml"
def __init__(self, master):
self.master = master
@command.command("console.keymap.load")
def keymap_load_path(self, path: mitmproxy.types.Path) -> None:
try:
self.load_path(ctx.master.keymap, path) # type: ignore
self.load_path(self.master.keymap, path) # type: ignore
except (OSError, KeyBindingError) as e:
raise exceptions.CommandError("Could not load key bindings - %s" % e) from e
@ -198,7 +203,7 @@ class KeymapConfig:
p = os.path.join(os.path.expanduser(ctx.options.confdir), self.defaultFile)
if os.path.exists(p):
try:
self.load_path(ctx.master.keymap, p)
self.load_path(self.master.keymap, p)
except KeyBindingError as e:
logging.error(e)

View File

@ -1,3 +1,6 @@
from typing import ClassVar
class LayoutWidget:
"""
All top-level layout widgets and all widgets that may be set in an
@ -6,7 +9,7 @@ class LayoutWidget:
# Title is only required for windows, not overlay components
title = ""
keyctx = ""
keyctx: ClassVar[str] = ""
def key_responder(self):
"""

View File

@ -18,6 +18,7 @@ import urwid
from mitmproxy import addons
from mitmproxy import master
from mitmproxy import options
from mitmproxy import log
from mitmproxy.addons import errorcheck, intercept
from mitmproxy.addons import eventstore
@ -36,7 +37,7 @@ T = TypeVar("T", str, bytes)
class ConsoleMaster(master.Master):
def __init__(self, opts):
def __init__(self, opts: options.Options) -> None:
super().__init__(opts)
self.view: view.View = view.View()
@ -48,8 +49,6 @@ class ConsoleMaster(master.Master):
defaultkeys.map(self.keymap)
self.options.errored.connect(self.options_error)
self.view_stack = []
self.addons.add(*addons.default_addons())
self.addons.add(
intercept.Intercept(),
@ -57,11 +56,11 @@ class ConsoleMaster(master.Master):
self.events,
readfile.ReadFile(),
consoleaddons.ConsoleAddon(self),
keymap.KeymapConfig(),
keymap.KeymapConfig(self),
errorcheck.ErrorCheck(log_to_stderr=True),
)
self.window = None
self.window: window.Window | None = None
def __setattr__(self, name, value):
super().__setattr__(name, value)
@ -241,9 +240,11 @@ class ConsoleMaster(master.Master):
await super().done()
def overlay(self, widget, **kwargs):
assert self.window
self.window.set_overlay(widget, **kwargs)
def switch_view(self, name):
assert self.window
self.window.push(name)
def quit(self, a):

View File

@ -244,6 +244,8 @@ class Options(urwid.Pile, layoutwidget.LayoutWidget):
title = "Options"
keyctx = "options"
focus_position: int
def __init__(self, master):
oh = OptionHelp(master)
self.optionslist = OptionsList(master, oh)

View File

@ -3,6 +3,7 @@
#
# http://urwid.org/manual/displayattributes.html
#
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Optional
@ -89,9 +90,10 @@ class Palette:
]
_fields.extend(["gradient_%02d" % i for i in range(100)])
high: Optional[Mapping[str, Sequence[str]]] = None
low: Mapping[str, Sequence[str]]
def palette(self, transparent):
l = []
def palette(self, transparent: bool):
l: list[Sequence[str | None]] = []
highback, lowback = None, None
if not transparent:
if self.high and self.high.get("background"):
@ -102,14 +104,14 @@ class Palette:
if transparent and i == "background":
l.append(["background", "default", "default"])
else:
v = [i]
v: list[str | None] = [i]
low = list(self.low[i])
if lowback and low[1] == "default":
low[1] = lowback
v.extend(low)
if self.high and i in self.high:
v.append(None)
high = list(self.high[i])
high: list[str | None] = list(self.high[i])
if highback and high[1] == "default":
high[1] = highback
v.extend(high)

View File

@ -67,10 +67,11 @@ def make(
"Export": "Export this flow to file",
"Delete": "Delete flow from view",
}
if focused_flow.marked:
top_items["Unmark"] = "Toggle mark on this flow"
else:
top_items["Mark"] = "Toggle mark on this flow"
if widget == FlowListBox:
if focused_flow.marked:
top_items["Unmark"] = "Toggle mark on this flow"
else:
top_items["Mark"] = "Toggle mark on this flow"
if focused_flow.intercepted:
top_items["Resume"] = "Resume this intercepted flow"
if focused_flow.modified():

View File

@ -150,11 +150,11 @@ class ActionBar(urwid.WidgetWrap):
return k
def show_quickhelp(self) -> None:
try:
s = self.master.window.focus_stack()
if w := self.master.window:
s = w.focus_stack()
focused_widget = type(s.top_widget())
is_top_widget = len(s.stack) == 1
except AttributeError: # on startup
else: # on startup
focused_widget = flowlist.FlowListBox
is_top_widget = True
focused_flow = self.master.view.focus.flow
@ -196,7 +196,7 @@ class StatusBar(urwid.WidgetWrap):
self.redraw()
signals.call_in.send(seconds=self.REFRESHTIME, callback=self.refresh)
def sig_update(self, flow=None, updated=None):
def sig_update(self, *args, **kwargs) -> None:
self.redraw()
def keypress(self, *args, **kwargs):
@ -297,7 +297,7 @@ class StatusBar(urwid.WidgetWrap):
def redraw(self) -> None:
fc = self.master.commands.execute("view.properties.length")
if self.master.view.focus.flow is None:
if self.master.view.focus.index is None:
offset = 0
else:
offset = self.master.view.focus.index + 1

View File

@ -1,3 +1,4 @@
from __future__ import annotations
import argparse
import asyncio
import logging
@ -42,7 +43,7 @@ def run(
master_cls: type[T],
make_parser: Callable[[options.Options], argparse.ArgumentParser],
arguments: Sequence[str],
extra: Callable[[Any], dict] = None,
extra: Callable[[Any], dict] | None = None,
) -> T: # pragma: no cover
"""
extra: Extra argument processing callable which returns a dict of

View File

@ -5,7 +5,7 @@ import json
import logging
import os.path
import re
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from io import BytesIO
from itertools import islice
from typing import ClassVar, Optional, Union
@ -313,16 +313,18 @@ class Flows(RequestHandler):
class DumpFlows(RequestHandler):
def get(self):
def get(self) -> None:
self.set_header("Content-Disposition", "attachment; filename=flows")
self.set_header("Content-Type", "application/octet-stream")
match: Callable[[mitmproxy.flow.Flow], bool]
try:
match = flowfilter.parse(self.request.arguments["filter"][0].decode())
except ValueError: # thrown py flowfilter.parse if filter is invalid
raise APIError(400, f"Invalid filter argument / regex")
except (KeyError, IndexError): # Key+Index: ["filter"][0] can fail, if it's not set
match = bool # returns always true
def match(_) -> bool:
return True
with BytesIO() as bio:
fw = io.FlowWriter(bio)
@ -381,7 +383,7 @@ class FlowHandler(RequestHandler):
self.flow.kill()
self.view.remove([self.flow])
def put(self, flow_id):
def put(self, flow_id) -> None:
flow: mitmproxy.flow.Flow = self.flow
flow.backup()
try:
@ -469,13 +471,13 @@ class FlowContent(RequestHandler):
def get(self, flow_id, message):
message = getattr(self.flow, message)
assert isinstance(self.flow, HTTPFlow)
original_cd = message.headers.get("Content-Disposition", None)
filename = None
if original_cd:
filename = re.search(r'filename=([-\w" .()]+)', original_cd)
if filename:
filename = filename.group(1)
if m := re.search(r'filename=([-\w" .()]+)', original_cd):
filename = m.group(1)
if not filename:
filename = self.flow.request.path.split("?")[0].split("/")[-1]
@ -509,7 +511,7 @@ class FlowContentView(RequestHandler):
description=description,
)
def get(self, flow_id, message, content_view):
def get(self, flow_id, message, content_view) -> None:
flow = self.flow
assert isinstance(flow, (HTTPFlow, TCPFlow, UDPFlow))
@ -519,6 +521,7 @@ class FlowContentView(RequestHandler):
max_lines = None
if message == "messages":
messages: list[TCPMessage] | list[UDPMessage] | list[WebSocketMessage]
if isinstance(flow, HTTPFlow) and flow.websocket:
messages = flow.websocket.messages
elif isinstance(flow, (TCPFlow, UDPFlow)):

View File

@ -9,6 +9,7 @@ from mitmproxy import flow
from mitmproxy import log
from mitmproxy import master
from mitmproxy import optmanager
from mitmproxy import options
from mitmproxy.addons import errorcheck, eventstore
from mitmproxy.addons import intercept
from mitmproxy.addons import readfile
@ -22,8 +23,8 @@ logger = logging.getLogger(__name__)
class WebMaster(master.Master):
def __init__(self, options, with_termlog=True):
super().__init__(options)
def __init__(self, opts: options.Options, with_termlog: bool = True):
super().__init__(opts)
self.view = view.View()
self.view.sig_view_add.connect(self._sig_view_add)
self.view.sig_view_remove.connect(self._sig_view_remove)

View File

@ -131,10 +131,11 @@ def check():
for option in REPLACED.splitlines():
if option in args:
if isinstance(REPLACEMENTS.get(option), list):
new_options = REPLACEMENTS.get(option)
r = REPLACEMENTS.get(option)
if isinstance(r, list):
new_options = r
else:
new_options = [REPLACEMENTS.get(option)]
new_options = [r]
print(
"{} is deprecated.\n"
"Please use `{}` instead.".format(option, "` or `".join(new_options))

View File

@ -7,7 +7,9 @@ class Data:
def __init__(self, name):
self.name = name
m = importlib.import_module(name)
dirname = os.path.dirname(inspect.getsourcefile(m))
f = inspect.getsourcefile(m)
assert f is not None
dirname = os.path.dirname(f)
self.dirname = os.path.abspath(dirname)
def push(self, subpath):

View File

@ -18,11 +18,14 @@ from mitmproxy.utils import asyncio_utils
def dump_system_info():
mitmproxy_version = version.get_dev_version()
openssl_version = SSL.SSLeay_version(SSL.SSLEAY_VERSION)
if isinstance(openssl_version, bytes):
openssl_version = openssl_version.decode()
data = [
f"Mitmproxy: {mitmproxy_version}",
f"Python: {platform.python_version()}",
f"OpenSSL: {SSL.SSLeay_version(SSL.SSLEAY_VERSION).decode()}",
f"OpenSSL: {openssl_version}",
f"Platform: {platform.platform()}",
]
return "\n".join(data)

View File

@ -18,7 +18,7 @@ try:
from typing import ParamSpec
except ImportError: # pragma: no cover
# Python 3.9
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec # type: ignore
P = ParamSpec("P")
R = TypeVar("R")
@ -37,7 +37,7 @@ def make_weak_ref(obj: Any) -> weakref.ReferenceType:
# We're running into https://github.com/python/mypy/issues/6073 here,
# which is why the base class is a mixin and not a generic superclass.
class _SignalMixin:
def __init__(self):
def __init__(self) -> None:
self.receivers: list[weakref.ref[Callable]] = []
def connect(self, receiver: Callable) -> None:

View File

@ -21,7 +21,7 @@ if ref.startswith("refs/heads/"):
elif ref.startswith("refs/tags/"):
tag = ref.replace("refs/tags/", "")
else:
raise AssertionError
raise AssertionError("Failed to parse $GITHUB_REF")
(whl,) = root.glob("release/dist/mitmproxy-*-py3-none-any.whl")
docker_build_dir = root / "release/docker"
@ -47,15 +47,17 @@ r = subprocess.run(
"docker",
"run",
"--rm",
"-v",
f"{root / 'release'}:/release",
"localtesting",
"mitmdump",
"--version",
"-s", "/release/selftest.py",
],
check=True,
capture_output=True,
)
print(r.stdout.decode())
assert "Mitmproxy: " in r.stdout.decode()
assert "Self-test successful" in r.stdout.decode()
assert r.returncode == 0
# Now we can deploy.
subprocess.check_call(

View File

@ -121,28 +121,39 @@ def standalone_binaries():
with archive(DIST_DIR / f"mitmproxy-{version()}-{operating_system()}") as f:
_pyinstaller("standalone.spec")
_test_binaries(TEMP_DIR / "pyinstaller/dist")
for tool in ["mitmproxy", "mitmdump", "mitmweb"]:
executable = TEMP_DIR / "pyinstaller/dist" / tool
if platform.system() == "Windows":
executable = executable.with_suffix(".exe")
# Test if it works at all O:-)
print(f"> {executable} --version")
subprocess.check_call([executable, "--version"])
f.add(str(executable), str(executable.name))
print(f"Packed {f.name}.")
print(f"Packed {f.name!r}.")
def _ensure_pyinstaller_onedir():
if not (TEMP_DIR / "pyinstaller/dist/onedir").exists():
_pyinstaller("windows-dir.spec")
_test_binaries(TEMP_DIR / "pyinstaller/dist/onedir")
def _test_binaries(binary_directory: Path) -> None:
for tool in ["mitmproxy", "mitmdump", "mitmweb"]:
executable = binary_directory / tool
if platform.system() == "Windows":
executable = executable.with_suffix(".exe")
print(f"> {tool} --version")
executable = (TEMP_DIR / "pyinstaller/dist/onedir" / tool).with_suffix(".exe")
subprocess.check_call([executable, "--version"])
if tool == "mitmproxy":
continue # requires a TTY, which we don't have here.
print(f"> {tool} -s selftest.py")
subprocess.check_call([executable, "-s", here / "selftest.py"])
@cli.command()
def msix_installer():
@ -256,11 +267,7 @@ def installbuilder_installer():
subprocess.run(
[installer, "--mode", "unattended", "--unattendedmodeui", "none"], check=True
)
MITMPROXY_INSTALL_DIR = Path(rf"C:\Program Files\mitmproxy\bin")
for tool in ["mitmproxy", "mitmdump", "mitmweb"]:
executable = (MITMPROXY_INSTALL_DIR / tool).with_suffix(".exe")
print(f"> {executable} --version")
subprocess.check_call([executable, "--version"])
_test_binaries(Path(r"C:\Program Files\mitmproxy\bin"))
if __name__ == "__main__":

43
release/selftest.py Normal file
View File

@ -0,0 +1,43 @@
"""
This addons is used for binaries to perform a minimal selftest. Use like so:
mitmdump -s selftest.py -p 0
"""
import asyncio
import logging
import ssl
import sys
from pathlib import Path
from mitmproxy import ctx
def load(_):
# force a random port
ctx.options.listen_port = 0
def running():
# attach is somewhere so that it's not collected.
ctx.task = asyncio.create_task(make_request()) # type: ignore
async def make_request():
try:
cafile = Path(ctx.options.confdir).expanduser() / "mitmproxy-ca.pem"
ssl_ctx = ssl.create_default_context(cafile=cafile)
port = ctx.master.addons.get("proxyserver").listen_addrs()[0][1]
reader, writer = await asyncio.open_connection(
"127.0.0.1", port,
ssl=ssl_ctx
)
writer.write(b"GET / HTTP/1.1\r\nHost: mitm.it\r\nConnection: close\r\n\r\n")
await writer.drain()
resp = await reader.read()
if b"This page is served by your local mitmproxy instance" not in resp:
raise RuntimeError(resp)
logging.info("Self-test successful.")
ctx.master.shutdown()
except Exception as e:
print(f"{e!r}")
sys.exit(1)

View File

@ -29,6 +29,7 @@ exclude_lines =
\.\.\.
[mypy]
check_untyped_defs = True
ignore_missing_imports = True
files = mitmproxy,examples/addons,release

View File

@ -123,6 +123,14 @@ class TestScript:
await caplog_async.await_log("error.py")
sc.done()
async def test_import_error(self, monkeypatch, tdata, caplog):
monkeypatch.setattr(sys, "frozen", True, raising=False)
script.Script(
tdata.path("mitmproxy/data/addonscripts/import_error.py"),
False,
)
assert "Note that mitmproxy's binaries include their own Python environment" in caplog.text
async def test_optionexceptions(self, tdata, caplog_async):
with taddons.context() as tctx:
sc = script.Script(

View File

@ -0,0 +1 @@
import nonexistent

View File

@ -75,8 +75,8 @@ def test_remove():
def test_load_path(tmpdir):
dst = str(tmpdir.join("conf"))
kmc = keymap.KeymapConfig()
with taddons.context(kmc) as tctx:
with taddons.context() as tctx:
kmc = keymap.KeymapConfig(tctx.master)
km = keymap.Keymap(tctx.master)
tctx.master.keymap = km
@ -148,8 +148,8 @@ def test_load_path(tmpdir):
def test_parse():
kmc = keymap.KeymapConfig()
with taddons.context(kmc):
with taddons.context() as tctx:
kmc = keymap.KeymapConfig(tctx.master)
assert kmc.parse("") == []
assert kmc.parse("\n\n\n \n") == []
with pytest.raises(keymap.KeyBindingError, match="expected a list of keys"):

View File

@ -29,13 +29,13 @@ commands =
[testenv:mypy]
deps =
mypy==0.982
mypy==0.990
types-certifi==2021.10.8.3
types-Flask==1.1.6
types-Werkzeug==1.0.9
types-requests==2.28.11.2
types-cryptography==3.3.23.1
types-pyOpenSSL==22.1.0.1
types-requests==2.28.11.4
types-cryptography==3.3.23.2
types-pyOpenSSL==22.1.0.2
-e .[dev]
commands =

View File

@ -39,7 +39,6 @@ export interface OptionsState {
normalize_outbound_headers: boolean
onboarding: boolean
onboarding_host: string
onboarding_port: number
proxy_debug: boolean
proxyauth: string | undefined
rawtcp: boolean
@ -130,7 +129,6 @@ export const defaultState: OptionsState = {
normalize_outbound_headers: true,
onboarding: true,
onboarding_host: "mitm.it",
onboarding_port: 80,
proxy_debug: false,
proxyauth: undefined,
rawtcp: true,