remove subclassing of tuple in tcp.Address, move StateObject into netlib
This commit is contained in:
parent
e18ac4b672
commit
ff9656be80
|
@ -3,6 +3,7 @@ from pyasn1.type import univ, constraint, char, namedtype, tag
|
|||
from pyasn1.codec.der.decoder import decode
|
||||
from pyasn1.error import PyAsn1Error
|
||||
import OpenSSL
|
||||
from netlib.stateobject import StateObject
|
||||
import tcp
|
||||
|
||||
default_exp = 62208000 # =24 * 60 * 60 * 720
|
||||
|
@ -152,13 +153,22 @@ class _GeneralNames(univ.SequenceOf):
|
|||
sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024)
|
||||
|
||||
|
||||
class SSLCert:
|
||||
class SSLCert(StateObject):
|
||||
def __init__(self, cert):
|
||||
"""
|
||||
Returns a (common name, [subject alternative names]) tuple.
|
||||
"""
|
||||
self.x509 = cert
|
||||
|
||||
def _get_state(self):
|
||||
return self.to_pem()
|
||||
|
||||
def _load_state(self, state):
|
||||
self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state)
|
||||
|
||||
def _from_state(cls, state):
|
||||
return cls.from_pem(state)
|
||||
|
||||
@classmethod
|
||||
def from_pem(klass, txt):
|
||||
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import re, copy
|
||||
from netlib.stateobject import StateObject
|
||||
|
||||
|
||||
def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||
"""
|
||||
|
@ -9,7 +11,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
|
|||
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
|
||||
|
||||
|
||||
class ODict:
|
||||
class ODict(StateObject):
|
||||
"""
|
||||
A dictionary-like object for managing ordered (key, value) data.
|
||||
"""
|
||||
|
@ -98,6 +100,9 @@ class ODict:
|
|||
def _get_state(self):
|
||||
return [tuple(i) for i in self.lst]
|
||||
|
||||
def _load_state(self, state):
|
||||
self.list = [list(i) for i in state]
|
||||
|
||||
@classmethod
|
||||
def _from_state(klass, state):
|
||||
return klass([list(i) for i in state])
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
from types import ClassType
|
||||
|
||||
|
||||
class StateObject:
|
||||
def _get_state(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_state(self, state):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _from_state(cls, state):
|
||||
raise NotImplementedError
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
return self._get_state() == other._get_state()
|
||||
except AttributeError: # we may compare with something that's not a StateObject
|
||||
return False
|
||||
|
||||
|
||||
class SimpleStateObject(StateObject):
|
||||
"""
|
||||
A StateObject with opionated conventions that tries to keep everything DRY.
|
||||
|
||||
Simply put, you agree on a list of attributes and their type.
|
||||
Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves.
|
||||
SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods.
|
||||
Overriding _get_state or _load_state to add custom adjustments is always possible.
|
||||
"""
|
||||
|
||||
_stateobject_attributes = None # none by default to raise an exception if definition was forgotten
|
||||
"""
|
||||
An attribute-name -> class-or-type dict containing all attributes that should be serialized
|
||||
If the attribute is a class, this class must be a subclass of StateObject.
|
||||
"""
|
||||
|
||||
def _get_state(self):
|
||||
return {attr: self.__get_state_attr(attr, cls)
|
||||
for attr, cls in self._stateobject_attributes.iteritems()}
|
||||
|
||||
def __get_state_attr(self, attr, cls):
|
||||
"""
|
||||
helper for _get_state.
|
||||
returns the value of the given attribute
|
||||
"""
|
||||
if getattr(self, attr) is None:
|
||||
return None
|
||||
if isinstance(cls, ClassType):
|
||||
return getattr(self, attr)._get_state()
|
||||
else:
|
||||
return getattr(self, attr)
|
||||
|
||||
def _load_state(self, state):
|
||||
for attr, cls in self._stateobject_attributes.iteritems():
|
||||
self.__load_state_attr(attr, cls, state)
|
||||
|
||||
def __load_state_attr(self, attr, cls, state):
|
||||
"""
|
||||
helper for _load_state.
|
||||
loads the given attribute from the state.
|
||||
"""
|
||||
if state[attr] is not None: # First, catch None as value.
|
||||
if isinstance(cls, ClassType): # Is the attribute a StateObject itself?
|
||||
assert issubclass(cls, StateObject)
|
||||
curr = getattr(self, attr)
|
||||
if curr: # if the attribute is already present, delegate to the objects ._load_state method.
|
||||
curr._load_state(state[attr])
|
||||
else: # otherwise, create a new object.
|
||||
setattr(self, attr, cls._from_state(state[attr]))
|
||||
else:
|
||||
setattr(self, attr, cls(state[attr]))
|
||||
else:
|
||||
setattr(self, attr, None)
|
||||
|
||||
@classmethod
|
||||
def _from_state(cls, state):
|
||||
f = cls() # the default implementation assumes an empty constructor. Override accordingly.
|
||||
f._load_state(state)
|
||||
return f
|
|
@ -1,6 +1,7 @@
|
|||
import select, socket, threading, sys, time, traceback
|
||||
from OpenSSL import SSL
|
||||
import certutils
|
||||
from netlib.stateobject import StateObject
|
||||
|
||||
SSLv2_METHOD = SSL.SSLv2_METHOD
|
||||
SSLv3_METHOD = SSL.SSLv3_METHOD
|
||||
|
@ -173,14 +174,13 @@ class Reader(_FileLike):
|
|||
return result
|
||||
|
||||
|
||||
class Address(tuple):
|
||||
class Address(StateObject):
|
||||
"""
|
||||
This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
|
||||
"""
|
||||
def __new__(cls, address, use_ipv6=False):
|
||||
a = super(Address, cls).__new__(cls, tuple(address))
|
||||
a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET
|
||||
return a
|
||||
def __init__(self, address, use_ipv6=False):
|
||||
self.address = address
|
||||
self.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET
|
||||
|
||||
@classmethod
|
||||
def wrap(cls, t):
|
||||
|
@ -189,18 +189,35 @@ class Address(tuple):
|
|||
else:
|
||||
return cls(t)
|
||||
|
||||
def __call__(self):
|
||||
return self.address
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self[0]
|
||||
return self.address[0]
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self[1]
|
||||
return self.address[1]
|
||||
|
||||
@property
|
||||
def is_ipv6(self):
|
||||
def use_ipv6(self):
|
||||
return self.family == socket.AF_INET6
|
||||
|
||||
def _load_state(self, state):
|
||||
self.address = state["address"]
|
||||
self.family = socket.AF_INET6 if state["use_ipv6"] else socket.AF_INET
|
||||
|
||||
def _get_state(self):
|
||||
return dict(
|
||||
address=self.address,
|
||||
use_ipv6=self.use_ipv6
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_state(cls, state):
|
||||
return cls(**state)
|
||||
|
||||
|
||||
class SocketCloseMixin:
|
||||
def finish(self):
|
||||
|
@ -240,7 +257,7 @@ class TCPClient(SocketCloseMixin):
|
|||
wbufsize = -1
|
||||
def __init__(self, address, source_address=None):
|
||||
self.address = Address.wrap(address)
|
||||
self.source_address = source_address
|
||||
self.source_address = Address.wrap(source_address) if source_address else None
|
||||
self.connection, self.rfile, self.wfile = None, None, None
|
||||
self.cert = None
|
||||
self.ssl_established = False
|
||||
|
@ -275,12 +292,12 @@ class TCPClient(SocketCloseMixin):
|
|||
try:
|
||||
connection = socket.socket(self.address.family, socket.SOCK_STREAM)
|
||||
if self.source_address:
|
||||
connection.bind(self.source_address)
|
||||
connection.connect(self.address)
|
||||
connection.bind(self.source_address())
|
||||
connection.connect(self.address())
|
||||
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
|
||||
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
|
||||
except (socket.error, IOError), err:
|
||||
raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err))
|
||||
raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err))
|
||||
self.connection = connection
|
||||
|
||||
def settimeout(self, n):
|
||||
|
@ -376,7 +393,7 @@ class TCPServer:
|
|||
self.__shutdown_request = False
|
||||
self.socket = socket.socket(self.address.family, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.bind(self.address)
|
||||
self.socket.bind(self.address())
|
||||
self.address = Address.wrap(self.socket.getsockname())
|
||||
self.socket.listen(self.request_queue_size)
|
||||
|
||||
|
@ -427,7 +444,7 @@ class TCPServer:
|
|||
if traceback:
|
||||
exc = traceback.format_exc()
|
||||
print >> fp, '-'*40
|
||||
print >> fp, "Error in processing of request from %s:%s"%client_address
|
||||
print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port)
|
||||
print >> fp, exc
|
||||
print >> fp, '-'*40
|
||||
|
||||
|
|
Loading…
Reference in New Issue