Certificate flags
This commit is contained in:
parent
2a12aa3c47
commit
f5cc63d653
|
@ -7,4 +7,5 @@ MANIFEST
|
|||
*.swp
|
||||
*.swo
|
||||
.coverage
|
||||
.idea
|
||||
.idea
|
||||
__pycache__
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import cffi
|
||||
import OpenSSL
|
||||
xffi = cffi.FFI()
|
||||
xffi.cdef ("""
|
||||
struct rsa_meth_st {
|
||||
int flags;
|
||||
...;
|
||||
};
|
||||
struct rsa_st {
|
||||
int pad;
|
||||
long version;
|
||||
struct rsa_meth_st *meth;
|
||||
...;
|
||||
};
|
||||
""")
|
||||
xffi.verify(
|
||||
"""#include <openssl/rsa.h>""",
|
||||
extra_compile_args=['-w']
|
||||
)
|
||||
|
||||
def handle(privkey):
|
||||
new = xffi.new("struct rsa_st*")
|
||||
newbuf = xffi.buffer(new)
|
||||
rsa = OpenSSL.SSL._lib.EVP_PKEY_get1_RSA(privkey._pkey)
|
||||
oldbuf = OpenSSL.SSL._ffi.buffer(rsa)
|
||||
newbuf[:] = oldbuf[:]
|
||||
return new
|
||||
|
||||
def set_flags(privkey, val):
|
||||
hdl = handle(privkey)
|
||||
hdl.meth.flags = val
|
||||
return privkey
|
||||
|
||||
def get_flags(privkey):
|
||||
hdl = handle(privkey)
|
||||
return hdl.meth.flags
|
|
@ -111,6 +111,7 @@ class DNTree:
|
|||
return current.value
|
||||
|
||||
|
||||
|
||||
class CertStore:
|
||||
"""
|
||||
Implements an in-memory certificate store.
|
||||
|
@ -222,6 +223,11 @@ class CertStore:
|
|||
c = (c, None)
|
||||
return (c[0], c[1] or self.privkey)
|
||||
|
||||
def gen_pkey(self, cert):
|
||||
import certffi
|
||||
certffi.set_flags(self.privkey, 1)
|
||||
return self.privkey
|
||||
|
||||
|
||||
class _GeneralName(univ.Choice):
|
||||
# We are only interested in dNSNames. We use a default handler to ignore
|
||||
|
@ -326,6 +332,7 @@ class SSLCert:
|
|||
return altnames
|
||||
|
||||
|
||||
|
||||
def get_remote_cert(host, port, sni):
|
||||
c = tcp.TCPClient((host, port))
|
||||
c.connect()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from netlib import certutils
|
||||
from netlib import certutils, certffi
|
||||
import OpenSSL
|
||||
import tutils
|
||||
|
||||
|
@ -83,6 +83,16 @@ class TestCertStore:
|
|||
ret = ca1.get_cert("foo.com", [])
|
||||
assert ret[0].serial == dc[0].serial
|
||||
|
||||
def test_gen_pkey(self):
|
||||
try:
|
||||
with tutils.tmpdir() as d:
|
||||
ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
|
||||
ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
|
||||
cert = ca1.get_cert("foo.com", [])
|
||||
assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1
|
||||
finally:
|
||||
certffi.set_flags(ca2.privkey, 0)
|
||||
|
||||
|
||||
class TestDummyCert:
|
||||
def test_with_ca(self):
|
||||
|
@ -125,3 +135,5 @@ class TestSSLCert:
|
|||
d = file(tutils.test_data.path("data/dercert"),"rb").read()
|
||||
s = certutils.SSLCert.from_der(d)
|
||||
assert s.cn
|
||||
|
||||
|
||||
|
|
127
test/test_tcp.py
127
test/test_tcp.py
|
@ -4,16 +4,6 @@ import mock
|
|||
import tutils
|
||||
from OpenSSL import SSL
|
||||
|
||||
class SNIHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(self.sni)
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class EchoHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle_sni(self, connection):
|
||||
|
@ -25,58 +15,19 @@ class EchoHandler(tcp.BaseHandler):
|
|||
self.wfile.flush()
|
||||
|
||||
|
||||
class ClientPeernameHandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
self.wfile.write(str(self.connection.getpeername()))
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class CertHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write("%s\n"%self.clientcert.serial)
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class ClientCipherListHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write("%s"%self.connection.get_cipher_list())
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class CurrentCipherHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle(self):
|
||||
self.wfile.write("%s"%str(self.get_current_cipher()))
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class DisconnectHandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
self.close()
|
||||
|
||||
|
||||
class HangHandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
while 1:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
class TimeoutHandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
self.timeout = False
|
||||
self.settimeout(0.01)
|
||||
try:
|
||||
self.rfile.read(10)
|
||||
except tcp.NetLibTimeout:
|
||||
self.timeout = True
|
||||
|
||||
|
||||
class TestServer(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
def test_echo(self):
|
||||
|
@ -89,7 +40,10 @@ class TestServer(test.ServerTestBase):
|
|||
|
||||
|
||||
class TestServerBind(test.ServerTestBase):
|
||||
handler = ClientPeernameHandler
|
||||
class handler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
self.wfile.write(str(self.connection.getpeername()))
|
||||
self.wfile.flush()
|
||||
|
||||
def test_bind(self):
|
||||
""" Test to bind to a given random port. Try again if the random port turned out to be blocked. """
|
||||
|
@ -198,7 +152,14 @@ class TestSSLv3Only(test.ServerTestBase):
|
|||
|
||||
|
||||
class TestSSLClientCert(test.ServerTestBase):
|
||||
handler = CertHandler
|
||||
class handler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write("%s\n"%self.clientcert.serial)
|
||||
self.wfile.flush()
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
|
@ -222,7 +183,15 @@ class TestSSLClientCert(test.ServerTestBase):
|
|||
|
||||
|
||||
class TestSNI(test.ServerTestBase):
|
||||
handler = SNIHandler
|
||||
class handler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(self.sni)
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
|
@ -254,7 +223,11 @@ class TestServerCipherList(test.ServerTestBase):
|
|||
|
||||
|
||||
class TestServerCurrentCipher(test.ServerTestBase):
|
||||
handler = CurrentCipherHandler
|
||||
class handler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle(self):
|
||||
self.wfile.write("%s"%str(self.get_current_cipher()))
|
||||
self.wfile.flush()
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
|
@ -300,7 +273,9 @@ class TestClientCipherListError(test.ServerTestBase):
|
|||
|
||||
|
||||
class TestSSLDisconnect(test.ServerTestBase):
|
||||
handler = DisconnectHandler
|
||||
class handler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
self.close()
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
|
@ -329,7 +304,15 @@ class TestDisconnect(test.ServerTestBase):
|
|||
|
||||
|
||||
class TestServerTimeOut(test.ServerTestBase):
|
||||
handler = TimeoutHandler
|
||||
class handler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
self.timeout = False
|
||||
self.settimeout(0.01)
|
||||
try:
|
||||
self.rfile.read(10)
|
||||
except tcp.NetLibTimeout:
|
||||
self.timeout = True
|
||||
|
||||
def test_timeout(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
|
@ -383,6 +366,40 @@ class TestDHParams(test.ServerTestBase):
|
|||
assert ret[0] == "DHE-RSA-AES256-SHA"
|
||||
|
||||
|
||||
|
||||
class TestPrivkeyGen(test.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca1 = certutils.CertStore.from_store(d, "test2")
|
||||
ca2 = certutils.CertStore.from_store(d, "test3")
|
||||
cert, _ = ca1.get_cert("foo.com", [])
|
||||
key = ca2.gen_pkey(cert)
|
||||
self.convert_to_ssl(cert, key)
|
||||
|
||||
def test_privkey(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
tutils.raises("bad record mac", c.convert_to_ssl)
|
||||
|
||||
|
||||
class TestPrivkeyGenNoFlags(test.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca1 = certutils.CertStore.from_store(d, "test2")
|
||||
ca2 = certutils.CertStore.from_store(d, "test3")
|
||||
cert, _ = ca1.get_cert("foo.com", [])
|
||||
certffi.set_flags(ca2.privkey, 0)
|
||||
self.convert_to_ssl(cert, ca2.privkey)
|
||||
|
||||
def test_privkey(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
tutils.raises("unexpected eof", c.convert_to_ssl)
|
||||
|
||||
|
||||
|
||||
class TestTCPClient:
|
||||
def test_conerr(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
|
|
Loading…
Reference in New Issue