diff --git a/netlib/tcp.py b/netlib/tcp.py index e1318435b..414c12377 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -284,6 +284,9 @@ class BaseHandler: def handle(self): # pragma: no cover raise NotImplementedError + def settimeout(self, n): + self.connection.settimeout(n) + def close(self): """ Does a hard close of the socket, i.e. a shutdown, followed by a close. diff --git a/test/test_tcp.py b/test/test_tcp.py index 9d581939c..c833ce077 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -28,6 +28,11 @@ class ServerTestBase: cls.server.shutdown() + @property + def last_handler(self): + return self.server.server.last_handler + + class SNIHandler(tcp.BaseHandler): sni = None def handle_sni(self, connection): @@ -63,15 +68,27 @@ class HangHandler(tcp.BaseHandler): time.sleep(1) +class TimeoutHandler(tcp.BaseHandler): + def handle(self): + self.timeout = False + self.settimeout(0.01) + try: + self.rfile.read(10) + except tcp.NetLibTimeout: + self.timeout = True + + class TServer(tcp.TCPServer): - def __init__(self, addr, ssl, q, handler, v3_only=False): + def __init__(self, addr, ssl, q, handler_klass, v3_only=False): tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q self.v3_only = v3_only - self.handler = handler + self.handler_klass = handler_klass + self.last_handler = None def handle_connection(self, request, client_address): - h = self.handler(request, client_address, self) + h = self.handler_klass(request, client_address, self) + self.last_handler = h if self.ssl: if self.v3_only: method = tcp.SSLv3_METHOD @@ -194,12 +211,24 @@ class TestDisconnect(ServerTestBase): c.close() +class TestServerTimeOut(ServerTestBase): + @classmethod + def makeserver(cls): + return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler) + + def test_timeout(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + time.sleep(0.3) + assert self.last_handler.timeout + + class TestTimeOut(ServerTestBase): @classmethod def makeserver(cls): return TServer(("127.0.0.1", 0), False, cls.q, HangHandler) - def test_timeout_client(self): + def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() c.settimeout(0.1)