commit b558997fd9db8406b2a24a1831d06e283dbf35a6 Author: Aldo Cortesi Date: Tue Jun 19 09:42:32 2012 +1200 Initial checkin. diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..99f57cb0f --- /dev/null +++ b/.coveragerc @@ -0,0 +1,2 @@ +[report] +include = *netlib* diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..f53cd2e25 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +MANIFEST +/build +/dist +/tmp +/doc +*.py[cdo] +*.swp +*.swo +.coverage diff --git a/README b/README new file mode 100644 index 000000000..1c86738cc --- /dev/null +++ b/README @@ -0,0 +1,2 @@ +Netlib is a collection of common utility functions, used by the pathod and +mitmproxy projects. diff --git a/netlib/__init__.py b/netlib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/netlib/odict.py b/netlib/odict.py new file mode 100644 index 000000000..afc33caa6 --- /dev/null +++ b/netlib/odict.py @@ -0,0 +1,160 @@ +import re, copy + +def safe_subn(pattern, repl, target, *args, **kwargs): + """ + There are Unicode conversion problems with re.subn. We try to smooth + that over by casting the pattern and replacement to strings. We really + need a better solution that is aware of the actual content ecoding. + """ + return re.subn(str(pattern), str(repl), target, *args, **kwargs) + + +class ODict: + """ + A dictionary-like object for managing ordered (key, value) data. + """ + def __init__(self, lst=None): + self.lst = lst or [] + + def _kconv(self, s): + return s + + def __eq__(self, other): + return self.lst == other.lst + + def __getitem__(self, k): + """ + Returns a list of values matching key. + """ + ret = [] + k = self._kconv(k) + for i in self.lst: + if self._kconv(i[0]) == k: + ret.append(i[1]) + return ret + + def _filter_lst(self, k, lst): + k = self._kconv(k) + new = [] + for i in lst: + if self._kconv(i[0]) != k: + new.append(i) + return new + + def __len__(self): + """ + Total number of (key, value) pairs. + """ + return len(self.lst) + + def __setitem__(self, k, valuelist): + """ + Sets the values for key k. If there are existing values for this + key, they are cleared. + """ + if isinstance(valuelist, basestring): + raise ValueError("ODict valuelist should be lists.") + new = self._filter_lst(k, self.lst) + for i in valuelist: + new.append([k, i]) + self.lst = new + + def __delitem__(self, k): + """ + Delete all items matching k. + """ + self.lst = self._filter_lst(k, self.lst) + + def __contains__(self, k): + for i in self.lst: + if self._kconv(i[0]) == self._kconv(k): + return True + return False + + def add(self, key, value): + self.lst.append([key, str(value)]) + + def get(self, k, d=None): + if k in self: + return self[k] + else: + return d + + def items(self): + return self.lst[:] + + def _get_state(self): + return [tuple(i) for i in self.lst] + + @classmethod + def _from_state(klass, state): + return klass([list(i) for i in state]) + + def copy(self): + """ + Returns a copy of this object. + """ + lst = copy.deepcopy(self.lst) + return self.__class__(lst) + + def __repr__(self): + elements = [] + for itm in self.lst: + elements.append(itm[0] + ": " + itm[1]) + elements.append("") + return "\r\n".join(elements) + + def in_any(self, key, value, caseless=False): + """ + Do any of the values matching key contain value? + + If caseless is true, value comparison is case-insensitive. + """ + if caseless: + value = value.lower() + for i in self[key]: + if caseless: + i = i.lower() + if value in i: + return True + return False + + def match_re(self, expr): + """ + Match the regular expression against each (key, value) pair. For + each pair a string of the following format is matched against: + + "key: value" + """ + for k, v in self.lst: + s = "%s: %s"%(k, v) + if re.search(expr, s): + return True + return False + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both keys and + values. Encoded content will be decoded before replacement, and + re-encoded afterwards. + + Returns the number of replacements made. + """ + nlst, count = [], 0 + for i in self.lst: + k, c = safe_subn(pattern, repl, i[0], *args, **kwargs) + count += c + v, c = safe_subn(pattern, repl, i[1], *args, **kwargs) + count += c + nlst.append([k, v]) + self.lst = nlst + return count + + +class ODictCaseless(ODict): + """ + A variant of ODict with "caseless" keys. This version _preserves_ key + case, but does not consider case when setting or getting items. + """ + def _kconv(self, s): + return s.lower() diff --git a/netlib/protocol.py b/netlib/protocol.py new file mode 100644 index 000000000..55bcf4405 --- /dev/null +++ b/netlib/protocol.py @@ -0,0 +1,218 @@ +import string, urlparse + +class ProtocolError(Exception): + def __init__(self, code, msg): + self.code, self.msg = code, msg + + def __str__(self): + return "ProtocolError(%s, %s)"%(self.code, self.msg) + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + """ + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + if not scheme: + return None + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + return scheme, host, port, path + + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line + is reached. Return a ODictCaseless object. + """ + ret = [] + name = '' + while 1: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i+1:].strip() + ret.append([name, value]) + return ret + + +def read_chunked(fp, limit): + content = "" + total = 0 + while 1: + line = fp.readline(128) + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + continue + try: + length = int(line,16) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise ProtocolError(400, "Invalid chunked encoding length: %s"%line) + if not length: + break + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large."\ + " Limit is %s, chunked content length was at least %s"%(limit, total) + raise ProtocolError(509, msg) + content += fp.read(length) + line = fp.readline(5) + if line != '\r\n': + raise IOError("Malformed chunked body") + while 1: + line = fp.readline() + if line == "": + raise IOError("Connection closed") + if line == '\r\n' or line == '\n': + break + return content + + +def has_chunked_encoding(headers): + for i in headers["transfer-encoding"]: + for j in i.split(","): + if j.lower() == "chunked": + return True + return False + + +def read_http_body(rfile, headers, all, limit): + if has_chunked_encoding(headers): + content = read_chunked(rfile, limit) + elif "content-length" in headers: + try: + l = int(headers["content-length"][0]) + except ValueError: + # FIXME: Not strictly correct - this could be from the server, in which + # case we should send a 502. + raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"]) + if limit is not None and l > limit: + raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + content = rfile.read(l) + elif all: + content = rfile.read(limit if limit else None) + else: + content = "" + return content + + +def parse_http_protocol(s): + if not s.startswith("HTTP/"): + return None + major, minor = s.split('/')[1].split('.') + major = int(major) + minor = int(minor) + return major, minor + + +def parse_init_connect(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if method != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + port = int(port) + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return host, port, httpversion + + +def parse_init_proxy(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + parts = parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + try: + method, url, protocol = string.split(line) + except ValueError: + return None + if not (url.startswith("/") or url == "*"): + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + return method, url, httpversion + + +def request_connection_close(httpversion, headers): + """ + Checks the request to see if the client connection should be closed. + """ + if "connection" in headers: + for value in ",".join(headers['connection']).split(","): + value = value.strip() + if value == "close": + return True + elif value == "keep-alive": + return False + # HTTP 1.1 connections are assumed to be persistent + if httpversion == (1, 1): + return False + return True + + +def response_connection_close(httpversion, headers): + """ + Checks the response to see if the client connection should be closed. + """ + if request_connection_close(httpversion, headers): + return True + elif not has_chunked_encoding(headers) and "content-length" in headers: + return True + return False + + +def read_http_body_request(rfile, wfile, headers, httpversion, limit): + if "expect" in headers: + # FIXME: Should be forwarded upstream + expect = ",".join(headers['expect']) + if expect == "100-continue" and httpversion >= (1, 1): + wfile.write('HTTP/1.1 100 Continue\r\n') + wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) + wfile.write('\r\n') + del headers['expect'] + return read_http_body(rfile, headers, False, limit) diff --git a/netlib/tcp.py b/netlib/tcp.py new file mode 100644 index 000000000..08ccba091 --- /dev/null +++ b/netlib/tcp.py @@ -0,0 +1,182 @@ +import select, socket, threading, traceback, sys +from OpenSSL import SSL + + +class NetLibError(Exception): pass + + +class FileLike: + def __init__(self, o): + self.o = o + + def __getattr__(self, attr): + return getattr(self.o, attr) + + def flush(self): + pass + + def read(self, length): + result = '' + while len(result) < length: + try: + data = self.o.read(length) + except SSL.ZeroReturnError: + break + if not data: + break + result += data + return result + + def write(self, v): + self.o.sendall(v) + + def readline(self, size = None): + result = '' + bytes_read = 0 + while True: + if size is not None and bytes_read >= size: + break + ch = self.read(1) + bytes_read += 1 + if not ch: + break + else: + result += ch + if ch == '\n': + break + return result + + +class TCPClient: + def __init__(self, ssl, host, port, clientcert): + self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert + self.connection, self.rfile, self.wfile = None, None, None + self.cert = None + self.connect() + + def connect(self): + try: + addr = socket.gethostbyname(self.host) + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if self.ssl: + context = SSL.Context(SSL.SSLv23_METHOD) + if self.clientcert: + context.use_certificate_file(self.clientcert) + server = SSL.Connection(context, server) + server.connect((addr, self.port)) + if self.ssl: + self.cert = server.get_peer_certificate() + self.rfile, self.wfile = FileLike(server), FileLike(server) + else: + self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb') + except socket.error, err: + raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) + self.connection = server + + +class BaseHandler: + rbufsize = -1 + wbufsize = 0 + def __init__(self, connection, client_address, server): + self.connection = connection + self.rfile = self.connection.makefile('rb', self.rbufsize) + self.wfile = self.connection.makefile('wb', self.wbufsize) + + self.client_address = client_address + self.server = server + self.handle() + self.finish() + + def convert_to_ssl(self, cert, key): + ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.use_privatekey_file(key) + ctx.use_certificate_file(cert) + self.connection = SSL.Connection(ctx, self.connection) + self.connection.set_accept_state() + self.rfile = FileLike(self.connection) + self.wfile = FileLike(self.connection) + + def finish(self): + try: + if not getattr(self.wfile, "closed", False): + self.wfile.flush() + self.connection.close() + self.wfile.close() + self.rfile.close() + except IOError: # pragma: no cover + pass + + def handle(self): # pragma: no cover + raise NotImplementedError + + +class TCPServer: + request_queue_size = 20 + def __init__(self, server_address): + self.server_address = server_address + self.__is_shut_down = threading.Event() + self.__shutdown_request = False + self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.socket.bind(self.server_address) + self.server_address = self.socket.getsockname() + self.socket.listen(self.request_queue_size) + self.port = self.socket.getsockname()[1] + + def request_thread(self, request, client_address): + try: + self.handle_connection(request, client_address) + request.close() + except: + self.handle_error(request, client_address) + request.close() + + def serve_forever(self, poll_interval=0.5): + self.__is_shut_down.clear() + try: + while not self.__shutdown_request: + r, w, e = select.select([self.socket], [], [], poll_interval) + if self.socket in r: + try: + request, client_address = self.socket.accept() + except socket.error: + return + try: + t = threading.Thread( + target = self.request_thread, + args = (request, client_address) + ) + t.setDaemon(1) + t.start() + except: + self.handle_error(request, client_address) + request.close() + finally: + self.__shutdown_request = False + self.__is_shut_down.set() + + def shutdown(self): + self.__shutdown_request = True + self.__is_shut_down.wait() + self.handle_shutdown() + + def handle_error(self, request, client_address, fp=sys.stderr): + """ + Called when handle_connection raises an exception. + """ + print >> fp, '-'*40 + print >> fp, "Error processing of request from %s:%s"%client_address + print >> fp, traceback.format_exc() + print >> fp, '-'*40 + + def handle_connection(self, request, client_address): # pragma: no cover + """ + Called after client connection. + """ + raise NotImplementedError + + def handle_shutdown(self): + """ + Called after server shutdown. + """ + pass diff --git a/test/test_odict.py b/test/test_odict.py new file mode 100644 index 000000000..e7453e2dd --- /dev/null +++ b/test/test_odict.py @@ -0,0 +1,113 @@ +from netlib import odict +import tutils + + +class TestODict: + def setUp(self): + self.od = odict.ODict() + + def test_str_err(self): + h = odict.ODict() + tutils.raises(ValueError, h.__setitem__, "key", "foo") + + def test_dictToHeader1(self): + self.od.add("one", "uno") + self.od.add("two", "due") + self.od.add("two", "tre") + expected = [ + "one: uno\r\n", + "two: due\r\n", + "two: tre\r\n", + "\r\n" + ] + out = repr(self.od) + for i in expected: + assert out.find(i) >= 0 + + def test_dictToHeader2(self): + self.od["one"] = ["uno"] + expected1 = "one: uno\r\n" + expected2 = "\r\n" + out = repr(self.od) + assert out.find(expected1) >= 0 + assert out.find(expected2) >= 0 + + def test_match_re(self): + h = odict.ODict() + h.add("one", "uno") + h.add("two", "due") + h.add("two", "tre") + assert h.match_re("uno") + assert h.match_re("two: due") + assert not h.match_re("nonono") + + def test_getset_state(self): + self.od.add("foo", 1) + self.od.add("foo", 2) + self.od.add("bar", 3) + state = self.od._get_state() + nd = odict.ODict._from_state(state) + assert nd == self.od + + def test_in_any(self): + self.od["one"] = ["atwoa", "athreea"] + assert self.od.in_any("one", "two") + assert self.od.in_any("one", "three") + assert not self.od.in_any("one", "four") + assert not self.od.in_any("nonexistent", "foo") + assert not self.od.in_any("one", "TWO") + assert self.od.in_any("one", "TWO", True) + + def test_copy(self): + self.od.add("foo", 1) + self.od.add("foo", 2) + self.od.add("bar", 3) + assert self.od == self.od.copy() + + def test_del(self): + self.od.add("foo", 1) + self.od.add("Foo", 2) + self.od.add("bar", 3) + del self.od["foo"] + assert len(self.od.lst) == 2 + + def test_replace(self): + self.od.add("one", "two") + self.od.add("two", "one") + assert self.od.replace("one", "vun") == 2 + assert self.od.lst == [ + ["vun", "two"], + ["two", "vun"], + ] + + def test_get(self): + self.od.add("one", "two") + assert self.od.get("one") == ["two"] + assert self.od.get("two") == None + + +class TestODictCaseless: + def setUp(self): + self.od = odict.ODictCaseless() + + def test_override(self): + o = odict.ODictCaseless() + o.add('T', 'application/x-www-form-urlencoded; charset=UTF-8') + o["T"] = ["foo"] + assert o["T"] == ["foo"] + + def test_case_preservation(self): + self.od["Foo"] = ["1"] + assert "foo" in self.od + assert self.od.items()[0][0] == "Foo" + assert self.od.get("foo") == ["1"] + assert self.od.get("foo", [""]) == ["1"] + assert self.od.get("Foo", [""]) == ["1"] + assert self.od.get("xx", "yy") == "yy" + + def test_del(self): + self.od.add("foo", 1) + self.od.add("Foo", 2) + self.od.add("bar", 3) + del self.od["foo"] + assert len(self.od) == 1 diff --git a/test/test_protocol.py b/test/test_protocol.py new file mode 100644 index 000000000..028faadd3 --- /dev/null +++ b/test/test_protocol.py @@ -0,0 +1,163 @@ +import cStringIO, textwrap +from netlib import protocol, odict +import tutils + +def test_has_chunked_encoding(): + h = odict.ODictCaseless() + assert not protocol.has_chunked_encoding(h) + h["transfer-encoding"] = ["chunked"] + assert protocol.has_chunked_encoding(h) + + +def test_read_chunked(): + s = cStringIO.StringIO("1\r\na\r\n0\r\n") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") + assert protocol.read_chunked(s, None) == "a" + + s = cStringIO.StringIO("\r\n") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("1\r\nfoo") + tutils.raises(IOError, protocol.read_chunked, s, None) + + s = cStringIO.StringIO("foo\r\nfoo") + tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None) + + +def test_request_connection_close(): + h = odict.ODictCaseless() + assert protocol.request_connection_close((1, 0), h) + assert not protocol.request_connection_close((1, 1), h) + + h["connection"] = ["keep-alive"] + assert not protocol.request_connection_close((1, 1), h) + + +def test_read_http_body(): + h = odict.ODict() + s = cStringIO.StringIO("testing") + assert protocol.read_http_body(s, h, False, None) == "" + + h["content-length"] = ["foo"] + s = cStringIO.StringIO("testing") + tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None) + + h["content-length"] = [5] + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, False, None)) == 5 + s = cStringIO.StringIO("testing") + tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4) + + h = odict.ODict() + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, True, 4)) == 4 + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, True, 100)) == 7 + +def test_parse_http_protocol(): + assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not protocol.parse_http_protocol("foo/0.0") + + +def test_parse_init_connect(): + assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("bogus") + assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + + +def test_prase_init_proxy(): + u = "GET http://foo.com:8888/test HTTP/1.1" + m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + assert m == "GET" + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + assert httpversion == (1, 1) + + assert not protocol.parse_init_proxy("invalid") + assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + + +def test_parse_init_http(): + u = "GET /test HTTP/1.1" + m, u, httpversion= protocol.parse_init_http(u) + assert m == "GET" + assert u == "/test" + assert httpversion == (1, 1) + + assert not protocol.parse_init_http("invalid") + assert not protocol.parse_init_http("GET invalid HTTP/1.1") + assert not protocol.parse_init_http("GET /test foo/1.1") + + +class TestReadHeaders: + def test_read_simple(self): + data = """ + Header: one + Header2: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = protocol.read_headers(s) + assert h == [["Header", "one"], ["Header2", "two"]] + + def test_read_multi(self): + data = """ + Header: one + Header: two + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = protocol.read_headers(s) + assert h == [["Header", "one"], ["Header", "two"]] + + def test_read_continued(self): + data = """ + Header: one + \ttwo + Header2: three + \r\n + """ + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + h = protocol.read_headers(s) + assert h == [["Header", "one\r\n two"], ["Header2", "three"]] + + +def test_parse_url(): + assert not protocol.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = protocol.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = protocol.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = protocol.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = protocol.parse_url("https://foo") + assert po == 443 + + assert not protocol.parse_url("https://foo:bar") + assert not protocol.parse_url("https://foo:") + diff --git a/test/test_tcp.py b/test/test_tcp.py new file mode 100644 index 000000000..d7d4483e8 --- /dev/null +++ b/test/test_tcp.py @@ -0,0 +1,93 @@ +import cStringIO, threading, Queue +from netlib import tcp +import tutils + +class ServerThread(threading.Thread): + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase: + @classmethod + def setupAll(cls): + cls.server = ServerThread(cls.makeserver()) + cls.server.start() + + @classmethod + def teardownAll(cls): + cls.server.shutdown() + + +class THandler(tcp.BaseHandler): + def handle(self): + v = self.rfile.readline() + if v.startswith("echo"): + self.wfile.write(v) + elif v.startswith("error"): + raise ValueError("Testing an error.") + self.wfile.flush() + + +class TServer(tcp.TCPServer): + def __init__(self, addr, q): + tcp.TCPServer.__init__(self, addr) + self.q = q + + def handle_connection(self, request, client_address): + THandler(request, client_address, self) + + def handle_error(self, request, client_address): + s = cStringIO.StringIO() + tcp.TCPServer.handle_error(self, request, client_address, s) + self.q.put(s.getvalue()) + + +class TestServer(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), cls.q) + cls.port = s.port + return s + + def test_echo(self): + testval = "echo!\n" + c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c.wfile.write(testval) + c.wfile.flush() + assert c.rfile.readline() == testval + + def test_error(self): + testval = "error!\n" + c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c.wfile.write(testval) + c.wfile.flush() + assert "Testing an error" in self.q.get() + + +class TestTCPClient: + def test_conerr(self): + tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None) + + +class TestFileLike: + def test_wrap(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.FileLike(s) + s.flush() + assert s.readline() == "foobar\n" + assert s.readline() == "foobar" + # Test __getattr__ + assert s.isatty + + def test_limit(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.FileLike(s) + assert s.readline(3) == "foo" diff --git a/test/tutils.py b/test/tutils.py new file mode 100644 index 000000000..c8e06b962 --- /dev/null +++ b/test/tutils.py @@ -0,0 +1,56 @@ +import tempfile, os, shutil +from contextlib import contextmanager +from libpathod import utils + + +@contextmanager +def tmpdir(*args, **kwargs): + orig_workdir = os.getcwd() + temp_workdir = tempfile.mkdtemp(*args, **kwargs) + os.chdir(temp_workdir) + + yield temp_workdir + + os.chdir(orig_workdir) + shutil.rmtree(temp_workdir) + + +def raises(exc, obj, *args, **kwargs): + """ + Assert that a callable raises a specified exception. + + :exc An exception class or a string. If a class, assert that an + exception of this type is raised. If a string, assert that the string + occurs in the string representation of the exception, based on a + case-insenstivie match. + + :obj A callable object. + + :args Arguments to be passsed to the callable. + + :kwargs Arguments to be passed to the callable. + """ + try: + apply(obj, args, kwargs) + except Exception, v: + if isinstance(exc, basestring): + if exc.lower() in str(v).lower(): + return + else: + raise AssertionError( + "Expected %s, but caught %s"%( + repr(str(exc)), v + ) + ) + else: + if isinstance(v, exc): + return + else: + raise AssertionError( + "Expected %s, but caught %s %s"%( + exc.__name__, v.__class__.__name__, str(v) + ) + ) + raise AssertionError("No exception raised.") + +test_data = utils.Data(__name__)