diff --git a/tests/testlib.py b/tests/testlib.py index fb6e55e9..75494ef7 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -1,7 +1,10 @@ import os import random +import re +import socket import sys +import time import unittest import urlparse @@ -27,6 +30,75 @@ def data_path(suffix): return path +def wait_for_port( + host, + port, + pattern=None, + connect_timeout=0.5, + receive_timeout=0.5, + overall_timeout=5.0, + sleep=0.1, + ): + """Attempt to connect to host/port, for upto overall_timeout seconds. + If a regex pattern is supplied try to find it in the initial data. + Return None on success, or raise on error. + """ + start = time.time() + end = start + overall_timeout + addr = (host, port) + + while time.time() < end: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(connect_timeout) + try: + sock.connect(addr) + except socket.error: + # Failed to connect. So wait then retry. + time.sleep(sleep) + continue + + if not pattern: + # Success: We connected & there's no banner check to perform. + sock.shutdown(socket.SHUTD_RDWR) + sock.close() + return + + sock.settimeout(receive_timeout) + data = '' + found = False + while time.time() < end: + try: + resp = sock.recv(1024) + except socket.timeout: + # Server stayed up, but had no data. Retry the recv(). + continue + + if not resp: + # Server went away. Wait then retry the connection. + time.sleep(sleep) + break + + data += resp + if re.search(pattern, data): + found = True + break + + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + if found: + # Success: We received the banner & found the desired pattern + return + else: + # Failure: The overall timeout expired + if pattern: + raise socket.timeout('Timed out while searching for %r from %s:%s' + % (pattern, host, port)) + else: + raise socket.timeout('Timed out while connecting to %s:%s' + % (host, port)) + + class TestCase(unittest.TestCase): def assertRaises(self, exc, func, *args, **kwargs): """Like regular assertRaises, except return the exception that was @@ -61,6 +133,9 @@ class DockerizedSshDaemon(object): parsed = urlparse.urlparse(self.docker.api.base_url) return parsed.netloc.partition(':')[0] + def wait_for_sshd(self): + wait_for_port(self.get_host(), int(self.port), pattern='OpenSSH') + def close(self): self.container.stop() self.container.remove() @@ -86,6 +161,7 @@ class DockerMixin(RouterMixin): def setUpClass(cls): super(DockerMixin, cls).setUpClass() cls.dockerized_ssh = DockerizedSshDaemon() + cls.dockerized_ssh.wait_for_sshd() @classmethod def tearDownClass(cls):