Merge pull request #3671 from mhils/add-types

Add HTTP Message Type Hints
This commit is contained in:
Maximilian Hils 2019-10-16 21:42:37 +02:00 committed by GitHub
commit eb7ed1dc40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 10 deletions

View File

@ -1,14 +1,18 @@
import re import re
from typing import Optional, Union # noqa from typing import Optional # noqa
from mitmproxy.utils import strutils from mitmproxy.utils import strutils
from mitmproxy.net.http import encoding from mitmproxy.net.http import encoding
from mitmproxy.coretypes import serializable from mitmproxy.coretypes import serializable
from mitmproxy.net.http import headers from mitmproxy.net.http import headers as mheaders
class MessageData(serializable.Serializable): class MessageData(serializable.Serializable):
content: bytes = None headers: mheaders.Headers
content: bytes
http_version: bytes
timestamp_start: float
timestamp_end: float
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, MessageData): if isinstance(other, MessageData):
@ -18,7 +22,7 @@ class MessageData(serializable.Serializable):
def set_state(self, state): def set_state(self, state):
for k, v in state.items(): for k, v in state.items():
if k == "headers": if k == "headers":
v = headers.Headers.from_state(v) v = mheaders.Headers.from_state(v)
setattr(self, k, v) setattr(self, k, v)
def get_state(self): def get_state(self):
@ -28,12 +32,12 @@ class MessageData(serializable.Serializable):
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
state["headers"] = headers.Headers.from_state(state["headers"]) state["headers"] = mheaders.Headers.from_state(state["headers"])
return cls(**state) return cls(**state)
class Message(serializable.Serializable): class Message(serializable.Serializable):
data: MessageData = None data: MessageData
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Message): if isinstance(other, Message):
@ -48,7 +52,7 @@ class Message(serializable.Serializable):
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
state["headers"] = headers.Headers.from_state(state["headers"]) state["headers"] = mheaders.Headers.from_state(state["headers"])
return cls(**state) return cls(**state)
@property @property
@ -160,7 +164,7 @@ class Message(serializable.Serializable):
self.data.timestamp_end = timestamp_end self.data.timestamp_end = timestamp_end
def _get_content_type_charset(self) -> Optional[str]: def _get_content_type_charset(self) -> Optional[str]:
ct = headers.parse_content_type(self.headers.get("content-type", "")) ct = mheaders.parse_content_type(self.headers.get("content-type", ""))
if ct: if ct:
return ct[2].get("charset") return ct[2].get("charset")
return None return None
@ -213,9 +217,9 @@ class Message(serializable.Serializable):
self.content = encoding.encode(text, enc) self.content = encoding.encode(text, enc)
except ValueError: except ValueError:
# Fall back to UTF-8 and update the content-type header. # Fall back to UTF-8 and update the content-type header.
ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) ct = mheaders.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
ct[2]["charset"] = "utf-8" ct[2]["charset"] = "utf-8"
self.headers["content-type"] = headers.assemble_content_type(*ct) self.headers["content-type"] = mheaders.assemble_content_type(*ct)
enc = "utf8" enc = "utf8"
self.content = text.encode(enc, "surrogateescape") self.content = text.encode(enc, "surrogateescape")

View File

@ -64,6 +64,8 @@ class Request(message.Message):
""" """
An HTTP request. An HTTP request.
""" """
data: RequestData
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
self.data = RequestData(*args, **kwargs) self.data = RequestData(*args, **kwargs)

View File

@ -47,6 +47,8 @@ class Response(message.Message):
""" """
An HTTP response. An HTTP response.
""" """
data: ResponseData
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
self.data = ResponseData(*args, **kwargs) self.data = ResponseData(*args, **kwargs)