diff --git a/netlib/certutils.py b/netlib/certutils.py index b9c291d09..fafcb5fd3 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -4,6 +4,7 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL import tcp +import UserDict DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 # Generated with "openssl dhparam". It's too slow to generate this on startup. @@ -42,11 +43,11 @@ def create_ca(o, cn, exp): return key, cert -def dummy_cert(pkey, cacert, commonname, sans): +def dummy_cert(privkey, cacert, commonname, sans): """ Generates a dummy certificate. - pkey: CA private key + privkey: CA private key cacert: CA certificate commonname: Common name for the generated certificate. sans: A list of Subject Alternate Names. @@ -68,17 +69,55 @@ def dummy_cert(pkey, cacert, commonname, sans): cert.set_version(2) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.set_pubkey(cacert.get_pubkey()) - cert.sign(pkey, "sha1") + cert.sign(privkey, "sha1") return SSLCert(cert) +class _Node(UserDict.UserDict): + def __init__(self): + UserDict.UserDict.__init__(self) + self.value = None + + +class DNTree: + """ + Domain store that knows about wildcards. DNS wildcards are very + restricted - the only valid variety is an asterisk on the left-most + domain component, i.e.: + + *.foo.com + """ + def __init__(self): + self.d = _Node() + + def add(self, dn, cert): + parts = dn.split(".") + parts.reverse() + current = self.d + for i in parts: + current = current.setdefault(i, _Node()) + current.value = cert + + def get(self, dn): + parts = dn.split(".") + current = self.d + for i in reversed(parts): + if i in current: + current = current[i] + elif "*" in current: + return current["*"].value + else: + return None + return current.value + + class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, pkey, cert): - self.pkey, self.cert = pkey, cert - self.certs = {} + def __init__(self, privkey, cacert): + self.privkey, self.cacert = privkey, cacert + self.certs = DNTree() @classmethod def from_store(klass, path, basename): @@ -130,9 +169,29 @@ class CertStore: f.close() return key, ca + def add_cert_file(self, commonname, path): + raw = file(path, "rb").read() + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) + try: + privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + except Exception: + privkey = None + self.add_cert(SSLCert(cert), privkey, commonname) + + def add_cert(self, cert, privkey, *names): + """ + Adds a cert to the certstore. We register the CN in the cert plus + any SANs, and also the list of names provided as an argument. + """ + self.certs.add(cert.cn, (cert, privkey)) + for i in cert.altnames: + self.certs.add(i, (cert, privkey)) + for i in names: + self.certs.add(i, (cert, privkey)) + def get_cert(self, commonname, sans): """ - Returns an SSLCert object. + Returns an (cert, privkey) tuple. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. @@ -141,11 +200,12 @@ class CertStore: Return None if the certificate could not be found or generated. """ - if commonname in self.certs: - return self.certs[commonname] - c = dummy_cert(self.pkey, self.cert, commonname, sans) - self.certs[commonname] = c - return c + c = self.certs.get(commonname) + if not c: + c = dummy_cert(self.privkey, self.cacert, commonname, sans) + self.add_cert(c, None) + c = (c, None) + return (c[0], c[1] or self.privkey) class _GeneralName(univ.Choice): @@ -171,6 +231,9 @@ class SSLCert: """ self.x509 = cert + def __eq__(self, other): + return self.digest("sha1") == other.digest("sha1") + @classmethod def from_pem(klass, txt): x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) diff --git a/test/test_certutils.py b/test/test_certutils.py index f741bdece..7f320e7ef 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -1,7 +1,37 @@ import os from netlib import certutils +import OpenSSL import tutils +class TestDNTree: + def test_simple(self): + d = certutils.DNTree() + d.add("foo.com", "foo") + d.add("bar.com", "bar") + assert d.get("foo.com") == "foo" + assert d.get("bar.com") == "bar" + assert not d.get("oink.com") + assert not d.get("oink") + assert not d.get("") + assert not d.get("oink.oink") + + d.add("*.match.org", "match") + assert not d.get("match.org") + assert d.get("foo.match.org") == "match" + assert d.get("foo.foo.match.org") == "match" + + def test_wildcard(self): + d = certutils.DNTree() + d.add("foo.com", "foo") + assert not d.get("*.foo.com") + d.add("*.foo.com", "wild") + + d = certutils.DNTree() + d.add("*", "foo") + assert d.get("foo.com") == "foo" + assert d.get("*.foo.com") == "foo" + assert d.get("com") == "foo" + class TestCertStore: def test_create_explicit(self): @@ -12,7 +42,7 @@ class TestCertStore: ca2 = certutils.CertStore.from_store(d, "test") assert ca2.get_cert("foo", []) - assert ca.cert.get_serial_number() == ca2.cert.get_serial_number() + assert ca.cacert.get_serial_number() == ca2.cacert.get_serial_number() def test_create_tmp(self): with tutils.tmpdir() as d: @@ -21,14 +51,46 @@ class TestCertStore: assert ca.get_cert("foo.com", []) assert ca.get_cert("*.foo.com", []) + r = ca.get_cert("*.foo.com", []) + assert r[1] == ca.privkey + + def test_add_cert(self): + with tutils.tmpdir() as d: + ca = certutils.CertStore.from_store(d, "test") + + def test_sans(self): + with tutils.tmpdir() as d: + ca = certutils.CertStore.from_store(d, "test") + c1 = ca.get_cert("foo.com", ["*.bar.com"]) + c2 = ca.get_cert("foo.bar.com", []) + assert c1 == c2 + c3 = ca.get_cert("bar.com", []) + assert not c1 == c3 + + def test_overrides(self): + with tutils.tmpdir() as d: + ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") + ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") + assert not ca1.cacert.get_serial_number() == ca2.cacert.get_serial_number() + + dc = ca2.get_cert("foo.com", []) + dcp = os.path.join(d, "dc") + f = open(dcp, "wb") + f.write(dc[0].to_pem()) + f.close() + ca1.add_cert_file("foo.com", dcp) + + ret = ca1.get_cert("foo.com", []) + assert ret[0].serial == dc[0].serial + class TestDummyCert: def test_with_ca(self): with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") r = certutils.dummy_cert( - ca.pkey, - ca.cert, + ca.privkey, + ca.cacert, "foo.com", ["one.com", "two.com", "*.three.com"] )