Merge pull request #3671 from mhils/add-types
Add HTTP Message Type Hints
This commit is contained in:
commit
eb7ed1dc40
|
@ -1,14 +1,18 @@
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Union # noqa
|
from typing import Optional # noqa
|
||||||
|
|
||||||
from mitmproxy.utils import strutils
|
from mitmproxy.utils import strutils
|
||||||
from mitmproxy.net.http import encoding
|
from mitmproxy.net.http import encoding
|
||||||
from mitmproxy.coretypes import serializable
|
from mitmproxy.coretypes import serializable
|
||||||
from mitmproxy.net.http import headers
|
from mitmproxy.net.http import headers as mheaders
|
||||||
|
|
||||||
|
|
||||||
class MessageData(serializable.Serializable):
|
class MessageData(serializable.Serializable):
|
||||||
content: bytes = None
|
headers: mheaders.Headers
|
||||||
|
content: bytes
|
||||||
|
http_version: bytes
|
||||||
|
timestamp_start: float
|
||||||
|
timestamp_end: float
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, MessageData):
|
if isinstance(other, MessageData):
|
||||||
|
@ -18,7 +22,7 @@ class MessageData(serializable.Serializable):
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
if k == "headers":
|
if k == "headers":
|
||||||
v = headers.Headers.from_state(v)
|
v = mheaders.Headers.from_state(v)
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
|
@ -28,12 +32,12 @@ class MessageData(serializable.Serializable):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state(cls, state):
|
def from_state(cls, state):
|
||||||
state["headers"] = headers.Headers.from_state(state["headers"])
|
state["headers"] = mheaders.Headers.from_state(state["headers"])
|
||||||
return cls(**state)
|
return cls(**state)
|
||||||
|
|
||||||
|
|
||||||
class Message(serializable.Serializable):
|
class Message(serializable.Serializable):
|
||||||
data: MessageData = None
|
data: MessageData
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, Message):
|
if isinstance(other, Message):
|
||||||
|
@ -48,7 +52,7 @@ class Message(serializable.Serializable):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state(cls, state):
|
def from_state(cls, state):
|
||||||
state["headers"] = headers.Headers.from_state(state["headers"])
|
state["headers"] = mheaders.Headers.from_state(state["headers"])
|
||||||
return cls(**state)
|
return cls(**state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -160,7 +164,7 @@ class Message(serializable.Serializable):
|
||||||
self.data.timestamp_end = timestamp_end
|
self.data.timestamp_end = timestamp_end
|
||||||
|
|
||||||
def _get_content_type_charset(self) -> Optional[str]:
|
def _get_content_type_charset(self) -> Optional[str]:
|
||||||
ct = headers.parse_content_type(self.headers.get("content-type", ""))
|
ct = mheaders.parse_content_type(self.headers.get("content-type", ""))
|
||||||
if ct:
|
if ct:
|
||||||
return ct[2].get("charset")
|
return ct[2].get("charset")
|
||||||
return None
|
return None
|
||||||
|
@ -213,9 +217,9 @@ class Message(serializable.Serializable):
|
||||||
self.content = encoding.encode(text, enc)
|
self.content = encoding.encode(text, enc)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Fall back to UTF-8 and update the content-type header.
|
# Fall back to UTF-8 and update the content-type header.
|
||||||
ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
|
ct = mheaders.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
|
||||||
ct[2]["charset"] = "utf-8"
|
ct[2]["charset"] = "utf-8"
|
||||||
self.headers["content-type"] = headers.assemble_content_type(*ct)
|
self.headers["content-type"] = mheaders.assemble_content_type(*ct)
|
||||||
enc = "utf8"
|
enc = "utf8"
|
||||||
self.content = text.encode(enc, "surrogateescape")
|
self.content = text.encode(enc, "surrogateescape")
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,8 @@ class Request(message.Message):
|
||||||
"""
|
"""
|
||||||
An HTTP request.
|
An HTTP request.
|
||||||
"""
|
"""
|
||||||
|
data: RequestData
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.data = RequestData(*args, **kwargs)
|
self.data = RequestData(*args, **kwargs)
|
||||||
|
|
|
@ -47,6 +47,8 @@ class Response(message.Message):
|
||||||
"""
|
"""
|
||||||
An HTTP response.
|
An HTTP response.
|
||||||
"""
|
"""
|
||||||
|
data: ResponseData
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.data = ResponseData(*args, **kwargs)
|
self.data = ResponseData(*args, **kwargs)
|
||||||
|
|
Loading…
Reference in New Issue