diff --git a/netlib/tcp.py b/netlib/tcp.py index 0fed7380d..e1318435b 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -39,10 +39,11 @@ class NetLibDisconnect(Exception): pass class NetLibTimeout(Exception): pass -class FileLike: +class _FileLike: BLOCKSIZE = 1024 * 32 def __init__(self, o): self.o = o + self._log = None def set_descriptor(self, o): self.o = o @@ -50,6 +51,37 @@ class FileLike: def __getattr__(self, attr): return getattr(self.o, attr) + def start_log(self): + """ + Starts or resets the log. + + This will store all bytes read or written. + """ + self._log = [] + + def stop_log(self): + """ + Stops the log. + """ + self._log = None + + def is_logging(self): + return self._log is not None + + def get_log(self): + """ + Returns the log as a string. + """ + if not self.is_logging(): + raise ValueError("Not logging!") + return "".join(self._log) + + def add_log(self, v): + if self.is_logging(): + self._log.append(v) + + +class Writer(_FileLike): def flush(self): try: if hasattr(self.o, "flush"): @@ -57,6 +89,21 @@ class FileLike: except socket.error, v: raise NetLibDisconnect(str(v)) + def write(self, v): + if v: + try: + if hasattr(self.o, "sendall"): + self.add_log(v) + return self.o.sendall(v) + else: + r = self.o.write(v) + self.add_log(v[:r]) + return r + except (SSL.Error, socket.error): + raise NetLibDisconnect() + + +class Reader(_FileLike): def read(self, length): """ If length is None, we read until connection closes. @@ -85,19 +132,9 @@ class FileLike: result += data if length != -1: length -= len(data) + self.add_log(result) return result - def write(self, v): - if v: - try: - if hasattr(self.o, "sendall"): - return self.o.sendall(v) - else: - r = self.o.write(v) - return r - except (SSL.Error, socket.error): - raise NetLibDisconnect() - def readline(self, size = None): result = '' bytes_read = 0 @@ -151,8 +188,8 @@ class TCPClient: addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect((addr, self.port)) - self.rfile = FileLike(connection.makefile('rb', self.rbufsize)) - self.wfile = FileLike(connection.makefile('wb', self.wbufsize)) + self.rfile = Reader(connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except socket.error, err: raise NetLibError('Error connecting to "%s": %s' % (self.host, err)) self.connection = connection @@ -186,8 +223,8 @@ class BaseHandler: wbufsize = -1 def __init__(self, connection, client_address, server): self.connection = connection - self.rfile = FileLike(self.connection.makefile('rb', self.rbufsize)) - self.wfile = FileLike(self.connection.makefile('wb', self.wbufsize)) + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) self.client_address = client_address self.server = server diff --git a/test/test_tcp.py b/test/test_tcp.py index 67c56a374..9d581939c 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -228,8 +228,8 @@ class TestTCPClient: class TestFileLike: def test_wrap(self): s = cStringIO.StringIO("foobar\nfoobar") - s = tcp.FileLike(s) s.flush() + s = tcp.Reader(s) assert s.readline() == "foobar\n" assert s.readline() == "foobar" # Test __getattr__ @@ -237,11 +237,39 @@ class TestFileLike: def test_limit(self): s = cStringIO.StringIO("foobar\nfoobar") - s = tcp.FileLike(s) + s = tcp.Reader(s) assert s.readline(3) == "foo" def test_limitless(self): s = cStringIO.StringIO("f"*(50*1024)) - s = tcp.FileLike(s) + s = tcp.Reader(s) ret = s.read(-1) assert len(ret) == 50 * 1024 + + def test_readlog(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.Reader(s) + assert not s.is_logging() + s.start_log() + assert s.is_logging() + s.readline() + assert s.get_log() == "foobar\n" + s.read(1) + assert s.get_log() == "foobar\nf" + s.start_log() + assert s.get_log() == "" + s.read(1) + assert s.get_log() == "o" + s.stop_log() + tutils.raises(ValueError, s.get_log) + + def test_writelog(self): + s = cStringIO.StringIO() + s = tcp.Writer(s) + s.start_log() + assert s.is_logging() + s.write("x") + assert s.get_log() == "x" + s.write("x") + assert s.get_log() == "xx" +