package main import ( "fmt" "io" "net" "strconv" "strings" "time" "context" "errors" "crypto/tls" "crypto/x509" "math/rand" log "github.com/sirupsen/logrus" kcp "github.com/xtaci/kcp-go" ) func NewNetReader(mtu int, in, out net.Conn) *NetReader { return &NetReader{ mtu: mtu, in: in, out: out, wait: make(chan error), } } func (n *NetReader) recv() ([]byte, error) { portion := 1024 * n.mtu if portion <= 0 { portion = 4 * 1024 * 1024 } buffer := make([]byte, portion) l, err := n.in.Read(buffer) if err != nil { return nil, err } return buffer[:l], err } func (n *NetReader) send(buffer []byte) error { toSend := len(buffer) offset := 0 for { if toSend == 0 { break } portion := toSend if n.mtu > 0 && portion > n.mtu { portion = n.mtu } cnt, err := n.out.Write(buffer[offset : offset+portion]) if err != nil { return err } offset += cnt toSend -= cnt } return nil } func (n *NetReader) Serve() { for { buffer, err := n.recv() if err != nil { n.ReportError(err) return } err = n.send(buffer) if err != nil { n.ReportError(err) return } } } func (n *NetReader) ReportError(err error) { log.Error( "NetReader: ", n.in.RemoteAddr(), " -> ", n.out.RemoteAddr(), ": ", err, ) n.err = err select { case n.wait <- err: default: } } func NewNetForwarder(pproxy, remote net.Conn) *NetForwarder { return &NetForwarder{ pproxy: pproxy, remote: remote, } } func (n *NetForwarder) sendRemoteConnectionInfo() error { localAddr := strings.Split(n.remote.LocalAddr().String(), ":") remoteAddr := strings.Split(n.remote.RemoteAddr().String(), ":") localPort, _ := strconv.Atoi(localAddr[1]) remotePort, _ := strconv.Atoi(remoteAddr[1]) return SendMessage(n.pproxy, ConnectionAcceptHeader{ LocalHost: localAddr[0], LocalPort: localPort, RemoteHost: remoteAddr[0], RemotePort: remotePort, }) } func (n *NetForwarder) Serve(ctx context.Context, l7ready chan error, remoteMtu int) error { log.Warning( "Forwarder: ", n.remote.LocalAddr(), " <- ", n.remote.RemoteAddr(), ) err := n.sendRemoteConnectionInfo() if err != nil { log.Error("Notification handler send failed: ", err) return err } select { case err = <-l7ready: if err != nil { return err } case <-ctx.Done(): return errors.New("Cancelled") } log.Debug("Start network readers") remoteLocalReader := NewNetReader(remoteMtu, n.remote, n.pproxy) localRemoteReader := NewNetReader(remoteMtu, n.pproxy, n.remote) go remoteLocalReader.Serve() go localRemoteReader.Serve() select { case err = <-remoteLocalReader.wait: log.Debug("NetForwarder: Remote->Local Closed: ", err) return err case err = <-localRemoteReader.wait: log.Debug("NetForwarder: Local->Remote Closed: ", err) return err case <-ctx.Done(): log.Debug("NetForwarder: Cancellation received") return errors.New("Cancelled") } } func (d *Daemon) Accept(in net.Conn, port int, createListener func(net.Conn) (net.Listener, error)) (net.Conn, error) { var ( listener *Listener ok bool ) d.ListenersLock.Lock() if listener, ok = d.Listeners[port]; !ok { log.Debug(fmt.Sprintf("Create new listener [%d]", port)) l, err := createListener(in) if err != nil { log.Error(fmt.Sprintf("Create new listener [%d]: failed: %s", port, err.Error())) d.ListenersLock.Unlock() return nil, err } listener = &Listener{ Listener: l, refcnt: 0, } d.Listeners[port] = listener log.Debug(fmt.Sprintf("New listener [%d] created", port)) } listener.refcnt += 1 log.Info(fmt.Sprintf("Create new listener [%d]: ok: refcnt=%d", port, listener.refcnt)) d.ListenersLock.Unlock() return listener.Listener.Accept() } func (d *Daemon) Remove(port int) { var ( listener *Listener ok bool ) d.ListenersLock.Lock() if listener, ok = d.Listeners[port]; ok { listener.refcnt -= 1 log.Info(fmt.Sprintf("Remove listener [%d]; refcnt=%d", port, listener.refcnt)) if listener.refcnt == 0 { log.Info(fmt.Sprintf("Close listener [%d]", port)) listener.Listener.Close() delete(d.Listeners, port) } } d.ListenersLock.Unlock() } func (d *Daemon) listenAcceptTCP(in net.Conn, port int) (net.Conn, error) { conn, err := d.Accept(in, port, func(in net.Conn) (net.Listener, error) { log.Println("New listener requested, port:", port) return net.Listen("tcp", fmt.Sprintf("%s:%d", ExternalBindHost, port)) }) log.Debug("TCP: Accepted connection") if conn != nil { conn.(*net.TCPConn).SetKeepAlive(true) conn.(*net.TCPConn).SetKeepAlivePeriod(1 * time.Minute) conn.(*net.TCPConn).SetNoDelay(true) } log.Debug("TCP Acceptor completed: ", conn, err) return conn, err } func (d *Daemon) listenAcceptTLS(in net.Conn, port int) (net.Conn, error) { conn, err := d.Accept(in, port, func(in net.Conn) (net.Listener, error) { log.Debug("Load certificates") err := SendMessage(in, &Extra{ Extra: true, Data: "certs", }) if err != nil { return nil, err } config := &TLSAcceptorConfig{} err = RecvMessage(in, config) if err != nil { log.Error("Couldn't receive TLS certificate") return nil, err } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM([]byte(config.CACert)) { log.Error("Invalid CA cert") return nil, errors.New("Invalid CA cert") } cert, err := tls.X509KeyPair([]byte(config.Cert), []byte(config.Key)) if err != nil { log.Error("Invalid SSL Key/Cert") return nil, errors.New("Invalid SSL Key/Cert: " + err.Error()) } log.Debug("SSL: New listener requested, port:", port) return tls.Listen("tcp", fmt.Sprintf("%s:%d", ExternalBindHost, port), &tls.Config{ Certificates: []tls.Certificate{cert}, ClientCAs: pool, ClientAuth: tls.RequireAndVerifyClientCert, }) }) log.Debug("SSL: Accepted connection: ", conn, err) return conn, err } func NewKCPConn(in net.Conn) net.Conn { localId := [4]byte{} for i := 0; i < 4; i++ { localId[i] = byte(rand.Intn(255)) } kcpconn := &KCPConn{ localId: localId, Conn: in, } return kcpconn } func (c *KCPConn) sendEOF() { end := [5]byte{} end[0] = KCP_END copy(end[1:], c.localId[:]) c.Conn.Write(end[:]) } func compareId(id1, id2 []byte) bool { if len(id1) != 4 || len(id2) != 4 { return false } for i := 0; i < 4; i++ { if id1[i] != id2[i] { return false } } return true } func (c *KCPConn) Read(b []byte) (n int, err error) { buf := make([]byte, len(b)+5) n, err = c.Conn.Read(buf) if err != nil || n < 5 { log.Debug( "KCP: Invalid KCP header (too small or error) ", n, err, ) return 0, io.EOF } switch buf[0] { case KCP_NEW: if !c.initialized { log.Debug("KCP: NEW received") copy(c.remoteId[:], buf[1:5]) c.initialized = true } else { log.Debug("KCP: Unexpected NEW") c.sendEOF() return 0, io.EOF } case KCP_DAT: if !c.initialized || !compareId(c.remoteId[:], buf[1:5]) { log.Debug("KCP: Unexpected DAT") c.sendEOF() return 0, io.EOF } case KCP_END: log.Debug("KCP: EOF Received") return 0, io.EOF default: log.Debug("KCP: Unknown flag") return 0, io.EOF } return copy(b[:], buf[5:n]), nil } func (c *KCPConn) Write(b []byte) (n int, err error) { buf := make([]byte, len(b)+5) if c.new_sent { buf[0] = KCP_DAT } else { buf[0] = KCP_NEW c.new_sent = true } copy(buf[1:5], c.localId[:]) copy(buf[5:], b[:]) n, err = c.Conn.Write(buf) if err != nil { return 0, err } return n - 5, nil } func (c *KCPConn) Close() error { log.Debug("KCP: Close() called, send EOF") c.sendEOF() return c.Conn.Close() } func (c *KCPConn) LocalAddr() net.Addr { return c.Conn.LocalAddr() } func (c *KCPConn) RemoteAddr() net.Addr { return c.Conn.RemoteAddr() } func (c *KCPConn) SetDeadline(t time.Time) error { return c.Conn.SetDeadline(t) } func (c *KCPConn) SetReadDeadline(t time.Time) error { return c.Conn.SetReadDeadline(t) } func (c *KCPConn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) } func (d *Daemon) listenAcceptKCP(in net.Conn, port int) (net.Conn, error) { conn, err := d.Accept(in, port, func(in net.Conn) (net.Listener, error) { log.Debug("New KCP listener requested, port:", port) ll, err := kcp.Listen(fmt.Sprintf("%s:%d", ExternalBindHost, port)) if err != nil { log.Error("KCP Listen Error: ", err) return nil, err } l := ll.(*kcp.Listener) l.SetReadBuffer(1024 * 1024) l.SetWriteBuffer(1024 * 1024) return l, nil }) log.Debug("KCP: Accepted connection", conn, err) return NewKCPConn(conn), err } func l7KeepAliveSender(ctx context.Context, conn net.Conn, cherr chan error) { ticker := time.NewTicker(5 * time.Second) for { select { case <-ctx.Done(): ticker.Stop() return case t := <-ticker.C: err := SendKeepAlive(conn, t) if err != nil { select { case cherr <- err: default: } return } } } } func l7KeepAliveReceiver(ctx context.Context, conn net.Conn, cherr chan error) { for { select { case <-ctx.Done(): return default: keepalive := &KeepAlive{} err := RecvMessage(conn, keepalive) if err != nil { select { case cherr <- err: default: } return } rtt := time.Now().Unix() - keepalive.Tick log.Debug( "KeepAlive: ", keepalive.Tick, " RTT: ", rtt, " Last: ", keepalive.Last) if keepalive.Last { select { case cherr <- nil: default: } } } } } func withAccept( in net.Conn, port int, acceptor func(net.Conn, int) (net.Conn, error), chconn chan net.Conn, cherr chan error) { conn, err := acceptor(in, port) if err != nil { select { case cherr <- err: default: } } else { select { case chconn <- conn: default: } } } func acceptOrWait( in net.Conn, port int, acceptor func(net.Conn, int) (net.Conn, error)) (chan error, net.Conn, error) { chconn := make(chan net.Conn) cherr := make(chan error) l7rcherr := make(chan error) pingContext, pingCancel := context.WithCancel(context.Background()) defer pingCancel() pingRecvContext, pingRecvCancel := context.WithCancel( context.Background()) defer pingRecvCancel() go withAccept(in, port, acceptor, chconn, cherr) go l7KeepAliveSender(pingContext, in, cherr) go l7KeepAliveReceiver(pingRecvContext, in, l7rcherr) select { case conn := <-chconn: return l7rcherr, conn, nil case err := <-l7rcherr: log.Error("L7 KeepAlive Receiver failed for ", in.RemoteAddr()) return nil, nil, err case err := <-cherr: log.Error("Accept failed for ", in.RemoteAddr()) return nil, nil, err } } func (d *Daemon) serveStream( mtu int, pproxy net.Conn, bind string, acceptor func(net.Conn, int) (net.Conn, error), ) { defer pproxy.Close() port, err := strconv.Atoi(bind) if err != nil { log.Error("Invalid port: ", err.Error()) SendError(pproxy, err) return } for _, mapping := range PortMaps { if port == mapping.From { port = mapping.To break } } defer d.Remove(port) done, remote, err := acceptOrWait(pproxy, port, acceptor) if err != nil { SendError(pproxy, err) return } defer remote.Close() forwarder := NewNetForwarder(pproxy, remote) forwarder.Serve(context.Background(), done, mtu) }