# -*- coding: utf-8 -*- import sys import time import errno import socket import threading from boltons.socketutils import (BufferedSocket, NetstringSocket, ConnectionClosed, NetstringMessageTooLong, MessageTooLong, Timeout) import pytest # skip if there's no socketpair pytestmark = pytest.mark.skipif(getattr(socket, 'socketpair', None) is None, reason='no socketpair (likely Py2 on Windows)') def test_short_lines(): for ms in (2, 4, 6, 1024, None): x, y = socket.socketpair() bs = BufferedSocket(x) y.sendall(b'1\n2\n3\n') assert bs.recv_until(b'\n', maxsize=ms) == b'1' assert bs.recv_until(b'\n', maxsize=ms) == b'2' y.close() assert bs.recv_close(maxsize=ms) == b'3\n' try: bs.recv_size(1) except ConnectionClosed: pass else: assert False, 'expected ConnectionClosed' bs.close() return def test_multibyte_delim(): """Primarily tests recv_until with various maxsizes and True/False for with_delimiter. """ delim = b'\r\n' for with_delim in (True, False): if with_delim: cond_delim = b'\r\n' else: cond_delim = b'' empty = b'' small_one = b'1' big_two = b'2' * 2048 for ms in (3, 5, 1024, None): x, y = socket.socketpair() bs = BufferedSocket(x) y.sendall(empty + delim) y.sendall(small_one + delim) y.sendall(big_two + delim) kwargs = {'maxsize': ms, 'with_delimiter': with_delim} assert bs.recv_until(delim, **kwargs) == empty + cond_delim assert bs.recv_until(delim, **kwargs) == small_one + cond_delim try: assert bs.recv_until(delim, **kwargs) == big_two + cond_delim except MessageTooLong: if ms is None: assert False, 'unexpected MessageTooLong' else: if ms is not None: assert False, 'expected MessageTooLong' return def test_props(): x, y = socket.socketpair() bs = BufferedSocket(x) assert bs.type == x.type assert bs.proto == x.proto assert bs.family == x.family return def test_buffers(): x, y = socket.socketpair() bx, by = BufferedSocket(x), BufferedSocket(y) assert by.getrecvbuffer() == b'' assert by.getsendbuffer() == b'' assert bx.getrecvbuffer() == b'' by.buffer(b'12') by.sendall(b'3') assert bx.recv_size(1) == b'1' assert bx.getrecvbuffer() == b'23' return def test_client_disconnecting(): def get_bs_pair(): x, y = socket.socketpair() bx, by = BufferedSocket(x), BufferedSocket(y) # sanity check by.sendall(b'123') bx.recv_size(3) == b'123' return bx, by bx, by = get_bs_pair() assert bx.fileno() > 0 bx.close() assert bx.getrecvbuffer() == b'' try: bx.recv(1) except socket.error: pass else: assert False, 'expected socket.error on closed recv' assert bx.fileno() == -1 by.buffer(b'123') assert by.getsendbuffer() try: by.flush() except socket.error: assert by.getsendbuffer() == b'123' else: if sys.platform != 'win32': # Windows socketpairs are kind of bad assert False, 'expected socket.error broken pipe' by.shutdown(socket.SHUT_RDWR) by.close() assert not by.getsendbuffer() try: by.send(b'123') except socket.error: pass else: assert False, 'expected socket.error on closed send' return def test_split_delim(): delim = b'\r\n' first = b'1234\r' second = b'\n5' x, y = socket.socketpair() bs = BufferedSocket(x) y.sendall(first) try: bs.recv_until(delim, timeout=0.0001) except Timeout: pass y.sendall(second) assert bs.recv_until(delim, with_delimiter=True) == b'1234\r\n' assert bs.recv_size(1) == b'5' return def test_basic_nonblocking(): delim = b'\n' # test with per-call timeout x, y = socket.socketpair() bs = BufferedSocket(x) try: bs.recv_until(delim, timeout=0) except socket.error as se: assert se.errno == errno.EWOULDBLOCK y.sendall(delim) # sending an empty message, effectively assert bs.recv_until(delim) == b'' # test with instance-level default timeout x, y = socket.socketpair() bs = BufferedSocket(x, timeout=0) try: bs.recv_until(delim) except socket.error as se: assert se.errno == errno.EWOULDBLOCK y.sendall(delim) assert bs.recv_until(delim) == b'' # test with setblocking(0) on the underlying socket x, y = socket.socketpair() x.setblocking(0) bs = BufferedSocket(x) try: bs.recv_until(delim) except socket.error as se: assert se.errno == errno.EWOULDBLOCK y.sendall(delim) assert bs.recv_until(delim) == b'' return def test_simple_buffered_socket_passthroughs(): x, y = socket.socketpair() bs = BufferedSocket(x) assert bs.getsockname() == x.getsockname() assert bs.getpeername() == x.getpeername() def test_timeout_setters_getters(): x, y = socket.socketpair() bs = BufferedSocket(x) assert bs.settimeout(1.0) is None assert bs.gettimeout() == 1.0 assert bs.setblocking(False) is None assert bs.gettimeout() == 0.0 assert bs.setblocking(True) is None assert bs.gettimeout() is None def netstring_server(server_socket): "A basic netstring server loop, supporting a few operations" running = True try: while running: clientsock, addr = server_socket.accept() client = NetstringSocket(clientsock) while 1: request = client.read_ns() if request == b'close': clientsock.close() break elif request == b'shutdown': running = False break elif request == b'reply4k': client.write_ns(b'a' * 4096) elif request == b'ping': client.write_ns(b'pong') elif request == b'reply128k': client.setmaxsize(128 * 1024) client.write_ns(b'huge' * 32 * 1024) # 128kb client.setmaxsize(32768) # back to default except Exception as e: print(u'netstring_server exiting with error: %r' % e) raise return def test_socketutils_netstring(): """A holistic feature test of BufferedSocket via the NetstringSocket wrapper. Runs """ print("running self tests") # Set up server server_socket = socket.socket() server_socket.bind(('127.0.0.1', 0)) # localhost with ephemeral port server_socket.listen(100) ip, port = server_socket.getsockname() start_server = lambda: netstring_server(server_socket) threading.Thread(target=start_server).start() # set up client def client_connect(): clientsock = socket.create_connection((ip, port)) client = NetstringSocket(clientsock) return client # connect, ping-pong client = client_connect() client.write_ns(b'ping') assert client.read_ns() == b'pong' s = time.time() for i in range(1000): client.write_ns(b'ping') assert client.read_ns() == b'pong' dur = time.time() - s print("netstring ping-pong latency", dur, "ms") s = time.time() for i in range(1000): client.write_ns(b'ping') resps = [] for i in range(1000): resps.append(client.read_ns()) e = time.time() assert all([r == b'pong' for r in resps]) assert client.bsock.getrecvbuffer() == b'' dur = e - s print("netstring pipelined ping-pong latency", dur, "ms") # tell the server to close the socket and then try a failure case client.write_ns(b'close') try: client.read_ns() raise Exception('read from closed socket') except ConnectionClosed: print("raised ConnectionClosed correctly") # test big messages client = client_connect() client.setmaxsize(128 * 1024) client.write_ns(b'reply128k') res = client.read_ns() assert len(res) == (128 * 1024) client.write_ns(b'close') # test that read timeouts work client = client_connect() client.settimeout(0.1) try: client.read_ns() raise Exception('did not timeout') except Timeout: print("read_ns raised timeout correctly") client.write_ns(b'close') # test that netstring max sizes work client = client_connect() client.setmaxsize(2048) client.write_ns(b'reply4k') try: client.read_ns() raise Exception('read more than maxsize') except NetstringMessageTooLong: print("raised MessageTooLong correctly") try: client.bsock.recv_until(b'b', maxsize=4096) raise Exception('recv_until did not raise MessageTooLong') except MessageTooLong: print("raised MessageTooLong correctly") assert client.bsock.recv_size(4097) == b'a' * 4096 + b',' print('correctly maintained buffer after exception raised') # test BufferedSocket read timeouts with recv_until and recv_size client.bsock.settimeout(0.01) try: client.bsock.recv_until(b'a') raise Exception('recv_until did not raise Timeout') except Timeout: print('recv_until correctly raised Timeout') try: client.bsock.recv_size(1) raise Exception('recv_size did not raise Timeout') except Timeout: print('recv_size correctly raised Timeout') client.write_ns(b'shutdown') print("all passed")