Rework function calls flow to be thread-safe in pyuvproxy

This commit is contained in:
Oleksii Shevchuk 2017-03-12 19:08:03 +02:00
parent 2fb25d0a70
commit 2a8e965996
1 changed files with 159 additions and 100 deletions

View File

@ -4,6 +4,9 @@ import rpyc
import sys, time
import pyuv
import struct
import os
os.putenv('UV_THREADPOOL_SIZE', '1')
from netaddr import IPAddress, AddrFormatError
from threading import Event, Thread, Lock
@ -15,6 +18,8 @@ from socket import SHUT_RD, SHUT_WR
from socket import error as socket_error
from socket import inet_ntop
from Queue import Queue, Empty
import socket
import random
@ -71,7 +76,7 @@ class ChannelIsNotReady(ValueError):
pass
class Connection(object):
def __init__(self, neighbor, remote_id=None, socket=None, buffer=None, socks5=False):
def __init__(self, neighbor, remote_id=None, socket=None, buffer=None, socks5=False, timeout=5):
self.neighbor = neighbor
self.loop = self.neighbor.manager.loop
self.socket = socket or pyuv.TCP(self.loop)
@ -80,39 +85,48 @@ class Connection(object):
self.remote_local_address = None
self.buffer = None
self.socks5 = socks5
self.timer = pyuv.Timer(self.loop)
self.timeout = timeout
def _connection_timeout(self, handle):
try:
handle.stop()
except:
pass
self.close(-1, mutual=self.remote_id is not None)
def register_remote_id(self, remote_id):
self.remote_id = remote_id
def on_connected(self, local_address, error):
if error:
if self.socks5:
self.socket.write(
struct.pack(
'BB', 0x5, ERRNO_TO_SOCKS5.get(reason, CODE_GENERAL_SRV_FAILURE)
) + self.socks5[2:])
self.close(error, mutual=False)
else:
if self.socks5:
try:
addr, port = IPAddress(local_address[0]), local_address[1]
try:
if self.socks5:
self.socket.write(
struct.pack(
'BBBB', 0x5,
0, 0,
ADDR_IPV4 if addr.version == 4 else ADDR_IPV6
) + addr.packed + struct.pack('>H', port)
)
'BB', 0x5, ERRNO_TO_SOCKS5.get(error, CODE_GENERAL_SRV_FAILURE)
) + self.socks5[2:])
finally:
self.close(error, mutual=False)
else:
try:
if self.socks5:
addr, port = IPAddress(local_address[0]), local_address[1]
self.socket.write(
struct.pack(
'BBBB', 0x5,
0, 0,
ADDR_IPV4 if addr.version == 4 else ADDR_IPV6
) + addr.packed + struct.pack('>H', port))
except Exception, e:
logging.debug('SOCKS5 response failed: {}'.format(e))
if self.buffer:
self._on_read_data(self.socket, self.buffer, None)
if self.buffer:
self._on_read_data(self.socket, self.buffer, None)
self.forward()
self.forward()
except:
self.close(-1)
def on_data(self, data):
if not self.socket:
@ -143,20 +157,14 @@ class Connection(object):
else:
self.close(error)
def _report_close_remote(self, reason):
self.socket.close()
try:
self.neighbor.callbacks.on_disconnect(
self.neighbor.remote_id,
self.remote_id,
reason
)
except EOFError:
self.neighbor.stop(dead=True)
def _on_connected(self, handle, error):
try:
self.timer.stop()
self.timer.close()
self.timer = None
except:
pass
try:
self.neighbor.callbacks.on_connected(
self.neighbor.remote_id,
@ -165,41 +173,55 @@ class Connection(object):
error=error
)
if error:
try:
self.socket.close()
except:
pass
else:
self.forward()
except EOFError:
self.neighbor.stop(dead=True)
if error:
self.socket.close()
else:
self.forward()
def _start_connect(self, address):
def connect(self, address):
try:
self.timer.start(self._connection_timeout, self.timeout, 0)
self.socket.connect(address, self._on_connected)
except:
except Exception, e:
self._on_connected(None, -1)
def connect(self, address):
self.loop.queue_work(lambda: self._start_connect(address))
def forward(self):
self.loop.queue_work(lambda: self.socket.start_read(self._on_read_data))
self.socket.start_read(self._on_read_data)
def close(self, reason, mutual=True):
unregistered = False
if mutual:
try:
self.loop.queue_work(lambda: self.socket.shutdown(
lambda handle, error: self._report_close_remote(reason)))
except:
self.socket.close()
else:
self.neighbor.callbacks.on_disconnect(
self.neighbor.remote_id,
self.remote_id,
reason
)
except EOFError:
self.neighbor.stop(dead=True)
unregistered = True
try:
self.socket.close()
except:
pass
self.neighbor.unregister_connection(self)
def read_exactly(self, size, callback):
pass
try:
if self.timer:
self.timer.stop()
except:
pass
if not unregistered:
self.neighbor.unregister_connection(self)
class Acceptor(object):
def __init__(self, neighbor, local_address, forward_address=None):
@ -209,7 +231,9 @@ class Acceptor(object):
self.forward_address = forward_address
self.associaction = {}
self.socket = pyuv.TCP(self.loop)
self.socket.bind(local_address)
def start(self):
self.socket.bind(self.local_address)
self.socket.listen(self._on_connection)
def _on_connection(self, handle, error):
@ -396,7 +420,10 @@ class Neighbor(object):
connection
].close(-1, mutual=not dead)
del self.manager.neighbors[self.local_id]
try:
del self.manager.neighbors[self.local_id]
except KeyError:
pass
def pair(self, local_id, remote_id):
self.local_id = local_id
@ -419,7 +446,8 @@ class Neighbor(object):
self.connections[connection.local_id] = connection
def unregister_connection(self, connection):
del self.connections[connection.local_id]
if connection.local_id in self.connections:
del self.connections[connection.local_id]
def register_acceptor(self, acceptor, port):
self.acceptors[port] = acceptor
@ -446,31 +474,47 @@ class Neighbor(object):
class Manager(Thread):
def __init__(self):
super(Manager, self).__init__()
self.loop = pyuv.Loop.default_loop()
self.loop = pyuv.Loop()
self.neighbors = {}
self.wakeup = Event()
self.stopped = Event()
self.ports = {}
self.daemon = True
self.wake = pyuv.Async(self.loop, self.sync)
self.queue = Queue()
def stop(self, dead=False):
def sync(self, handle):
while True:
try:
method, args = self.queue.get_nowait()
except Empty:
break
try:
method(*args)
except Exception, e:
logging.exception('Defered call exception: {}'.format(e))
def defer(self, method, *args):
self.queue.put((method, args))
self.wake.send()
def _stop(self, dead):
for neighbor_id in self.neighbors.keys():
self.neighbors[neighbor_id].stop(dead=dead)
self.wake.close()
for handle in self.loop.handles:
if not handle.closed:
handle.close()
self.stopped.set()
self.wakeup.set()
self.loop.stop()
def stop(self, dead=False):
self.defer(self._stop, dead)
def force_stop(self):
self.stop(dead=True)
def run(self):
while not self.stopped.is_set():
self.wakeup.wait()
if self.stopped.is_set():
break
self.wakeup.clear()
self.loop.run()
self.loop.run()
def get_neighbor(self, neighbor_id):
if not neighbor_id in self.neighbors:
@ -478,7 +522,7 @@ class Manager(Thread):
return self.neighbors[neighbor_id]
def bind(self, neighbor_id, host='127.0.0.1', port=8080, forward=None):
def _bind(self, neighbor_id, host, port, forward):
neighbor = self.get_neighbor(neighbor_id)
acceptor = Acceptor(
neighbor,
@ -487,8 +531,10 @@ class Manager(Thread):
)
neighbor.register_acceptor(acceptor, port)
acceptor.start()
self.wakeup.set()
def bind(self, neighbor_id, host='127.0.0.1', port=8080, forward=None):
self.defer(self._bind, neighbor_id, host, port, forward)
def unbind(self, port):
for neighbor in self.neighbors.itervalues():
@ -506,36 +552,43 @@ class Manager(Thread):
).create_connection(remote_id=remote_id)
def connect(self, neighbor_id, connection_id, address):
self.get_neighbor(
neighbor_id
).get_connection(
connection_id
).connect(address)
self.wakeup.set()
self.defer(
self.get_neighbor(
neighbor_id
).get_connection(
connection_id
).connect,
address
)
def forward(self, neighbor_id, connection_id):
self.get_neighbor(
neighbor_id
).get_connection(
connection_id
).forward()
self.wakeup.set()
self.defer(
self.get_neighbor(
neighbor_id
).get_connection(
connection_id
).forward
)
def on_connected(self, neighbor_id, connection_id, local_address, error=None):
connection = self.get_neighbor(
neighbor_id
).get_connection(
connection_id
).on_connected(local_address, error)
self.defer(
self.get_neighbor(
neighbor_id
).get_connection(
connection_id
).on_connected,
local_address, error
)
def on_data(self, neighbor_id, connection_id, data):
self.get_neighbor(
neighbor_id
).get_connection(
connection_id
).on_data(data)
self.defer(
self.get_neighbor(
neighbor_id
).get_connection(
connection_id
).on_data,
data
)
def on_disconnect(self, neighbor_id, connection_id, reason=None):
neighbor = self.get_neighbor(
@ -547,7 +600,10 @@ class Manager(Thread):
connection_id
)
connection.on_disconnect(reason)
self.defer(
connection.on_disconnect,
reason
)
except (ConnectionIsNotExists, ChannelIsNotReady):
pass
@ -581,12 +637,15 @@ class Manager(Thread):
return remote_id, local_id
def unpair(self, local_id, dead=False):
def _unpair(self, local_id, dead):
if not local_id in self.neighbors:
raise NeighborIsNotExists(local_id)
self.neighbors[local_id].stop(dead=dead)
def unpair(self, local_id, dead=False):
self.defer(self._unpair, local_id, dead)
def list(self, filter_by_local_id=None):
results = []
if filter_by_local_id: