diff --git a/netlib/certutils.py b/netlib/certutils.py index 6c9a5c57b..180e1ac07 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -2,6 +2,7 @@ import os, ssl, hashlib, socket, time, datetime from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode import OpenSSL +import tcp CERT_SLEEP_TIME = 1 CERT_EXPIRY = str(365 * 3) @@ -218,7 +219,8 @@ class SSLCert: return altnames -def get_remote_cert(host, port): # pragma: no cover - addr = socket.gethostbyname(host) - s = ssl.get_server_certificate((addr, port)) - return SSLCert(s) +def get_remote_cert(host, port, sni): + c = tcp.TCPClient(host, port) + c.connect() + c.convert_to_ssl(sni=sni) + return c.cert diff --git a/netlib/tcp.py b/netlib/tcp.py index ef3298d59..6c5b49769 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,5 +1,6 @@ import select, socket, threading, traceback, sys from OpenSSL import SSL +import certutils class NetLibError(Exception): pass diff --git a/test/test_tcp.py b/test/test_tcp.py index a2ee5e368..969daf1ec 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,5 +1,5 @@ import cStringIO, threading, Queue -from netlib import tcp +from netlib import tcp, certutils import tutils class ServerThread(threading.Thread): @@ -110,6 +110,9 @@ class TestServerSSL(ServerTestBase): c.wfile.flush() assert c.rfile.readline() == testval + def test_get_remote_cert(self): + assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") + class TestSNI(ServerTestBase): @classmethod