From a6b12fc6ebf85c6fcddf02cd1c499486f52209a8 Mon Sep 17 00:00:00 2001 From: Pierre LALET Date: Mon, 6 Mar 2017 07:24:33 +0100 Subject: [PATCH] Flags: allow operations (| & ==) with strings: pkt[TCP].flags |= "SA" --- scapy/fields.py | 27 ++++++++++----------------- test/regression.uts | 12 ++++++++++++ 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/scapy/fields.py b/scapy/fields.py index a4b45c671..25dbc2ecf 100644 --- a/scapy/fields.py +++ b/scapy/fields.py @@ -1015,40 +1015,33 @@ class LEFieldLenField(FieldLenField): class FlagValue(object): __slots__ = ["value", "names", "multi"] - @staticmethod - def __fixvalue(value, names): + def _fixvalue(self, value): if isinstance(value, basestring): - if isinstance(names, list): - value = value.split('+') - else: - value = list(value) + value = value.split('+') if self.multi else list(value) if isinstance(value, list): y = 0 for i in value: - y |= 1 << names.index(i) + y |= 1 << self.names.index(i) value = y - return value + return None if value is None else int(value) def __init__(self, value, names): - self.value = (value.value if isinstance(value, self.__class__) - else self.__fixvalue(value, names)) self.multi = isinstance(names, list) self.names = names + self.value = self._fixvalue(value) def __int__(self): return self.value def __cmp__(self, other): - if isinstance(other, self.__class__): - return cmp(self.value, other.value) - return cmp(self.value, other) + return cmp(self.value, self._fixvalue(other)) def __and__(self, other): - return self.__class__(self.value & int(other), self.names) + return self.__class__(self.value & self._fixvalue(other), self.names) __rand__ = __and__ def __or__(self, other): - return self.__class__(self.value | int(other), self.names) + return self.__class__(self.value | self._fixvalue(other), self.names) __ror__ = __or__ def __lshift__(self, other): - return self.value << int(other) + return self.value << self._fixvalue(other) def __rshift__(self, other): - return self.value >> int(other) + return self.value >> self._fixvalue(other) def __nonzero__(self): return bool(self.value) def flagrepr(self): diff --git a/test/regression.uts b/test/regression.uts index d53fe23e1..2901e4276 100644 --- a/test/regression.uts +++ b/test/regression.uts @@ -7184,6 +7184,12 @@ assert not pkt.flags.MF assert pkt.flags.DF assert not pkt.flags.evil assert repr(pkt.flags) == '' +pkt.flags |= 'evil+MF' +pkt.flags &= 'DF+MF' +assert pkt.flags.MF +assert pkt.flags.DF +assert not pkt.flags.evil +assert repr(pkt.flags) == '' pkt = IP(flags=3) assert pkt.flags.MF @@ -7213,6 +7219,12 @@ assert pkt.flags.U assert pkt.flags.AU assert not any(getattr(pkt.flags, f) for f in 'FSRPECN') assert repr(pkt.flags) == '' +pkt.flags &= 'SFA' +pkt.flags |= 'P' +assert pkt.flags.P +assert pkt.flags.A +assert pkt.flags.PA +assert not any(getattr(pkt.flags, f) for f in 'FSRUECN') pkt = TCP(flags=56) assert all(getattr(pkt.flags, f) for f in 'PAU')