diff --git a/netlib/tcp.py b/netlib/tcp.py index 56cc0dead..a79f3ac43 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -87,6 +87,9 @@ class _FileLike: class Writer(_FileLike): def flush(self): + """ + May raise NetLibDisconnect + """ if hasattr(self.o, "flush"): try: self.o.flush() @@ -94,6 +97,9 @@ class Writer(_FileLike): raise NetLibDisconnect(str(v)) def write(self, v): + """ + May raise NetLibDisconnect + """ if v: try: if hasattr(self.o, "sendall"): diff --git a/test/test_tcp.py b/test/test_tcp.py index ad09143d5..e7524fdc0 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,5 +1,6 @@ -import cStringIO, threading, Queue, time +import cStringIO, threading, Queue, time, socket from netlib import tcp, certutils, test +import mock import tutils class SNIHandler(tcp.BaseHandler): @@ -275,6 +276,22 @@ class TestFileLike: s.write("x") assert s.get_log() == "xx" + def test_writer_flush_error(self): + s = cStringIO.StringIO() + s = tcp.Writer(s) + o = mock.MagicMock() + o.flush = mock.MagicMock(side_effect=socket.error) + s.o = o + tutils.raises(tcp.NetLibDisconnect, s.flush) + + def test_reader_read_error(self): + s = cStringIO.StringIO("foobar\nfoobar") + s = tcp.Reader(s) + o = mock.MagicMock() + o.read = mock.MagicMock(side_effect=socket.error) + s.o = o + tutils.raises(tcp.NetLibDisconnect, s.read, 10) + def test_reset_timestamps(self): s = cStringIO.StringIO("foobar\nfoobar") s = tcp.Reader(s)