// The contents of this file are subject to the BOINC Public License
// Version 1.0 (the "License"); you may not use this file except in
// compliance with the License. You may obtain a copy of the License at
// http://boinc.berkeley.edu/license_1.0.txt
//
// Software distributed under the License is distributed on an "AS IS"
// basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
// License for the specific language governing rights and limitations
// under the License.
//
// The Original Code is the Berkeley Open Infrastructure for Network Computing.
//
// The Initial Developer of the Original Code is the SETI@home project.
// Portions created by the SETI@home project are Copyright (C) 2002
// University of California at Berkeley. All Rights Reserved.
//
// Contributor(s):
//

#include "cpp.h"

#include <stdio.h>
#include <math.h>

#ifdef _WIN32
#include <afxwin.h>
#include <winsock.h>
#include "Win_net.h"
#include "wingui_mainwindow.h"
#endif

#if HAVE_SYS_TIME_H
#include <sys/time.h>
#endif
#if HAVE_SYS_SOCKET_H
#include <sys/socket.h>
#endif
#if HAVE_SYS_SELECT_H
#include <sys/select.h>
#endif
#if HAVE_NETINET_IN_H
#include <netinet/in.h>
#endif
#if HAVE_NETINET_TCP_H
#include <netinet/tcp.h>
#endif
#if HAVE_NETDB_H
#include <netdb.h>
#endif
#if HAVE_UNISTD_H
#include <unistd.h>
#endif
#if HAVE_FCNTL_H
#include <fcntl.h>
#endif

#include <sys/types.h>
#include <errno.h>
#include <stdlib.h>
#include <time.h>
#include <string.h>

#include "error_numbers.h"
#include "net_xfer.h"
#include "util.h"
#include "client_types.h"
#include "client_state.h"
#include "message.h"

#if defined(_WIN32) 
typedef int socklen_t;
#elif defined ( __APPLE__)
typedef int32_t socklen_t;
#elif !GETSOCKOPT_SOCKLEN_T
#ifndef socklen_t
typedef size_t socklen_t;
#endif
#endif

int get_socket_error(int fd) {
    socklen_t intsize = sizeof(int);
    int n;
#ifdef WIN32
    getsockopt(fd, SOL_SOCKET, SO_ERROR, (char *)&n, &intsize);
#elif __APPLE__
    getsockopt(fd, SOL_SOCKET, SO_ERROR, &n, (int *)&intsize);
#else
    getsockopt(fd, SOL_SOCKET, SO_ERROR, (void*)&n, &intsize);
#endif
    return n;
}

int NET_XFER::get_ip_addr(char *hostname, int &ip_addr) {
    hostent* hep;

#ifdef WIN32
    int retval;
    retval = NetOpen();
    if (retval) return retval;
#endif
    hep = gethostbyname(hostname);
    if (!hep) {
        char msg[256];
        int n;

        n = sprintf(msg, "Can't resolve hostname %s ", hostname);
#ifdef WIN32

        switch (WSAGetLastError()) {
        case WSANOTINITIALISED:
            break;
        case WSAENETDOWN:
            sprintf(msg+n, "(the network subsystem has failed)");
            break;
        case WSAHOST_NOT_FOUND:
            sprintf(msg+n, "(host name not found)");
            break;
        case WSATRY_AGAIN:
            sprintf(msg+n, "(no response from server)");
            break;
        case WSANO_RECOVERY:
            sprintf(msg+n, "(a nonrecoverable error occurred)");
            break;
        case WSANO_DATA:
            sprintf(msg+n, "(valid name, no data record of requested type)");
            break;
        case WSAEINPROGRESS:
            sprintf(msg+n, "(a blocking socket call in progress)");
            break;
        case WSAEFAULT:
            sprintf(msg+n, "(invalid part of user address space)");
            break;
        case WSAEINTR:
            sprintf(msg+n, "(a blocking socket call was canceled)");
            break;
        }
        NetClose();

#else

        switch (h_errno) {
        case HOST_NOT_FOUND:
            sprintf(msg+n, "(host not found)");
            break;
        case NO_DATA:
            sprintf(msg+n, "(valid name, no data record of requested type)");
            break;
        case NO_RECOVERY:
            sprintf(msg+n, "(a nonrecoverable error occurred)");
            break;
        case TRY_AGAIN:
            sprintf(msg+n, "(host not found or server failure)");
            break;
        }

#endif
        msg_printf(0, MSG_ERROR, "%s\n", msg);
        return ERR_GETHOSTBYNAME;
    }
    ip_addr = *(int*)hep->h_addr_list[0];

    return 0;
}


// Attempt to open a nonblocking socket to a server
//
int NET_XFER::open_server() {
    sockaddr_in addr;
    int fd=0, ipaddr, retval=0;

    retval = get_ip_addr(hostname, ipaddr);
    if (retval) return retval;

    fd = ::socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0) {
#ifdef WIN32
        NetClose();
#endif
        return ERR_SOCKET;
    }

#ifdef WIN32
    unsigned long one = 1;
    ioctlsocket(fd, FIONBIO, &one);
#else
    int flags;
    flags = fcntl(fd, F_GETFL, 0);
    if (flags < 0) {
        return ERR_FCNTL;
    }
    if (fcntl(fd, F_SETFL, flags|O_NONBLOCK) < 0 ) {
        return ERR_FCNTL;
    }
#endif

    addr.sin_family = AF_INET;
    addr.sin_port = htons(port);
    addr.sin_addr.s_addr = ((long)ipaddr);
    retval = connect(fd, (sockaddr*)&addr, sizeof(addr));
    if (retval) {
#ifdef WIN32
        errno = WSAGetLastError();
        if (errno != WSAEINPROGRESS && errno != WSAEWOULDBLOCK) {
            closesocket(fd);
            NetClose();
            return ERR_CONNECT;
        }
#ifndef _CONSOLE
        if (WSAAsyncSelect( fd, g_myWnd->GetSafeHwnd(), g_myWnd->m_nNetActivityMsg, FD_READ|FD_WRITE )) {
            errno = WSAGetLastError();
            if (errno != WSAEINPROGRESS && errno != WSAEWOULDBLOCK) {
                closesocket(fd);
                NetClose();
                return ERR_ASYNCSELECT;
            }
        }
#endif
#else
        if (errno != EINPROGRESS) {
            close(fd);
            perror("connect");
            return ERR_CONNECT;
        }
#endif
    } else {
        is_connected = true;
    }
    socket = fd;
    return 0;
}

void NET_XFER::close_socket() {
#ifdef WIN32
    NetClose();
    if (socket) closesocket(socket);
#else
    if (socket) close(socket);
#endif
}

void NET_XFER::init(char* host, int p, int b) {
    socket = -1;
    is_connected = false;
    want_download = false;
    want_upload = false;
    do_file_io = false;
    io_done = false;
    file = NULL;
    io_ready = false;
    error = 0;
    safe_strcpy(hostname, host);
    port = p;
    blocksize = (b > MAX_BLOCKSIZE ? MAX_BLOCKSIZE : b);
    xfer_speed = 0;
    last_speed_update = dtime();
    file_read_buf_offset = 0;
    file_read_buf_len = 0;
    bytes_xferred = 0;
}

char* NET_XFER::get_hostname() {
    return hostname;
}

NET_XFER_SET::NET_XFER_SET() {
    max_bytes_sec_up = 0;
    max_bytes_sec_down = 0;
    bytes_left_up = 0;
    bytes_left_down = 0;
    bytes_up = 0;
    bytes_down = 0;
    up_active = false;
    down_active = false;
}

// Insert a NET_XFER object into the set
//
int NET_XFER_SET::insert(NET_XFER* nxp) {
    int retval = nxp->open_server();
    if (retval) return retval;
    net_xfers.push_back(nxp);
    return 0;
}

// Remove a NET_XFER object from the set
//
int NET_XFER_SET::remove(NET_XFER* nxp) {
    vector<NET_XFER*>::iterator iter;

    nxp->close_socket();

    iter = net_xfers.begin();
    while (iter != net_xfers.end()) {
        if (*iter == nxp) {
            net_xfers.erase(iter);
            return 0;
        }
        iter++;
    }
    msg_printf(NULL, MSG_ERROR, "NET_XFER_SET::remove(): not found\n");
    return ERR_NOT_FOUND;
}

// Transfer data to/from active sockets.
// Keep doing I/O until would block, or we hit rate limits,
// or about .5 second goes by
//
bool NET_XFER_SET::poll() {
    double bytes_xferred;
    int retval;
    time_t t = time(0);
    bool action = false;

    while (1) {
        retval = do_select(bytes_xferred, 0);
        if (retval) break;
        if (bytes_xferred == 0) break;
        action = true;
        if (time(0) != t) break;
    }
    return action;
}

static void double_to_timeval(double x, timeval& t) {
    t.tv_sec = (int)x;
    t.tv_usec = (int)(1000000*(x - (int)x));
}

// Wait at most x seconds for network I/O to become possible,
// then do up to about .5 seconds of I/O.
//
int NET_XFER_SET::net_sleep(double x) {
    int retval;
    double bytes_xferred;

    retval = do_select(bytes_xferred, x);
    if (retval) return retval;
    if (bytes_xferred) {
        return poll();
    }
    return 0;
}

// do a select with the given timeout,
// then do I/O on as many sockets as possible, subject to rate limits
// Transfer at most one block per socket.
//
int NET_XFER_SET::do_select(double& bytes_transferred, double timeout) {
    int n, fd, retval, nsocks_queried;
    unsigned int i;
    NET_XFER *nxp;
    struct timeval tv;

    ScopeMessages scope_messages(log_messages, ClientMessages::DEBUG_NET_XFER);

    // if a second has gone by, do rate-limit accounting
    //
    time_t t = time(0);
    if (t != last_time) {
        last_time = t;
        if (bytes_left_up < max_bytes_sec_up) {
            bytes_left_up += max_bytes_sec_up;
        }
        if (bytes_left_down < max_bytes_sec_down) {
            bytes_left_down += max_bytes_sec_down;
        }
    }

    bytes_transferred = 0;

    fd_set read_fds, write_fds, error_fds;

    FD_ZERO(&read_fds);
    FD_ZERO(&write_fds);
    FD_ZERO(&error_fds);

    // do a select on all active (non-throttled) sockets
    //
    nsocks_queried = 0;
    for (i=0; i<net_xfers.size(); i++) {
        nxp = net_xfers[i];
        if (!nxp->is_connected) {
            FD_SET(nxp->socket, &write_fds);
            nsocks_queried++;
        } else if (nxp->want_download) {
            if (bytes_left_down > 0) {
                FD_SET(nxp->socket, &read_fds);
                nsocks_queried++;
            } else {
                scope_messages.printf("NET_XFER_SET::do_select(): Throttling download\n");
            }
        } else if (nxp->want_upload) {
            if (bytes_left_up > 0) {
                FD_SET(nxp->socket, &write_fds);
                nsocks_queried++;
            } else {
                scope_messages.printf("NET_XFER_SET::do_select(): Throttling upload\n");
            }
        }
        FD_SET(nxp->socket, &error_fds);
    }
    if (nsocks_queried==0) {
        boinc_sleep(timeout);
        return 0;
    }

    double_to_timeval(timeout, tv);
    n = select(FD_SETSIZE, &read_fds, &write_fds, &error_fds, &tv);
    scope_messages.printf(
        "NET_XFER_SET::do_select(): queried %d, returned %d\n",
        nsocks_queried, n
    );
    if (n == 0) return 0;
    if (n < 0) return ERR_SELECT;

    // if got a descriptor, find the first one in round-robin order
    // and do I/O on it
    // TODO: use round-robin order
    //
    for (i=0; i<net_xfers.size(); i++) {
        nxp = net_xfers[i];
        fd = nxp->socket;
        if (FD_ISSET(fd, &read_fds) || FD_ISSET(fd, &write_fds)) {
            if (FD_ISSET(fd, &read_fds)) {
                scope_messages.printf("NET_XFER_SET::do_select(): read enabled on socket %d\n", fd);
            }
            if (FD_ISSET(fd, &write_fds)) {
                scope_messages.printf("NET_XFER_SET::do_select(): write enabled on socket %d\n", fd);
            }
            if (!nxp->is_connected) {
                n = get_socket_error(fd);
                if (n) {
                    scope_messages.printf(
                        "NET_XFER_SET::do_select(): socket %d connection to %s failed\n",
                        fd, nxp->get_hostname()
                    );
                    nxp->error = ERR_CONNECT;
                    nxp->io_done = true;
                } else {
                    scope_messages.printf("NET_XFER_SET::do_select(): socket %d is connected\n", fd);
                    nxp->is_connected = true;
                    bytes_transferred += 1;
                }
            } else if (nxp->do_file_io) {
                n = 1;
                time_t now = time(0);
                do {
                    retval = nxp->do_xfer(n);
                    nxp->update_speed(n);
                    bytes_transferred += n;
                    if (nxp->want_download) {
                        down_active = true;
                        bytes_left_down -= n;
                        bytes_down += n;
                    } else {
                        up_active = true;
                        bytes_left_up -= n;
                        bytes_up += n;
                    }
                    // For uploads, keep trying to send until we fill
                    // the buffers or 1 second has passed
                } while(nxp->want_upload && n > 0 && time(0) == now);
            } else {
                nxp->io_ready = true;
            }
        } else if (FD_ISSET(fd, &error_fds)) {
            scope_messages.printf("NET_XFER_SET::do_select(): got error on socket %d\n", fd);
            nxp = lookup_fd(fd);
            if (nxp) {
                nxp->got_error();
            } else {
                msg_printf(0, MSG_ERROR, "do_select(): nxp not found\n");
            }
        }
    }
    return 0;
}

// Return the NET_XFER object whose socket matches fd
//
NET_XFER* NET_XFER_SET::lookup_fd(int fd) {
    for (unsigned int i=0; i<net_xfers.size(); i++) {
        if (net_xfers[i]->socket == fd) {
            return net_xfers[i];
        }
    }
    return 0;
}

// transfer up to a block of data; return #bytes transferred
//
int NET_XFER::do_xfer(int& nbytes_transferred) {
    // Leave these as signed ints so recv/send can return errors
    int n, m, nleft;
    bool would_block;
    char buf[MAX_BLOCKSIZE];

    nbytes_transferred = 0;

    ScopeMessages scope_messages(log_messages, ClientMessages::DEBUG_NET_XFER);

    if (want_download) {
#ifdef WIN32
        n = recv(socket, buf, blocksize, 0);
#else
        n = read(socket, buf, blocksize);
#endif
        scope_messages.printf("NET_XFER::do_xfer(): read %d bytes from socket %d\n", n, socket);
        if (n == 0) {
            io_done = true;
            want_download = false;
        } else if (n < 0) {
            io_done = true;
            error = ERR_READ;
        } else {
            nbytes_transferred += n;
            bytes_xferred += n;
            m = fwrite(buf, 1, n, file);
            if (n != m) {
                fprintf(stdout, "Error: incomplete disk write\n");
                io_done = true;
                error = ERR_FWRITE;
            }
        }
    } else if (want_upload) {
        // If we've sent the current contents of
        // the buffer, then read the next block
        if (file_read_buf_len == file_read_buf_offset) {
            m = fread(file_read_buf, 1, blocksize, file);
            if (m == 0) {
                want_upload = false;
                io_done = true;
                return 0;
            } else if (m < 0) {
                io_done = true;
                error = ERR_FREAD;
                return 0;
            }
            file_read_buf_len = m;
            file_read_buf_offset = 0;
        }
        nleft = file_read_buf_len - file_read_buf_offset;
        while (nleft) {
#ifdef WIN32
            n = send(socket, file_read_buf+file_read_buf_offset, nleft, 0);
            would_block = (WSAGetLastError() == WSAEWOULDBLOCK);
#else
            n = write(socket, file_read_buf+file_read_buf_offset, nleft);
            would_block = (errno == EAGAIN);
#endif
            if (would_block && n < 0) n = 0;
            scope_messages.printf("NET_XFER::do_xfer(): wrote %d bytes to socket %d%s\n", n, socket,
                                  (would_block?", would have blocked":""));
            if (n < 0 && !would_block) {
                error = ERR_WRITE;
                io_done = true;
                break;
            }

            file_read_buf_offset += n;
            nbytes_transferred += n;
            bytes_xferred += n;

            if (n < nleft || would_block) {
                break;
            }

            nleft -= n;
        }
    }
    return 0;
}

// Update the transfer speed for this NET_XFER
// Decay speed by 1/2 every 3 seconds
// This is called after by do_xfer() (i.e. on every I/O)
//
void NET_XFER::update_speed(int nbytes) {
    double now, delta_t;
    double x, recent_bytes;

    now = dtime();
    recent_bytes = (double) nbytes;
    delta_t = now-last_speed_update;
    if (delta_t<=0) return;
    x = exp(-delta_t*log(2.0)/3.0);
    xfer_speed *= x;
    xfer_speed += recent_bytes*(1-x);
    last_speed_update = now;
}

void NET_XFER::got_error() {
    error = ERR_IO;
    io_done = true;
    log_messages.printf(ClientMessages::DEBUG_NET_XFER,
                        "IO error on socket %d\n", socket);
}

// return true if an upload is currently in progress
// or has been since the last call to this.
// Similar for download.
//
void NET_XFER_SET::check_active(bool& up, bool& down) {
    unsigned int i;
    NET_XFER* nxp;

    up = up_active;
    down = down_active;
    for (i=0; i<net_xfers.size(); i++) {
        nxp = net_xfers[i];
        if (nxp->is_connected && nxp->do_file_io) {
            nxp->want_download?down=true:up=true;
        }
    }
    up_active = false;
    down_active = false;
}