stateobject: use typing, enable tuples and more complex datatypes
This commit is contained in:
parent
b7db304dde
commit
69726f180a
|
@ -87,7 +87,7 @@ class Flow(stateobject.StateObject):
|
|||
type=str,
|
||||
intercepted=bool,
|
||||
marked=bool,
|
||||
metadata=dict,
|
||||
metadata=typing.Dict[str, typing.Any],
|
||||
)
|
||||
|
||||
def get_state(self):
|
||||
|
|
|
@ -1,18 +1,12 @@
|
|||
from typing import Any
|
||||
from typing import List
|
||||
import typing
|
||||
from typing import Any # noqa
|
||||
from typing import MutableMapping # noqa
|
||||
|
||||
from mitmproxy.coretypes import serializable
|
||||
|
||||
|
||||
def _is_list(cls):
|
||||
# The typing module is broken on Python 3.5.0, fixed on 3.5.1.
|
||||
is_list_bugfix = getattr(cls, "__origin__", False) == getattr(List[Any], "__origin__", True)
|
||||
return issubclass(cls, List) or is_list_bugfix
|
||||
from mitmproxy.utils import typecheck
|
||||
|
||||
|
||||
class StateObject(serializable.Serializable):
|
||||
|
||||
"""
|
||||
An object with serializable state.
|
||||
|
||||
|
@ -34,22 +28,7 @@ class StateObject(serializable.Serializable):
|
|||
state = {}
|
||||
for attr, cls in self._stateobject_attributes.items():
|
||||
val = getattr(self, attr)
|
||||
if val is None:
|
||||
state[attr] = None
|
||||
elif hasattr(val, "get_state"):
|
||||
state[attr] = val.get_state()
|
||||
elif _is_list(cls):
|
||||
state[attr] = [x.get_state() for x in val]
|
||||
elif isinstance(val, dict):
|
||||
s = {}
|
||||
for k, v in val.items():
|
||||
if hasattr(v, "get_state"):
|
||||
s[k] = v.get_state()
|
||||
else:
|
||||
s[k] = v
|
||||
state[attr] = s
|
||||
else:
|
||||
state[attr] = val
|
||||
state[attr] = get_state(cls, val)
|
||||
return state
|
||||
|
||||
def set_state(self, state):
|
||||
|
@ -65,13 +44,51 @@ class StateObject(serializable.Serializable):
|
|||
curr = getattr(self, attr)
|
||||
if hasattr(curr, "set_state"):
|
||||
curr.set_state(val)
|
||||
elif hasattr(cls, "from_state"):
|
||||
obj = cls.from_state(val)
|
||||
setattr(self, attr, obj)
|
||||
elif _is_list(cls):
|
||||
cls = cls.__parameters__[0] if cls.__parameters__ else cls.__args__[0]
|
||||
setattr(self, attr, [cls.from_state(x) for x in val])
|
||||
else: # primitive types such as int, str, ...
|
||||
setattr(self, attr, cls(val))
|
||||
else:
|
||||
setattr(self, attr, make_object(cls, val))
|
||||
if state:
|
||||
raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state))
|
||||
|
||||
|
||||
def _process(typeinfo: typecheck.Type, val: typing.Any, make: bool) -> typing.Any:
|
||||
if val is None:
|
||||
return None
|
||||
elif make and hasattr(typeinfo, "from_state"):
|
||||
return typeinfo.from_state(val)
|
||||
elif not make and hasattr(val, "get_state"):
|
||||
return val.get_state()
|
||||
|
||||
typename = str(typeinfo)
|
||||
|
||||
if typename.startswith("typing.List"):
|
||||
T = typecheck.sequence_type(typeinfo)
|
||||
return [_process(T, x, make) for x in val]
|
||||
elif typename.startswith("typing.Tuple"):
|
||||
Ts = typecheck.tuple_types(typeinfo)
|
||||
if len(Ts) != len(val):
|
||||
raise ValueError("Invalid data. Expected {}, got {}.".format(Ts, val))
|
||||
return tuple(
|
||||
_process(T, x, make) for T, x in zip(Ts, val)
|
||||
)
|
||||
elif typename.startswith("typing.Dict"):
|
||||
k_cls, v_cls = typecheck.mapping_types(typeinfo)
|
||||
return {
|
||||
_process(k_cls, k, make): _process(v_cls, v, make)
|
||||
for k, v in val.items()
|
||||
}
|
||||
elif typename.startswith("typing.Any"):
|
||||
# FIXME: Remove this when we remove flow.metadata
|
||||
assert isinstance(val, (int, str, bool, bytes))
|
||||
return val
|
||||
else:
|
||||
return typeinfo(val)
|
||||
|
||||
|
||||
def make_object(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any:
|
||||
"""Create an object based on the state given in val."""
|
||||
return _process(typeinfo, val, True)
|
||||
|
||||
|
||||
def get_state(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any:
|
||||
"""Get the state of the object given as val."""
|
||||
return _process(typeinfo, val, False)
|
||||
|
|
|
@ -1,7 +1,40 @@
|
|||
import typing
|
||||
|
||||
Type = typing.Union[
|
||||
typing.Any # anything more elaborate really fails with mypy at the moment.
|
||||
]
|
||||
|
||||
def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> None:
|
||||
|
||||
def sequence_type(typeinfo: typing.Type[typing.List]) -> Type:
|
||||
"""Return the type of a sequence, e.g. typing.List"""
|
||||
try:
|
||||
return typeinfo.__args__[0] # type: ignore
|
||||
except AttributeError: # Python 3.5.0
|
||||
return typeinfo.__parameters__[0] # type: ignore
|
||||
|
||||
|
||||
def tuple_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]:
|
||||
"""Return the types of a typing.Tuple"""
|
||||
try:
|
||||
return typeinfo.__args__ # type: ignore
|
||||
except AttributeError: # Python 3.5.x
|
||||
return typeinfo.__tuple_params__ # type: ignore
|
||||
|
||||
|
||||
def union_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]:
|
||||
"""return the types of a typing.Union"""
|
||||
try:
|
||||
return typeinfo.__args__ # type: ignore
|
||||
except AttributeError: # Python 3.5.x
|
||||
return typeinfo.__union_params__ # type: ignore
|
||||
|
||||
|
||||
def mapping_types(typeinfo: typing.Type[typing.Mapping]) -> typing.Tuple[Type, Type]:
|
||||
"""return the types of a mapping, e.g. typing.Dict"""
|
||||
return typeinfo.__args__ # type: ignore
|
||||
|
||||
|
||||
def check_option_type(name: str, value: typing.Any, typeinfo: Type) -> None:
|
||||
"""
|
||||
Check if the provided value is an instance of typeinfo and raises a
|
||||
TypeError otherwise. This function supports only those types required for
|
||||
|
@ -16,13 +49,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non
|
|||
typename = str(typeinfo)
|
||||
|
||||
if typename.startswith("typing.Union"):
|
||||
try:
|
||||
types = typeinfo.__args__ # type: ignore
|
||||
except AttributeError:
|
||||
# Python 3.5.x
|
||||
types = typeinfo.__union_params__ # type: ignore
|
||||
|
||||
for T in types:
|
||||
for T in union_types(typeinfo):
|
||||
try:
|
||||
check_option_type(name, value, T)
|
||||
except TypeError:
|
||||
|
@ -31,12 +58,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non
|
|||
return
|
||||
raise e
|
||||
elif typename.startswith("typing.Tuple"):
|
||||
try:
|
||||
types = typeinfo.__args__ # type: ignore
|
||||
except AttributeError:
|
||||
# Python 3.5.x
|
||||
types = typeinfo.__tuple_params__ # type: ignore
|
||||
|
||||
types = tuple_types(typeinfo)
|
||||
if not isinstance(value, (tuple, list)):
|
||||
raise e
|
||||
if len(types) != len(value):
|
||||
|
@ -45,11 +67,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non
|
|||
check_option_type("{}[{}]".format(name, i), x, T)
|
||||
return
|
||||
elif typename.startswith("typing.Sequence"):
|
||||
try:
|
||||
T = typeinfo.__args__[0] # type: ignore
|
||||
except AttributeError:
|
||||
# Python 3.5.0
|
||||
T = typeinfo.__parameters__[0] # type: ignore
|
||||
T = sequence_type(typeinfo)
|
||||
if not isinstance(value, (tuple, list)):
|
||||
raise e
|
||||
for v in value:
|
||||
|
|
|
@ -75,7 +75,6 @@ exclude =
|
|||
mitmproxy/proxy/protocol/tls.py
|
||||
mitmproxy/proxy/root_context.py
|
||||
mitmproxy/proxy/server.py
|
||||
mitmproxy/stateobject.py
|
||||
mitmproxy/utils/bits.py
|
||||
pathod/language/actions.py
|
||||
pathod/language/base.py
|
||||
|
|
|
@ -1,101 +1,146 @@
|
|||
from typing import List
|
||||
import typing
|
||||
|
||||
import pytest
|
||||
|
||||
from mitmproxy.stateobject import StateObject
|
||||
|
||||
|
||||
class Child(StateObject):
|
||||
class TObject(StateObject):
|
||||
def __init__(self, x):
|
||||
self.x = x
|
||||
|
||||
_stateobject_attributes = dict(
|
||||
x=int
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
obj = cls(None)
|
||||
obj.set_state(state)
|
||||
return obj
|
||||
|
||||
|
||||
class Child(TObject):
|
||||
_stateobject_attributes = dict(
|
||||
x=int
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Child) and self.x == other.x
|
||||
|
||||
|
||||
class Container(StateObject):
|
||||
def __init__(self):
|
||||
self.child = None
|
||||
self.children = None
|
||||
self.dictionary = None
|
||||
|
||||
class TTuple(TObject):
|
||||
_stateobject_attributes = dict(
|
||||
child=Child,
|
||||
children=List[Child],
|
||||
dictionary=dict,
|
||||
x=typing.Tuple[int, Child]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
obj = cls()
|
||||
obj.set_state(state)
|
||||
return obj
|
||||
|
||||
class TList(TObject):
|
||||
_stateobject_attributes = dict(
|
||||
x=typing.List[Child]
|
||||
)
|
||||
|
||||
|
||||
class TDict(TObject):
|
||||
_stateobject_attributes = dict(
|
||||
x=typing.Dict[str, Child]
|
||||
)
|
||||
|
||||
|
||||
class TAny(TObject):
|
||||
_stateobject_attributes = dict(
|
||||
x=typing.Any
|
||||
)
|
||||
|
||||
|
||||
class TSerializableChild(TObject):
|
||||
_stateobject_attributes = dict(
|
||||
x=Child
|
||||
)
|
||||
|
||||
|
||||
def test_simple():
|
||||
a = Child(42)
|
||||
assert a.get_state() == {"x": 42}
|
||||
b = a.copy()
|
||||
assert b.get_state() == {"x": 42}
|
||||
a.set_state({"x": 44})
|
||||
assert a.x == 44
|
||||
assert b.x == 42
|
||||
|
||||
|
||||
def test_container():
|
||||
a = Container()
|
||||
a.child = Child(42)
|
||||
def test_serializable_child():
|
||||
child = Child(42)
|
||||
a = TSerializableChild(child)
|
||||
assert a.get_state() == {
|
||||
"x": {"x": 42}
|
||||
}
|
||||
a.set_state({
|
||||
"x": {"x": 43}
|
||||
})
|
||||
assert a.x.x == 43
|
||||
assert a.x is child
|
||||
b = a.copy()
|
||||
assert a.child.x == b.child.x
|
||||
b.child.x = 44
|
||||
assert a.child.x != b.child.x
|
||||
assert a.x == b.x
|
||||
assert a.x is not b.x
|
||||
|
||||
|
||||
def test_container_list():
|
||||
a = Container()
|
||||
a.children = [Child(42), Child(44)]
|
||||
def test_tuple():
|
||||
a = TTuple((42, Child(43)))
|
||||
assert a.get_state() == {
|
||||
"child": None,
|
||||
"children": [{"x": 42}, {"x": 44}],
|
||||
"dictionary": None,
|
||||
"x": (42, {"x": 43})
|
||||
}
|
||||
b = a.copy()
|
||||
a.set_state({"x": (44, {"x": 45})})
|
||||
assert a.x == (44, Child(45))
|
||||
assert b.x == (42, Child(43))
|
||||
|
||||
|
||||
def test_tuple_err():
|
||||
a = TTuple(None)
|
||||
with pytest.raises(ValueError, msg="Invalid data"):
|
||||
a.set_state({"x": (42,)})
|
||||
|
||||
|
||||
def test_list():
|
||||
a = TList([Child(1), Child(2)])
|
||||
assert a.get_state() == {
|
||||
"x": [{"x": 1}, {"x": 2}],
|
||||
}
|
||||
copy = a.copy()
|
||||
assert len(copy.children) == 2
|
||||
assert copy.children is not a.children
|
||||
assert copy.children[0] is not a.children[0]
|
||||
assert Container.from_state(a.get_state())
|
||||
assert len(copy.x) == 2
|
||||
assert copy.x is not a.x
|
||||
assert copy.x[0] is not a.x[0]
|
||||
|
||||
|
||||
def test_container_dict():
|
||||
a = Container()
|
||||
a.dictionary = dict()
|
||||
a.dictionary['foo'] = 'bar'
|
||||
a.dictionary['bar'] = Child(44)
|
||||
def test_dict():
|
||||
a = TDict({"foo": Child(42)})
|
||||
assert a.get_state() == {
|
||||
"child": None,
|
||||
"children": None,
|
||||
"dictionary": {'bar': {'x': 44}, 'foo': 'bar'},
|
||||
"x": {"foo": {"x": 42}}
|
||||
}
|
||||
copy = a.copy()
|
||||
assert len(copy.dictionary) == 2
|
||||
assert copy.dictionary is not a.dictionary
|
||||
assert copy.dictionary['bar'] is not a.dictionary['bar']
|
||||
b = a.copy()
|
||||
assert list(a.x.items()) == list(b.x.items())
|
||||
assert a.x is not b.x
|
||||
assert a.x["foo"] is not b.x["foo"]
|
||||
|
||||
|
||||
def test_any():
|
||||
a = TAny(42)
|
||||
b = a.copy()
|
||||
assert a.x == b.x
|
||||
|
||||
a = TAny(object())
|
||||
with pytest.raises(AssertionError):
|
||||
a.get_state()
|
||||
|
||||
|
||||
def test_too_much_state():
|
||||
a = Container()
|
||||
a.child = Child(42)
|
||||
a = Child(42)
|
||||
s = a.get_state()
|
||||
s['foo'] = 'bar'
|
||||
b = Container()
|
||||
|
||||
with pytest.raises(RuntimeWarning):
|
||||
b.set_state(s)
|
||||
a.set_state(s)
|
||||
|
||||
|
||||
def test_none():
|
||||
a = Child(None)
|
||||
assert a.get_state() == {"x": None}
|
||||
a = Child(42)
|
||||
a.set_state({"x": None})
|
||||
assert a.x is None
|
||||
|
|
|
@ -93,3 +93,8 @@ def test_typesec_to_str():
|
|||
assert(typecheck.typespec_to_str(typing.Optional[str])) == "optional str"
|
||||
with pytest.raises(NotImplementedError):
|
||||
typecheck.typespec_to_str(dict)
|
||||
|
||||
|
||||
def test_mapping_types():
|
||||
# this is not covered by check_option_type, but still belongs in this module
|
||||
assert (str, int) == typecheck.mapping_types(typing.Mapping[str, int])
|
||||
|
|
Loading…
Reference in New Issue