add http message type hints
This commit is contained in:
parent
45e3ae0f9c
commit
063ff41858
|
@ -1,14 +1,18 @@
|
|||
import re
|
||||
from typing import Optional, Union # noqa
|
||||
from typing import Optional # noqa
|
||||
|
||||
from mitmproxy.utils import strutils
|
||||
from mitmproxy.net.http import encoding
|
||||
from mitmproxy.coretypes import serializable
|
||||
from mitmproxy.net.http import headers
|
||||
from mitmproxy.net.http import headers as mheaders
|
||||
|
||||
|
||||
class MessageData(serializable.Serializable):
|
||||
content: bytes = None
|
||||
headers: mheaders.Headers
|
||||
content: bytes
|
||||
http_version: bytes
|
||||
timestamp_start: float
|
||||
timestamp_end: float
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, MessageData):
|
||||
|
@ -18,7 +22,7 @@ class MessageData(serializable.Serializable):
|
|||
def set_state(self, state):
|
||||
for k, v in state.items():
|
||||
if k == "headers":
|
||||
v = headers.Headers.from_state(v)
|
||||
v = mheaders.Headers.from_state(v)
|
||||
setattr(self, k, v)
|
||||
|
||||
def get_state(self):
|
||||
|
@ -28,12 +32,12 @@ class MessageData(serializable.Serializable):
|
|||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
state["headers"] = headers.Headers.from_state(state["headers"])
|
||||
state["headers"] = mheaders.Headers.from_state(state["headers"])
|
||||
return cls(**state)
|
||||
|
||||
|
||||
class Message(serializable.Serializable):
|
||||
data: MessageData = None
|
||||
data: MessageData
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Message):
|
||||
|
@ -48,7 +52,7 @@ class Message(serializable.Serializable):
|
|||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
state["headers"] = headers.Headers.from_state(state["headers"])
|
||||
state["headers"] = mheaders.Headers.from_state(state["headers"])
|
||||
return cls(**state)
|
||||
|
||||
@property
|
||||
|
@ -160,7 +164,7 @@ class Message(serializable.Serializable):
|
|||
self.data.timestamp_end = timestamp_end
|
||||
|
||||
def _get_content_type_charset(self) -> Optional[str]:
|
||||
ct = headers.parse_content_type(self.headers.get("content-type", ""))
|
||||
ct = mheaders.parse_content_type(self.headers.get("content-type", ""))
|
||||
if ct:
|
||||
return ct[2].get("charset")
|
||||
return None
|
||||
|
@ -213,9 +217,9 @@ class Message(serializable.Serializable):
|
|||
self.content = encoding.encode(text, enc)
|
||||
except ValueError:
|
||||
# Fall back to UTF-8 and update the content-type header.
|
||||
ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
|
||||
ct = mheaders.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
|
||||
ct[2]["charset"] = "utf-8"
|
||||
self.headers["content-type"] = headers.assemble_content_type(*ct)
|
||||
self.headers["content-type"] = mheaders.assemble_content_type(*ct)
|
||||
enc = "utf8"
|
||||
self.content = text.encode(enc, "surrogateescape")
|
||||
|
||||
|
|
|
@ -64,6 +64,8 @@ class Request(message.Message):
|
|||
"""
|
||||
An HTTP request.
|
||||
"""
|
||||
data: RequestData
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.data = RequestData(*args, **kwargs)
|
||||
|
|
|
@ -47,6 +47,8 @@ class Response(message.Message):
|
|||
"""
|
||||
An HTTP response.
|
||||
"""
|
||||
data: ResponseData
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.data = ResponseData(*args, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue