mirror of https://github.com/n1nj4sec/pupy.git
560 lines
11 KiB
Go
560 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
"context"
|
|
"errors"
|
|
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"math/rand"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
kcp "github.com/xtaci/kcp-go"
|
|
)
|
|
|
|
func NewNetReader(mtu int, in, out net.Conn) *NetReader {
|
|
return &NetReader{
|
|
mtu: mtu,
|
|
in: in,
|
|
out: out,
|
|
wait: make(chan error),
|
|
}
|
|
}
|
|
|
|
func (n *NetReader) recv() ([]byte, error) {
|
|
portion := 1024 * n.mtu
|
|
if portion <= 0 {
|
|
portion = 4 * 1024 * 1024
|
|
}
|
|
|
|
buffer := make([]byte, portion)
|
|
l, err := n.in.Read(buffer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return buffer[:l], err
|
|
}
|
|
|
|
func (n *NetReader) send(buffer []byte) error {
|
|
toSend := len(buffer)
|
|
offset := 0
|
|
|
|
for {
|
|
if toSend == 0 {
|
|
break
|
|
}
|
|
|
|
portion := toSend
|
|
if n.mtu > 0 && portion > n.mtu {
|
|
portion = n.mtu
|
|
}
|
|
|
|
cnt, err := n.out.Write(buffer[offset : offset+portion])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
offset += cnt
|
|
toSend -= cnt
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (n *NetReader) Serve() {
|
|
for {
|
|
buffer, err := n.recv()
|
|
if err != nil {
|
|
n.ReportError(err)
|
|
return
|
|
}
|
|
|
|
err = n.send(buffer)
|
|
if err != nil {
|
|
n.ReportError(err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (n *NetReader) ReportError(err error) {
|
|
log.Error(
|
|
"NetReader: ", n.in.RemoteAddr(), " -> ",
|
|
n.out.RemoteAddr(), ": ", err,
|
|
)
|
|
|
|
n.err = err
|
|
select {
|
|
case n.wait <- err:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func NewNetForwarder(pproxy, remote net.Conn) *NetForwarder {
|
|
return &NetForwarder{
|
|
pproxy: pproxy,
|
|
remote: remote,
|
|
}
|
|
}
|
|
|
|
func (n *NetForwarder) sendRemoteConnectionInfo() error {
|
|
localAddr := strings.Split(n.remote.LocalAddr().String(), ":")
|
|
remoteAddr := strings.Split(n.remote.RemoteAddr().String(), ":")
|
|
|
|
localPort, _ := strconv.Atoi(localAddr[1])
|
|
remotePort, _ := strconv.Atoi(remoteAddr[1])
|
|
|
|
return SendMessage(n.pproxy, ConnectionAcceptHeader{
|
|
LocalHost: localAddr[0],
|
|
LocalPort: localPort,
|
|
RemoteHost: remoteAddr[0],
|
|
RemotePort: remotePort,
|
|
})
|
|
}
|
|
|
|
func (n *NetForwarder) Serve(ctx context.Context, l7ready chan error, remoteMtu int) error {
|
|
log.Warning(
|
|
"Forwarder: ", n.remote.LocalAddr(), " <- ", n.remote.RemoteAddr(),
|
|
)
|
|
|
|
err := n.sendRemoteConnectionInfo()
|
|
if err != nil {
|
|
log.Error("Notification handler send failed: ", err)
|
|
return err
|
|
}
|
|
|
|
select {
|
|
case err = <-l7ready:
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case <-ctx.Done():
|
|
return errors.New("Cancelled")
|
|
}
|
|
|
|
log.Debug("Start network readers")
|
|
|
|
remoteLocalReader := NewNetReader(remoteMtu, n.remote, n.pproxy)
|
|
localRemoteReader := NewNetReader(remoteMtu, n.pproxy, n.remote)
|
|
|
|
go remoteLocalReader.Serve()
|
|
go localRemoteReader.Serve()
|
|
|
|
select {
|
|
case err = <-remoteLocalReader.wait:
|
|
log.Debug("NetForwarder: Remote->Local Closed: ", err)
|
|
return err
|
|
|
|
case err = <-localRemoteReader.wait:
|
|
log.Debug("NetForwarder: Local->Remote Closed: ", err)
|
|
return err
|
|
|
|
case <-ctx.Done():
|
|
log.Debug("NetForwarder: Cancellation received")
|
|
return errors.New("Cancelled")
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) Accept(in net.Conn, port int, createListener func(net.Conn) (net.Listener, error)) (net.Conn, error) {
|
|
var (
|
|
listener *Listener
|
|
ok bool
|
|
)
|
|
|
|
d.ListenersLock.Lock()
|
|
if listener, ok = d.Listeners[port]; !ok {
|
|
log.Debug(fmt.Sprintf("Create new listener [%d]", port))
|
|
l, err := createListener(in)
|
|
if err != nil {
|
|
log.Error(fmt.Sprintf("Create new listener [%d]: failed: %s", port, err.Error()))
|
|
d.ListenersLock.Unlock()
|
|
return nil, err
|
|
}
|
|
|
|
listener = &Listener{
|
|
Listener: l,
|
|
refcnt: 0,
|
|
}
|
|
|
|
d.Listeners[port] = listener
|
|
log.Debug(fmt.Sprintf("New listener [%d] created", port))
|
|
}
|
|
|
|
listener.refcnt += 1
|
|
log.Info(fmt.Sprintf("Create new listener [%d]: ok: refcnt=%d", port, listener.refcnt))
|
|
d.ListenersLock.Unlock()
|
|
|
|
return listener.Listener.Accept()
|
|
}
|
|
|
|
func (d *Daemon) Remove(port int) {
|
|
var (
|
|
listener *Listener
|
|
ok bool
|
|
)
|
|
|
|
d.ListenersLock.Lock()
|
|
if listener, ok = d.Listeners[port]; ok {
|
|
listener.refcnt -= 1
|
|
log.Info(fmt.Sprintf("Remove listener [%d]; refcnt=%d", port, listener.refcnt))
|
|
if listener.refcnt == 0 {
|
|
log.Info(fmt.Sprintf("Close listener [%d]", port))
|
|
listener.Listener.Close()
|
|
delete(d.Listeners, port)
|
|
}
|
|
}
|
|
|
|
d.ListenersLock.Unlock()
|
|
}
|
|
|
|
func (d *Daemon) listenAcceptTCP(in net.Conn, port int) (net.Conn, error) {
|
|
conn, err := d.Accept(in, port, func(in net.Conn) (net.Listener, error) {
|
|
log.Println("New listener requested, port:", port)
|
|
return net.Listen("tcp", fmt.Sprintf("%s:%d", ExternalBindHost, port))
|
|
})
|
|
|
|
log.Debug("TCP: Accepted connection")
|
|
|
|
if conn != nil {
|
|
conn.(*net.TCPConn).SetKeepAlive(true)
|
|
conn.(*net.TCPConn).SetKeepAlivePeriod(1 * time.Minute)
|
|
conn.(*net.TCPConn).SetNoDelay(true)
|
|
}
|
|
|
|
log.Debug("TCP Acceptor completed: ", conn, err)
|
|
return conn, err
|
|
}
|
|
|
|
func (d *Daemon) listenAcceptTLS(in net.Conn, port int) (net.Conn, error) {
|
|
conn, err := d.Accept(in, port, func(in net.Conn) (net.Listener, error) {
|
|
log.Debug("Load certificates")
|
|
err := SendMessage(in, &Extra{
|
|
Extra: true,
|
|
Data: "certs",
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
config := &TLSAcceptorConfig{}
|
|
err = RecvMessage(in, config)
|
|
if err != nil {
|
|
log.Error("Couldn't receive TLS certificate")
|
|
return nil, err
|
|
}
|
|
|
|
pool := x509.NewCertPool()
|
|
if !pool.AppendCertsFromPEM([]byte(config.CACert)) {
|
|
log.Error("Invalid CA cert")
|
|
return nil, errors.New("Invalid CA cert")
|
|
}
|
|
|
|
cert, err := tls.X509KeyPair([]byte(config.Cert), []byte(config.Key))
|
|
if err != nil {
|
|
log.Error("Invalid SSL Key/Cert")
|
|
return nil, errors.New("Invalid SSL Key/Cert: " + err.Error())
|
|
}
|
|
|
|
log.Debug("SSL: New listener requested, port:", port)
|
|
return tls.Listen("tcp", fmt.Sprintf("%s:%d", ExternalBindHost, port), &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
ClientCAs: pool,
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
})
|
|
})
|
|
|
|
log.Debug("SSL: Accepted connection: ", conn, err)
|
|
return conn, err
|
|
}
|
|
|
|
func NewKCPConn(in net.Conn) net.Conn {
|
|
localId := [4]byte{}
|
|
for i := 0; i < 4; i++ {
|
|
localId[i] = byte(rand.Intn(255))
|
|
}
|
|
|
|
kcpconn := &KCPConn{
|
|
localId: localId,
|
|
Conn: in,
|
|
}
|
|
return kcpconn
|
|
}
|
|
|
|
func (c *KCPConn) sendEOF() {
|
|
end := [5]byte{}
|
|
end[0] = KCP_END
|
|
copy(end[1:], c.localId[:])
|
|
c.Conn.Write(end[:])
|
|
}
|
|
|
|
func compareId(id1, id2 []byte) bool {
|
|
if len(id1) != 4 || len(id2) != 4 {
|
|
return false
|
|
}
|
|
|
|
for i := 0; i < 4; i++ {
|
|
if id1[i] != id2[i] {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (c *KCPConn) Read(b []byte) (n int, err error) {
|
|
buf := make([]byte, len(b)+5)
|
|
|
|
n, err = c.Conn.Read(buf)
|
|
|
|
if err != nil || n < 5 {
|
|
log.Debug(
|
|
"KCP: Invalid KCP header (too small or error) ",
|
|
n, err,
|
|
)
|
|
return 0, io.EOF
|
|
}
|
|
|
|
switch buf[0] {
|
|
case KCP_NEW:
|
|
if !c.initialized {
|
|
log.Debug("KCP: NEW received")
|
|
copy(c.remoteId[:], buf[1:5])
|
|
c.initialized = true
|
|
} else {
|
|
log.Debug("KCP: Unexpected NEW")
|
|
c.sendEOF()
|
|
return 0, io.EOF
|
|
}
|
|
case KCP_DAT:
|
|
if !c.initialized || !compareId(c.remoteId[:], buf[1:5]) {
|
|
log.Debug("KCP: Unexpected DAT")
|
|
c.sendEOF()
|
|
return 0, io.EOF
|
|
}
|
|
case KCP_END:
|
|
log.Debug("KCP: EOF Received")
|
|
return 0, io.EOF
|
|
|
|
default:
|
|
log.Debug("KCP: Unknown flag")
|
|
return 0, io.EOF
|
|
}
|
|
|
|
return copy(b[:], buf[5:n]), nil
|
|
}
|
|
|
|
func (c *KCPConn) Write(b []byte) (n int, err error) {
|
|
buf := make([]byte, len(b)+5)
|
|
if c.new_sent {
|
|
buf[0] = KCP_DAT
|
|
} else {
|
|
buf[0] = KCP_NEW
|
|
c.new_sent = true
|
|
}
|
|
|
|
copy(buf[1:5], c.localId[:])
|
|
copy(buf[5:], b[:])
|
|
|
|
n, err = c.Conn.Write(buf)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return n - 5, nil
|
|
}
|
|
|
|
func (c *KCPConn) Close() error {
|
|
log.Debug("KCP: Close() called, send EOF")
|
|
c.sendEOF()
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
func (c *KCPConn) LocalAddr() net.Addr {
|
|
return c.Conn.LocalAddr()
|
|
}
|
|
|
|
func (c *KCPConn) RemoteAddr() net.Addr {
|
|
return c.Conn.RemoteAddr()
|
|
}
|
|
|
|
func (c *KCPConn) SetDeadline(t time.Time) error {
|
|
return c.Conn.SetDeadline(t)
|
|
}
|
|
|
|
func (c *KCPConn) SetReadDeadline(t time.Time) error {
|
|
return c.Conn.SetReadDeadline(t)
|
|
}
|
|
|
|
func (c *KCPConn) SetWriteDeadline(t time.Time) error {
|
|
return c.Conn.SetWriteDeadline(t)
|
|
}
|
|
|
|
func (d *Daemon) listenAcceptKCP(in net.Conn, port int) (net.Conn, error) {
|
|
conn, err := d.Accept(in, port, func(in net.Conn) (net.Listener, error) {
|
|
log.Debug("New KCP listener requested, port:", port)
|
|
|
|
ll, err := kcp.Listen(fmt.Sprintf("%s:%d", ExternalBindHost, port))
|
|
if err != nil {
|
|
log.Error("KCP Listen Error: ", err)
|
|
return nil, err
|
|
}
|
|
|
|
l := ll.(*kcp.Listener)
|
|
|
|
l.SetReadBuffer(1024 * 1024)
|
|
l.SetWriteBuffer(1024 * 1024)
|
|
return l, nil
|
|
})
|
|
|
|
log.Debug("KCP: Accepted connection", conn, err)
|
|
return NewKCPConn(conn), err
|
|
}
|
|
|
|
func l7KeepAliveSender(ctx context.Context, conn net.Conn, cherr chan error) {
|
|
ticker := time.NewTicker(5 * time.Second)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
ticker.Stop()
|
|
return
|
|
case t := <-ticker.C:
|
|
err := SendKeepAlive(conn, t)
|
|
if err != nil {
|
|
select {
|
|
case cherr <- err:
|
|
default:
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func l7KeepAliveReceiver(ctx context.Context, conn net.Conn, cherr chan error) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
keepalive := &KeepAlive{}
|
|
err := RecvMessage(conn, keepalive)
|
|
if err != nil {
|
|
select {
|
|
case cherr <- err:
|
|
default:
|
|
}
|
|
return
|
|
}
|
|
|
|
rtt := time.Now().Unix() - keepalive.Tick
|
|
|
|
log.Debug(
|
|
"KeepAlive: ", keepalive.Tick,
|
|
" RTT: ", rtt,
|
|
" Last: ", keepalive.Last)
|
|
|
|
if keepalive.Last {
|
|
select {
|
|
case cherr <- nil:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func withAccept(
|
|
in net.Conn, port int,
|
|
acceptor func(net.Conn, int) (net.Conn, error),
|
|
chconn chan net.Conn, cherr chan error) {
|
|
conn, err := acceptor(in, port)
|
|
|
|
if err != nil {
|
|
select {
|
|
case cherr <- err:
|
|
default:
|
|
}
|
|
} else {
|
|
select {
|
|
case chconn <- conn:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func acceptOrWait(
|
|
in net.Conn, port int,
|
|
acceptor func(net.Conn, int) (net.Conn, error)) (chan error, net.Conn, error) {
|
|
|
|
chconn := make(chan net.Conn)
|
|
cherr := make(chan error)
|
|
l7rcherr := make(chan error)
|
|
|
|
pingContext, pingCancel := context.WithCancel(context.Background())
|
|
defer pingCancel()
|
|
|
|
pingRecvContext, pingRecvCancel := context.WithCancel(
|
|
context.Background())
|
|
|
|
defer pingRecvCancel()
|
|
|
|
go withAccept(in, port, acceptor, chconn, cherr)
|
|
go l7KeepAliveSender(pingContext, in, cherr)
|
|
go l7KeepAliveReceiver(pingRecvContext, in, l7rcherr)
|
|
|
|
select {
|
|
case conn := <-chconn:
|
|
return l7rcherr, conn, nil
|
|
case err := <-l7rcherr:
|
|
log.Error("L7 KeepAlive Receiver failed for ", in.RemoteAddr())
|
|
return nil, nil, err
|
|
case err := <-cherr:
|
|
log.Error("Accept failed for ", in.RemoteAddr())
|
|
return nil, nil, err
|
|
}
|
|
}
|
|
|
|
func (d *Daemon) serveStream(
|
|
mtu int, pproxy net.Conn, bind string,
|
|
acceptor func(net.Conn, int) (net.Conn, error),
|
|
) {
|
|
|
|
defer pproxy.Close()
|
|
|
|
port, err := strconv.Atoi(bind)
|
|
if err != nil {
|
|
log.Error("Invalid port: ", err.Error())
|
|
SendError(pproxy, err)
|
|
return
|
|
}
|
|
|
|
for _, mapping := range PortMaps {
|
|
if port == mapping.From {
|
|
port = mapping.To
|
|
break
|
|
}
|
|
}
|
|
|
|
defer d.Remove(port)
|
|
|
|
done, remote, err := acceptOrWait(pproxy, port, acceptor)
|
|
if err != nil {
|
|
SendError(pproxy, err)
|
|
return
|
|
}
|
|
|
|
defer remote.Close()
|
|
|
|
forwarder := NewNetForwarder(pproxy, remote)
|
|
forwarder.Serve(context.Background(), done, mtu)
|
|
}
|