mirror of https://github.com/n1nj4sec/pupy.git
Rework function calls flow to be thread-safe in pyuvproxy
This commit is contained in:
parent
2fb25d0a70
commit
2a8e965996
|
@ -4,6 +4,9 @@ import rpyc
|
|||
import sys, time
|
||||
import pyuv
|
||||
import struct
|
||||
import os
|
||||
|
||||
os.putenv('UV_THREADPOOL_SIZE', '1')
|
||||
|
||||
from netaddr import IPAddress, AddrFormatError
|
||||
from threading import Event, Thread, Lock
|
||||
|
@ -15,6 +18,8 @@ from socket import SHUT_RD, SHUT_WR
|
|||
from socket import error as socket_error
|
||||
from socket import inet_ntop
|
||||
|
||||
from Queue import Queue, Empty
|
||||
|
||||
import socket
|
||||
|
||||
import random
|
||||
|
@ -71,7 +76,7 @@ class ChannelIsNotReady(ValueError):
|
|||
pass
|
||||
|
||||
class Connection(object):
|
||||
def __init__(self, neighbor, remote_id=None, socket=None, buffer=None, socks5=False):
|
||||
def __init__(self, neighbor, remote_id=None, socket=None, buffer=None, socks5=False, timeout=5):
|
||||
self.neighbor = neighbor
|
||||
self.loop = self.neighbor.manager.loop
|
||||
self.socket = socket or pyuv.TCP(self.loop)
|
||||
|
@ -80,39 +85,48 @@ class Connection(object):
|
|||
self.remote_local_address = None
|
||||
self.buffer = None
|
||||
self.socks5 = socks5
|
||||
self.timer = pyuv.Timer(self.loop)
|
||||
self.timeout = timeout
|
||||
|
||||
def _connection_timeout(self, handle):
|
||||
try:
|
||||
handle.stop()
|
||||
except:
|
||||
pass
|
||||
|
||||
self.close(-1, mutual=self.remote_id is not None)
|
||||
|
||||
def register_remote_id(self, remote_id):
|
||||
self.remote_id = remote_id
|
||||
|
||||
def on_connected(self, local_address, error):
|
||||
if error:
|
||||
if self.socks5:
|
||||
self.socket.write(
|
||||
struct.pack(
|
||||
'BB', 0x5, ERRNO_TO_SOCKS5.get(reason, CODE_GENERAL_SRV_FAILURE)
|
||||
) + self.socks5[2:])
|
||||
|
||||
self.close(error, mutual=False)
|
||||
else:
|
||||
if self.socks5:
|
||||
try:
|
||||
addr, port = IPAddress(local_address[0]), local_address[1]
|
||||
|
||||
try:
|
||||
if self.socks5:
|
||||
self.socket.write(
|
||||
struct.pack(
|
||||
'BBBB', 0x5,
|
||||
0, 0,
|
||||
ADDR_IPV4 if addr.version == 4 else ADDR_IPV6
|
||||
) + addr.packed + struct.pack('>H', port)
|
||||
)
|
||||
'BB', 0x5, ERRNO_TO_SOCKS5.get(error, CODE_GENERAL_SRV_FAILURE)
|
||||
) + self.socks5[2:])
|
||||
finally:
|
||||
self.close(error, mutual=False)
|
||||
else:
|
||||
try:
|
||||
if self.socks5:
|
||||
addr, port = IPAddress(local_address[0]), local_address[1]
|
||||
self.socket.write(
|
||||
struct.pack(
|
||||
'BBBB', 0x5,
|
||||
0, 0,
|
||||
ADDR_IPV4 if addr.version == 4 else ADDR_IPV6
|
||||
) + addr.packed + struct.pack('>H', port))
|
||||
|
||||
except Exception, e:
|
||||
logging.debug('SOCKS5 response failed: {}'.format(e))
|
||||
if self.buffer:
|
||||
self._on_read_data(self.socket, self.buffer, None)
|
||||
|
||||
if self.buffer:
|
||||
self._on_read_data(self.socket, self.buffer, None)
|
||||
self.forward()
|
||||
|
||||
self.forward()
|
||||
except:
|
||||
self.close(-1)
|
||||
|
||||
def on_data(self, data):
|
||||
if not self.socket:
|
||||
|
@ -143,20 +157,14 @@ class Connection(object):
|
|||
else:
|
||||
self.close(error)
|
||||
|
||||
def _report_close_remote(self, reason):
|
||||
self.socket.close()
|
||||
|
||||
try:
|
||||
self.neighbor.callbacks.on_disconnect(
|
||||
self.neighbor.remote_id,
|
||||
self.remote_id,
|
||||
reason
|
||||
)
|
||||
|
||||
except EOFError:
|
||||
self.neighbor.stop(dead=True)
|
||||
|
||||
def _on_connected(self, handle, error):
|
||||
try:
|
||||
self.timer.stop()
|
||||
self.timer.close()
|
||||
self.timer = None
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.neighbor.callbacks.on_connected(
|
||||
self.neighbor.remote_id,
|
||||
|
@ -165,41 +173,55 @@ class Connection(object):
|
|||
error=error
|
||||
)
|
||||
|
||||
if error:
|
||||
try:
|
||||
self.socket.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
self.forward()
|
||||
|
||||
except EOFError:
|
||||
self.neighbor.stop(dead=True)
|
||||
|
||||
if error:
|
||||
self.socket.close()
|
||||
else:
|
||||
self.forward()
|
||||
|
||||
def _start_connect(self, address):
|
||||
def connect(self, address):
|
||||
try:
|
||||
self.timer.start(self._connection_timeout, self.timeout, 0)
|
||||
self.socket.connect(address, self._on_connected)
|
||||
except:
|
||||
except Exception, e:
|
||||
self._on_connected(None, -1)
|
||||
|
||||
def connect(self, address):
|
||||
self.loop.queue_work(lambda: self._start_connect(address))
|
||||
|
||||
def forward(self):
|
||||
self.loop.queue_work(lambda: self.socket.start_read(self._on_read_data))
|
||||
self.socket.start_read(self._on_read_data)
|
||||
|
||||
def close(self, reason, mutual=True):
|
||||
unregistered = False
|
||||
if mutual:
|
||||
try:
|
||||
self.loop.queue_work(lambda: self.socket.shutdown(
|
||||
lambda handle, error: self._report_close_remote(reason)))
|
||||
except:
|
||||
self.socket.close()
|
||||
else:
|
||||
self.neighbor.callbacks.on_disconnect(
|
||||
self.neighbor.remote_id,
|
||||
self.remote_id,
|
||||
reason
|
||||
)
|
||||
|
||||
except EOFError:
|
||||
self.neighbor.stop(dead=True)
|
||||
unregistered = True
|
||||
|
||||
try:
|
||||
self.socket.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
self.neighbor.unregister_connection(self)
|
||||
|
||||
def read_exactly(self, size, callback):
|
||||
pass
|
||||
try:
|
||||
if self.timer:
|
||||
self.timer.stop()
|
||||
except:
|
||||
pass
|
||||
|
||||
if not unregistered:
|
||||
self.neighbor.unregister_connection(self)
|
||||
|
||||
class Acceptor(object):
|
||||
def __init__(self, neighbor, local_address, forward_address=None):
|
||||
|
@ -209,7 +231,9 @@ class Acceptor(object):
|
|||
self.forward_address = forward_address
|
||||
self.associaction = {}
|
||||
self.socket = pyuv.TCP(self.loop)
|
||||
self.socket.bind(local_address)
|
||||
|
||||
def start(self):
|
||||
self.socket.bind(self.local_address)
|
||||
self.socket.listen(self._on_connection)
|
||||
|
||||
def _on_connection(self, handle, error):
|
||||
|
@ -396,7 +420,10 @@ class Neighbor(object):
|
|||
connection
|
||||
].close(-1, mutual=not dead)
|
||||
|
||||
del self.manager.neighbors[self.local_id]
|
||||
try:
|
||||
del self.manager.neighbors[self.local_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def pair(self, local_id, remote_id):
|
||||
self.local_id = local_id
|
||||
|
@ -419,7 +446,8 @@ class Neighbor(object):
|
|||
self.connections[connection.local_id] = connection
|
||||
|
||||
def unregister_connection(self, connection):
|
||||
del self.connections[connection.local_id]
|
||||
if connection.local_id in self.connections:
|
||||
del self.connections[connection.local_id]
|
||||
|
||||
def register_acceptor(self, acceptor, port):
|
||||
self.acceptors[port] = acceptor
|
||||
|
@ -446,31 +474,47 @@ class Neighbor(object):
|
|||
class Manager(Thread):
|
||||
def __init__(self):
|
||||
super(Manager, self).__init__()
|
||||
self.loop = pyuv.Loop.default_loop()
|
||||
self.loop = pyuv.Loop()
|
||||
self.neighbors = {}
|
||||
self.wakeup = Event()
|
||||
self.stopped = Event()
|
||||
self.ports = {}
|
||||
self.daemon = True
|
||||
self.wake = pyuv.Async(self.loop, self.sync)
|
||||
self.queue = Queue()
|
||||
|
||||
def stop(self, dead=False):
|
||||
def sync(self, handle):
|
||||
while True:
|
||||
try:
|
||||
method, args = self.queue.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
|
||||
try:
|
||||
method(*args)
|
||||
except Exception, e:
|
||||
logging.exception('Defered call exception: {}'.format(e))
|
||||
|
||||
def defer(self, method, *args):
|
||||
self.queue.put((method, args))
|
||||
self.wake.send()
|
||||
|
||||
def _stop(self, dead):
|
||||
for neighbor_id in self.neighbors.keys():
|
||||
self.neighbors[neighbor_id].stop(dead=dead)
|
||||
self.wake.close()
|
||||
for handle in self.loop.handles:
|
||||
if not handle.closed:
|
||||
handle.close()
|
||||
|
||||
self.stopped.set()
|
||||
self.wakeup.set()
|
||||
self.loop.stop()
|
||||
|
||||
def stop(self, dead=False):
|
||||
self.defer(self._stop, dead)
|
||||
|
||||
def force_stop(self):
|
||||
self.stop(dead=True)
|
||||
|
||||
def run(self):
|
||||
while not self.stopped.is_set():
|
||||
self.wakeup.wait()
|
||||
if self.stopped.is_set():
|
||||
break
|
||||
|
||||
self.wakeup.clear()
|
||||
self.loop.run()
|
||||
self.loop.run()
|
||||
|
||||
def get_neighbor(self, neighbor_id):
|
||||
if not neighbor_id in self.neighbors:
|
||||
|
@ -478,7 +522,7 @@ class Manager(Thread):
|
|||
|
||||
return self.neighbors[neighbor_id]
|
||||
|
||||
def bind(self, neighbor_id, host='127.0.0.1', port=8080, forward=None):
|
||||
def _bind(self, neighbor_id, host, port, forward):
|
||||
neighbor = self.get_neighbor(neighbor_id)
|
||||
acceptor = Acceptor(
|
||||
neighbor,
|
||||
|
@ -487,8 +531,10 @@ class Manager(Thread):
|
|||
)
|
||||
|
||||
neighbor.register_acceptor(acceptor, port)
|
||||
acceptor.start()
|
||||
|
||||
self.wakeup.set()
|
||||
def bind(self, neighbor_id, host='127.0.0.1', port=8080, forward=None):
|
||||
self.defer(self._bind, neighbor_id, host, port, forward)
|
||||
|
||||
def unbind(self, port):
|
||||
for neighbor in self.neighbors.itervalues():
|
||||
|
@ -506,36 +552,43 @@ class Manager(Thread):
|
|||
).create_connection(remote_id=remote_id)
|
||||
|
||||
def connect(self, neighbor_id, connection_id, address):
|
||||
self.get_neighbor(
|
||||
neighbor_id
|
||||
).get_connection(
|
||||
connection_id
|
||||
).connect(address)
|
||||
|
||||
self.wakeup.set()
|
||||
self.defer(
|
||||
self.get_neighbor(
|
||||
neighbor_id
|
||||
).get_connection(
|
||||
connection_id
|
||||
).connect,
|
||||
address
|
||||
)
|
||||
|
||||
def forward(self, neighbor_id, connection_id):
|
||||
self.get_neighbor(
|
||||
neighbor_id
|
||||
).get_connection(
|
||||
connection_id
|
||||
).forward()
|
||||
|
||||
self.wakeup.set()
|
||||
self.defer(
|
||||
self.get_neighbor(
|
||||
neighbor_id
|
||||
).get_connection(
|
||||
connection_id
|
||||
).forward
|
||||
)
|
||||
|
||||
def on_connected(self, neighbor_id, connection_id, local_address, error=None):
|
||||
connection = self.get_neighbor(
|
||||
neighbor_id
|
||||
).get_connection(
|
||||
connection_id
|
||||
).on_connected(local_address, error)
|
||||
self.defer(
|
||||
self.get_neighbor(
|
||||
neighbor_id
|
||||
).get_connection(
|
||||
connection_id
|
||||
).on_connected,
|
||||
local_address, error
|
||||
)
|
||||
|
||||
def on_data(self, neighbor_id, connection_id, data):
|
||||
self.get_neighbor(
|
||||
neighbor_id
|
||||
).get_connection(
|
||||
connection_id
|
||||
).on_data(data)
|
||||
self.defer(
|
||||
self.get_neighbor(
|
||||
neighbor_id
|
||||
).get_connection(
|
||||
connection_id
|
||||
).on_data,
|
||||
data
|
||||
)
|
||||
|
||||
def on_disconnect(self, neighbor_id, connection_id, reason=None):
|
||||
neighbor = self.get_neighbor(
|
||||
|
@ -547,7 +600,10 @@ class Manager(Thread):
|
|||
connection_id
|
||||
)
|
||||
|
||||
connection.on_disconnect(reason)
|
||||
self.defer(
|
||||
connection.on_disconnect,
|
||||
reason
|
||||
)
|
||||
except (ConnectionIsNotExists, ChannelIsNotReady):
|
||||
pass
|
||||
|
||||
|
@ -581,12 +637,15 @@ class Manager(Thread):
|
|||
|
||||
return remote_id, local_id
|
||||
|
||||
def unpair(self, local_id, dead=False):
|
||||
def _unpair(self, local_id, dead):
|
||||
if not local_id in self.neighbors:
|
||||
raise NeighborIsNotExists(local_id)
|
||||
|
||||
self.neighbors[local_id].stop(dead=dead)
|
||||
|
||||
def unpair(self, local_id, dead=False):
|
||||
self.defer(self._unpair, local_id, dead)
|
||||
|
||||
def list(self, filter_by_local_id=None):
|
||||
results = []
|
||||
if filter_by_local_id:
|
||||
|
|
Loading…
Reference in New Issue