Flags: allow operations (| & ==) with strings: pkt[TCP].flags |= "SA"

This commit is contained in:
Pierre LALET 2017-03-06 07:24:33 +01:00
parent 0d92fc9fbb
commit a6b12fc6eb
2 changed files with 22 additions and 17 deletions

View File

@ -1015,40 +1015,33 @@ class LEFieldLenField(FieldLenField):
class FlagValue(object):
__slots__ = ["value", "names", "multi"]
@staticmethod
def __fixvalue(value, names):
def _fixvalue(self, value):
if isinstance(value, basestring):
if isinstance(names, list):
value = value.split('+')
else:
value = list(value)
value = value.split('+') if self.multi else list(value)
if isinstance(value, list):
y = 0
for i in value:
y |= 1 << names.index(i)
y |= 1 << self.names.index(i)
value = y
return value
return None if value is None else int(value)
def __init__(self, value, names):
self.value = (value.value if isinstance(value, self.__class__)
else self.__fixvalue(value, names))
self.multi = isinstance(names, list)
self.names = names
self.value = self._fixvalue(value)
def __int__(self):
return self.value
def __cmp__(self, other):
if isinstance(other, self.__class__):
return cmp(self.value, other.value)
return cmp(self.value, other)
return cmp(self.value, self._fixvalue(other))
def __and__(self, other):
return self.__class__(self.value & int(other), self.names)
return self.__class__(self.value & self._fixvalue(other), self.names)
__rand__ = __and__
def __or__(self, other):
return self.__class__(self.value | int(other), self.names)
return self.__class__(self.value | self._fixvalue(other), self.names)
__ror__ = __or__
def __lshift__(self, other):
return self.value << int(other)
return self.value << self._fixvalue(other)
def __rshift__(self, other):
return self.value >> int(other)
return self.value >> self._fixvalue(other)
def __nonzero__(self):
return bool(self.value)
def flagrepr(self):

View File

@ -7184,6 +7184,12 @@ assert not pkt.flags.MF
assert pkt.flags.DF
assert not pkt.flags.evil
assert repr(pkt.flags) == '<Flag 2 (DF)>'
pkt.flags |= 'evil+MF'
pkt.flags &= 'DF+MF'
assert pkt.flags.MF
assert pkt.flags.DF
assert not pkt.flags.evil
assert repr(pkt.flags) == '<Flag 3 (MF+DF)>'
pkt = IP(flags=3)
assert pkt.flags.MF
@ -7213,6 +7219,12 @@ assert pkt.flags.U
assert pkt.flags.AU
assert not any(getattr(pkt.flags, f) for f in 'FSRPECN')
assert repr(pkt.flags) == '<Flag 48 (AU)>'
pkt.flags &= 'SFA'
pkt.flags |= 'P'
assert pkt.flags.P
assert pkt.flags.A
assert pkt.flags.PA
assert not any(getattr(pkt.flags, f) for f in 'FSRUECN')
pkt = TCP(flags=56)
assert all(getattr(pkt.flags, f) for f in 'PAU')