diff --git a/Lib/httplib.py b/Lib/httplib.py index 8238f1ae027..1db930ea932 100644 --- a/Lib/httplib.py +++ b/Lib/httplib.py @@ -754,13 +754,59 @@ def getresponse(self): return response -class SSLFile: +# The next several classes are used to define FakeSocket,a socket-like +# interface to an SSL connection. + +# The primary complexity comes from faking a makefile() method. The +# standard socket makefile() implementation calls dup() on the socket +# file descriptor. As a consequence, clients can call close() on the +# parent socket and its makefile children in any order. The underlying +# socket isn't closed until they are all closed. + +# The implementation uses reference counting to keep the socket open +# until the last client calls close(). SharedSocket keeps track of +# the reference counting and SharedSocketClient provides an constructor +# and close() method that call incref() and decref() correctly. + +class SharedSocket: + + def __init__(self, sock): + self.sock = sock + self._refcnt = 0 + + def incref(self): + self._refcnt += 1 + + def decref(self): + self._refcnt -= 1 + assert self._refcnt >= 0 + if self._refcnt == 0: + self.sock.close() + + def __del__(self): + self.sock.close() + +class SharedSocketClient: + + def __init__(self, shared): + self._closed = 0 + self._shared = shared + self._shared.incref() + self._sock = shared.sock + + def close(self): + if not self._closed: + self._shared.decref() + self._closed = 1 + self._shared = None + +class SSLFile(SharedSocketClient): """File-like object wrapping an SSL socket.""" BUFSIZE = 8192 def __init__(self, sock, ssl, bufsize=None): - self._sock = sock + SharedSocketClient.__init__(self, sock) self._ssl = ssl self._buf = '' self._bufsize = bufsize or self.__class__.BUFSIZE @@ -829,30 +875,36 @@ def readline(self): self._buf = all[i:] return line - def close(self): - self._sock.close() +class FakeSocket(SharedSocketClient): + + class _closedsocket: + def __getattr__(self, name): + raise error(9, 'Bad file descriptor') -class FakeSocket: def __init__(self, sock, ssl): - self.__sock = sock - self.__ssl = ssl + sock = SharedSocket(sock) + SharedSocketClient.__init__(self, sock) + self._ssl = ssl + + def close(self): + SharedSocketClient.close(self) + self._sock = self.__class__._closedsocket() def makefile(self, mode, bufsize=None): if mode != 'r' and mode != 'rb': raise UnimplementedFileMode() - return SSLFile(self.__sock, self.__ssl, bufsize) + return SSLFile(self._shared, self._ssl, bufsize) def send(self, stuff, flags = 0): - return self.__ssl.write(stuff) + return self._ssl.write(stuff) - def sendall(self, stuff, flags = 0): - return self.__ssl.write(stuff) + sendall = send def recv(self, len = 1024, flags = 0): - return self.__ssl.read(len) + return self._ssl.read(len) def __getattr__(self, attr): - return getattr(self.__sock, attr) + return getattr(self._sock, attr) class HTTPSConnection(HTTPConnection): @@ -1101,15 +1153,11 @@ def readlines(self, size=None): else: return L + self._file.readlines(size) -# -# snarfed from httplib.py for now... -# def test(): """Test this module. - The test consists of retrieving and displaying the Python - home page, along with the error code and error string returned - by the www.python.org server. + A hodge podge of tests collected here, because they have too many + external dependencies for the regular test suite. """ import sys @@ -1130,11 +1178,11 @@ def test(): status, reason, headers = h.getreply() print 'status =', status print 'reason =', reason + print "read", len(h.getfile().read()) print if headers: for header in headers.headers: print header.strip() print - print "read", len(h.getfile().read()) # minimal test that code to extract host from url works class HTTP11(HTTP): @@ -1148,22 +1196,26 @@ class HTTP11(HTTP): h.close() if hasattr(socket, 'ssl'): - host = 'sourceforge.net' - selector = '/projects/python' - hs = HTTPS() - hs.connect(host) - hs.putrequest('GET', selector) - hs.endheaders() - status, reason, headers = hs.getreply() - # XXX why does this give a 302 response? - print 'status =', status - print 'reason =', reason - print - if headers: - for header in headers.headers: print header.strip() - print - print "read", len(hs.getfile().read()) + + for host, selector in (('sourceforge.net', '/projects/python'), + ('dbserv2.theopalgroup.com', '/mediumfile'), + ('dbserv2.theopalgroup.com', '/smallfile'), + ): + print "https://%s%s" % (host, selector) + hs = HTTPS() + hs.connect(host) + hs.putrequest('GET', selector) + hs.endheaders() + status, reason, headers = hs.getreply() + print 'status =', status + print 'reason =', reason + print "read", len(hs.getfile().read()) + print + if headers: + for header in headers.headers: print header.strip() + print + return # Test a buggy server -- returns garbled status line. # http://www.yahoo.com/promotions/mom_com97/supermom.html