diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 944c032d8..6a27a4a81 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -87,7 +87,7 @@ class Flow(stateobject.StateObject): type=str, intercepted=bool, marked=bool, - metadata=dict, + metadata=typing.Dict[str, typing.Any], ) def get_state(self): diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py index 007339e8c..ffaf285fa 100644 --- a/mitmproxy/stateobject.py +++ b/mitmproxy/stateobject.py @@ -1,18 +1,12 @@ -from typing import Any -from typing import List +import typing +from typing import Any # noqa from typing import MutableMapping # noqa from mitmproxy.coretypes import serializable - - -def _is_list(cls): - # The typing module is broken on Python 3.5.0, fixed on 3.5.1. - is_list_bugfix = getattr(cls, "__origin__", False) == getattr(List[Any], "__origin__", True) - return issubclass(cls, List) or is_list_bugfix +from mitmproxy.utils import typecheck class StateObject(serializable.Serializable): - """ An object with serializable state. @@ -34,22 +28,7 @@ class StateObject(serializable.Serializable): state = {} for attr, cls in self._stateobject_attributes.items(): val = getattr(self, attr) - if val is None: - state[attr] = None - elif hasattr(val, "get_state"): - state[attr] = val.get_state() - elif _is_list(cls): - state[attr] = [x.get_state() for x in val] - elif isinstance(val, dict): - s = {} - for k, v in val.items(): - if hasattr(v, "get_state"): - s[k] = v.get_state() - else: - s[k] = v - state[attr] = s - else: - state[attr] = val + state[attr] = get_state(cls, val) return state def set_state(self, state): @@ -65,13 +44,51 @@ class StateObject(serializable.Serializable): curr = getattr(self, attr) if hasattr(curr, "set_state"): curr.set_state(val) - elif hasattr(cls, "from_state"): - obj = cls.from_state(val) - setattr(self, attr, obj) - elif _is_list(cls): - cls = cls.__parameters__[0] if cls.__parameters__ else cls.__args__[0] - setattr(self, attr, [cls.from_state(x) for x in val]) - else: # primitive types such as int, str, ... - setattr(self, attr, cls(val)) + else: + setattr(self, attr, make_object(cls, val)) if state: raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state)) + + +def _process(typeinfo: typecheck.Type, val: typing.Any, make: bool) -> typing.Any: + if val is None: + return None + elif make and hasattr(typeinfo, "from_state"): + return typeinfo.from_state(val) + elif not make and hasattr(val, "get_state"): + return val.get_state() + + typename = str(typeinfo) + + if typename.startswith("typing.List"): + T = typecheck.sequence_type(typeinfo) + return [_process(T, x, make) for x in val] + elif typename.startswith("typing.Tuple"): + Ts = typecheck.tuple_types(typeinfo) + if len(Ts) != len(val): + raise ValueError("Invalid data. Expected {}, got {}.".format(Ts, val)) + return tuple( + _process(T, x, make) for T, x in zip(Ts, val) + ) + elif typename.startswith("typing.Dict"): + k_cls, v_cls = typecheck.mapping_types(typeinfo) + return { + _process(k_cls, k, make): _process(v_cls, v, make) + for k, v in val.items() + } + elif typename.startswith("typing.Any"): + # FIXME: Remove this when we remove flow.metadata + assert isinstance(val, (int, str, bool, bytes)) + return val + else: + return typeinfo(val) + + +def make_object(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any: + """Create an object based on the state given in val.""" + return _process(typeinfo, val, True) + + +def get_state(typeinfo: typecheck.Type, val: typing.Any) -> typing.Any: + """Get the state of the object given as val.""" + return _process(typeinfo, val, False) diff --git a/mitmproxy/utils/typecheck.py b/mitmproxy/utils/typecheck.py index 1070fad08..22db68f58 100644 --- a/mitmproxy/utils/typecheck.py +++ b/mitmproxy/utils/typecheck.py @@ -1,7 +1,40 @@ import typing +Type = typing.Union[ + typing.Any # anything more elaborate really fails with mypy at the moment. +] -def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> None: + +def sequence_type(typeinfo: typing.Type[typing.List]) -> Type: + """Return the type of a sequence, e.g. typing.List""" + try: + return typeinfo.__args__[0] # type: ignore + except AttributeError: # Python 3.5.0 + return typeinfo.__parameters__[0] # type: ignore + + +def tuple_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]: + """Return the types of a typing.Tuple""" + try: + return typeinfo.__args__ # type: ignore + except AttributeError: # Python 3.5.x + return typeinfo.__tuple_params__ # type: ignore + + +def union_types(typeinfo: typing.Type[typing.Tuple]) -> typing.Sequence[Type]: + """return the types of a typing.Union""" + try: + return typeinfo.__args__ # type: ignore + except AttributeError: # Python 3.5.x + return typeinfo.__union_params__ # type: ignore + + +def mapping_types(typeinfo: typing.Type[typing.Mapping]) -> typing.Tuple[Type, Type]: + """return the types of a mapping, e.g. typing.Dict""" + return typeinfo.__args__ # type: ignore + + +def check_option_type(name: str, value: typing.Any, typeinfo: Type) -> None: """ Check if the provided value is an instance of typeinfo and raises a TypeError otherwise. This function supports only those types required for @@ -16,13 +49,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non typename = str(typeinfo) if typename.startswith("typing.Union"): - try: - types = typeinfo.__args__ # type: ignore - except AttributeError: - # Python 3.5.x - types = typeinfo.__union_params__ # type: ignore - - for T in types: + for T in union_types(typeinfo): try: check_option_type(name, value, T) except TypeError: @@ -31,12 +58,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non return raise e elif typename.startswith("typing.Tuple"): - try: - types = typeinfo.__args__ # type: ignore - except AttributeError: - # Python 3.5.x - types = typeinfo.__tuple_params__ # type: ignore - + types = tuple_types(typeinfo) if not isinstance(value, (tuple, list)): raise e if len(types) != len(value): @@ -45,11 +67,7 @@ def check_option_type(name: str, value: typing.Any, typeinfo: typing.Any) -> Non check_option_type("{}[{}]".format(name, i), x, T) return elif typename.startswith("typing.Sequence"): - try: - T = typeinfo.__args__[0] # type: ignore - except AttributeError: - # Python 3.5.0 - T = typeinfo.__parameters__[0] # type: ignore + T = sequence_type(typeinfo) if not isinstance(value, (tuple, list)): raise e for v in value: diff --git a/setup.cfg b/setup.cfg index 7c7547228..592cc2e33 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,7 +75,6 @@ exclude = mitmproxy/proxy/protocol/tls.py mitmproxy/proxy/root_context.py mitmproxy/proxy/server.py - mitmproxy/stateobject.py mitmproxy/utils/bits.py pathod/language/actions.py pathod/language/base.py diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py index d8c7a8e9b..bd5d17928 100644 --- a/test/mitmproxy/test_stateobject.py +++ b/test/mitmproxy/test_stateobject.py @@ -1,101 +1,146 @@ -from typing import List +import typing + import pytest from mitmproxy.stateobject import StateObject -class Child(StateObject): +class TObject(StateObject): def __init__(self, x): self.x = x - _stateobject_attributes = dict( - x=int - ) - @classmethod def from_state(cls, state): obj = cls(None) obj.set_state(state) return obj + +class Child(TObject): + _stateobject_attributes = dict( + x=int + ) + def __eq__(self, other): return isinstance(other, Child) and self.x == other.x -class Container(StateObject): - def __init__(self): - self.child = None - self.children = None - self.dictionary = None - +class TTuple(TObject): _stateobject_attributes = dict( - child=Child, - children=List[Child], - dictionary=dict, + x=typing.Tuple[int, Child] ) - @classmethod - def from_state(cls, state): - obj = cls() - obj.set_state(state) - return obj + +class TList(TObject): + _stateobject_attributes = dict( + x=typing.List[Child] + ) + + +class TDict(TObject): + _stateobject_attributes = dict( + x=typing.Dict[str, Child] + ) + + +class TAny(TObject): + _stateobject_attributes = dict( + x=typing.Any + ) + + +class TSerializableChild(TObject): + _stateobject_attributes = dict( + x=Child + ) def test_simple(): a = Child(42) + assert a.get_state() == {"x": 42} b = a.copy() - assert b.get_state() == {"x": 42} a.set_state({"x": 44}) assert a.x == 44 assert b.x == 42 -def test_container(): - a = Container() - a.child = Child(42) +def test_serializable_child(): + child = Child(42) + a = TSerializableChild(child) + assert a.get_state() == { + "x": {"x": 42} + } + a.set_state({ + "x": {"x": 43} + }) + assert a.x.x == 43 + assert a.x is child b = a.copy() - assert a.child.x == b.child.x - b.child.x = 44 - assert a.child.x != b.child.x + assert a.x == b.x + assert a.x is not b.x -def test_container_list(): - a = Container() - a.children = [Child(42), Child(44)] +def test_tuple(): + a = TTuple((42, Child(43))) assert a.get_state() == { - "child": None, - "children": [{"x": 42}, {"x": 44}], - "dictionary": None, + "x": (42, {"x": 43}) + } + b = a.copy() + a.set_state({"x": (44, {"x": 45})}) + assert a.x == (44, Child(45)) + assert b.x == (42, Child(43)) + + +def test_tuple_err(): + a = TTuple(None) + with pytest.raises(ValueError, msg="Invalid data"): + a.set_state({"x": (42,)}) + + +def test_list(): + a = TList([Child(1), Child(2)]) + assert a.get_state() == { + "x": [{"x": 1}, {"x": 2}], } copy = a.copy() - assert len(copy.children) == 2 - assert copy.children is not a.children - assert copy.children[0] is not a.children[0] - assert Container.from_state(a.get_state()) + assert len(copy.x) == 2 + assert copy.x is not a.x + assert copy.x[0] is not a.x[0] -def test_container_dict(): - a = Container() - a.dictionary = dict() - a.dictionary['foo'] = 'bar' - a.dictionary['bar'] = Child(44) +def test_dict(): + a = TDict({"foo": Child(42)}) assert a.get_state() == { - "child": None, - "children": None, - "dictionary": {'bar': {'x': 44}, 'foo': 'bar'}, + "x": {"foo": {"x": 42}} } - copy = a.copy() - assert len(copy.dictionary) == 2 - assert copy.dictionary is not a.dictionary - assert copy.dictionary['bar'] is not a.dictionary['bar'] + b = a.copy() + assert list(a.x.items()) == list(b.x.items()) + assert a.x is not b.x + assert a.x["foo"] is not b.x["foo"] + + +def test_any(): + a = TAny(42) + b = a.copy() + assert a.x == b.x + + a = TAny(object()) + with pytest.raises(AssertionError): + a.get_state() def test_too_much_state(): - a = Container() - a.child = Child(42) + a = Child(42) s = a.get_state() s['foo'] = 'bar' - b = Container() with pytest.raises(RuntimeWarning): - b.set_state(s) + a.set_state(s) + + +def test_none(): + a = Child(None) + assert a.get_state() == {"x": None} + a = Child(42) + a.set_state({"x": None}) + assert a.x is None diff --git a/test/mitmproxy/utils/test_typecheck.py b/test/mitmproxy/utils/test_typecheck.py index 5295fff55..9cb4334e0 100644 --- a/test/mitmproxy/utils/test_typecheck.py +++ b/test/mitmproxy/utils/test_typecheck.py @@ -93,3 +93,8 @@ def test_typesec_to_str(): assert(typecheck.typespec_to_str(typing.Optional[str])) == "optional str" with pytest.raises(NotImplementedError): typecheck.typespec_to_str(dict) + + +def test_mapping_types(): + # this is not covered by check_option_type, but still belongs in this module + assert (str, int) == typecheck.mapping_types(typing.Mapping[str, int])