/*
 * socket.cpp
 *
 * (C) 2000-2002 Murat Deligonul
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 */

#include <sys/types.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>

#ifdef HAVE_POLL_H
#   include <poll.h>
#endif

#ifdef HAVE_SYS_POLL_H
#   include <sys/poll.h>
#endif


#include "autoconf.h"
#include "socket.h"
#include "server.h"
#include "debug.h"

int             pollsocket::num_socks = 0;
int             pollsocket::MAX_SOCKETS = 0;
pollsocket      **pollsocket::table = 0;

#ifdef _USE_SSL
SSL_CTX         *pollsocket::ssl_ctx = 0;
char            *pollsocket::tls_rand_file = 0;
#endif

#ifdef USE_SELECT
fd_set pollsocket::fds_read  = { };
fd_set pollsocket::fds_write = { };
fd_set pollsocket::res_read = { };
fd_set pollsocket::res_write = { };
int pollsocket::highest_fd = -1;

#warning -----------------------------------------------
#warning Your system does not appear to support poll():
#warning Using select().
#warning -----------------------------------------------

#else
struct pollfd * pollsocket::pfds = 0;
int             pollsocket::num_pfds = 0;

/* static */ struct pollfd * pollsocket::find_empty_slot(void)
{
    for (int i = 0; i < num_pfds; i++)
        if (pfds[i].fd == -1)
            return &pfds[i];

    return 0;
}

#endif

/* static */ int pollsocket::create_table(int m)
{
    delete[] pollsocket::table;
    pollsocket::MAX_SOCKETS = m;
    pollsocket::table = new pollsocket*[m];
    for (int i = 0; i < MAX_SOCKETS; ++i)
        pollsocket::table[i] = 0;
    return m;
}

pollsocket::pollsocket(int f, bool *success, int events, int buffMin, int buffMax)
{
    assert(f >= 0);

    fd = f;
    ibuff = obuff = 0;

#ifdef _USE_SSL
//    tls_rand_file = NULL;
    ssl = NULL;
#endif

    if (f >= MAX_SOCKETS)
    {
        *success = 0;
        fd = -1;
        return;
    }
    table[f] = this;

#ifdef USE_SELECT
    if (!num_socks)
    {
        FD_ZERO(&fds_read);
        FD_ZERO(&fds_write);
        FD_ZERO(&res_read);
        FD_ZERO(&res_write);
    }
    if (events & POLLIN)
        FD_SET(fd, &fds_read);
    if (events & POLLOUT)
        FD_SET(fd, &fds_write);
    if (fd > highest_fd)
        highest_fd = fd;
#else
    /* Register this file descriptor */
    struct pollfd * pfd = find_empty_slot();
    if (!pfd)
    {
        if (!pfds)
        {
            pfds = pfd = (struct pollfd *) malloc(sizeof(struct pollfd));
            num_pfds = 1;
            idx = 0;
        }
        else
        {
            pfds = (struct pollfd *) realloc(pfds, sizeof(struct pollfd) * ++num_pfds);
            pfd = &pfds[num_pfds - 1];
            idx = num_pfds - 1;
        }
    } else
    {
        DEBUG("pollsocket::pollsocket() -- Base is: %p diff is %d\n",pfds, pfd - pfds);
        idx = (pfd - pfds);
    }
    DEBUG("pollsocket::pollsocket() [%p] @ %p idx: %d fd: %d\n", this, pfd, idx, f);
    pfd->fd = f;
    pfd->events = events;
    pfd->revents = 0;

#endif
    /* Create buffers etc.*/
    if (buffMin > 0)
    {
        ibuff = new dynbuff(buffMin, buffMax);
        obuff = new dynbuff(buffMin, buffMax);
    }
    num_socks++;
    *success = 1;
}

int pollsocket::close()
{
    DEBUG("pollsocket::close() for %p\n", this);

    delete ibuff;
    delete obuff;
    obuff = ibuff = 0;

    if (fd > -1)
    {
#ifdef _USE_SSL
		if (ssl) {
			DEBUG("Shutting down SSL ...\n");
    			SSL_shutdown(ssl);
			SSL_free(ssl);
			ssl=NULL;			
  		}
#endif

        ::close(fd);
#ifdef USE_SELECT
        FD_CLR(fd, &fds_write);
        FD_CLR(fd, &fds_read);
#else
        if (idx > -1 && idx + 1 == num_pfds)
            num_pfds--;

        struct pollfd * pfd = &pfds[idx];
        pfd->fd = -1;
        pfd->events = pfd->revents = 0;
        idx = -1;
#endif
        num_socks--;

        if (fd < MAX_SOCKETS)
            table[fd] = 0;
        fd = -1;
        return 1;
    }
    return 0;
}

pollsocket::~pollsocket()
{
#ifdef USE_SELECT
    close();
#else
    close();

    if (!num_socks)
    {
        free(pfds);
        pfds = 0;
        num_pfds = 0;
    }
    DEBUG("pollsocket::~pollsocket(): %p num_socks: %d num_pfds: %d\n", this, num_socks, num_pfds);
#endif
}

int pollsocket::set_revents(int revents)
{
#ifdef USE_SELECT
    if (fd < 0)
        return 0;
    FD_CLR(fd, &res_read);
    FD_CLR(fd, &res_write);
    if (revents & POLLIN)
        FD_SET(fd, &res_read);
    if (revents & POLLOUT)
        FD_SET(fd, &res_write);
#else
    if (fd > -1)
        pfds[idx].revents = revents;
    else
        return 0;
#endif
    return revents;
}

int pollsocket::set_events(int events)
{
#ifdef USE_SELECT
    if (fd < 0)
        return 0;
    FD_CLR(fd, &fds_read);
    FD_CLR(fd, &fds_write);
    if (events & POLLIN)
        FD_SET(fd, &fds_read);
    if (events & POLLOUT)
        FD_SET(fd, &fds_write);
#else
    if (fd > -1)
        pfds[idx].events = events;
    else
        return 0;
#endif
    return events;
}

/* static */ int pollsocket::poll_all(int wait)
{
#ifdef USE_SELECT
    struct timeval tv;
    int secs = (wait / 1000);
    tv.tv_sec = secs;
    tv.tv_usec = 0;
    memcpy(&res_read, &fds_read, sizeof(fd_set));
    memcpy(&res_write, &fds_write, sizeof(fd_set));

    int r = select(highest_fd + 1, &res_read, &res_write, (fd_set *) 0, &tv);

    if (r > 0)
    {
        pollsocket * p;
        update_time_ctr();
        for (int i = 0; i < MAX_SOCKETS; ++i)
        {
            p = table[i];
            if (!p)
                continue;

            if (p->fd > -1)
            {
                /* Since the event handler expects a valid pollfd structure,
                 * we must create one */
                struct pollfd pfd;
                pfd.revents = 0;
                pfd.events = 0;
                pfd.fd = p->fd;
                if (FD_ISSET(p->fd, &res_read))
                    pfd.revents |= POLLIN;
                else if (FD_ISSET(p->fd, &res_write))
                    pfd.revents |= POLLOUT;
                if (pfd.revents)
                {
                    int fd = p->fd;
                    p->event_handler(&pfd);
                    FD_CLR(fd, &res_read);
                    FD_CLR(fd, &res_write);
                }
            }
        }
    }

#else
    int r = poll(pfds, num_pfds, wait);
    if (r > 0)
    {
        update_time_ctr();
        for (int i = 0; i < num_pfds; i++)
            if (pfds[i].fd > -1 && pfds[i].revents)
                table[pfds[i].fd]->event_handler(&pfds[i]);
    }
#endif
    return r;
}

#ifdef USE_SELECT
int pollsocket::events() const
{
	int r = 0;
	if (fd < 0)
		return 0;
	if (FD_ISSET(fd, &fds_read))
		r |= POLLIN;
	if (FD_ISSET(fd, &fds_write))
		r |= POLLOUT;
	return r;
}

int pollsocket::revents() const
{
    int r = 0;
    if (fd < 0)
        return 0;
    if (FD_ISSET(fd, &res_read))
        r |= POLLIN;
    if (FD_ISSET(fd, &res_write))
        r |= POLLOUT;
    return r;
}

int pollsocket::revents(int flags) const
{
    return (revents() & flags);
}
#endif

/* static */ int pollsocket::compress(void)
{
#ifdef USE_SELECT
    /* FIXME: We don't need to do anything. However, we could loop
     * through and find a new highest_fd */
    return 0;
#else
    int j = 0;
    if (num_pfds == num_socks)
        return 0;
    DEBUG("pollsocket::compress(): entering with array size of %d\n", num_pfds);

    struct pollfd * new_pfds = (struct pollfd *) malloc(num_socks * sizeof(pollfd));
    memset(new_pfds, 0, sizeof(pollfd) * num_socks);
    for (int i = 0; i < num_pfds; ++i)
    {
        if (pfds[i].fd > -1)
        {
            memcpy(&new_pfds[j], &pfds[i], sizeof(struct pollfd));
            table[new_pfds[j].fd]->idx = j;
            ++j;
        }
    }
    free(pfds);
    pfds = new_pfds;
    num_pfds = j;
    DEBUG("pollsocket::compress(): Exiting with new array size of %d\n", num_pfds);
    return 1;
#endif
}

void pollsocket::optimize_buffers(void)
{
    if (obuff)
        obuff->optimize();
    if (ibuff)
        ibuff->optimize();
}

int pollsocket::printf(const char * format, ...)
{
    va_list ap;
    va_start(ap, format);
    printf_raw(format, &ap);
    va_end(ap);
    return flushO();
}

/* TODO
int pollsocket::printfQ(const char * format, ...)
{
}
*/

int pollsocket::printf_raw(const char * format, va_list * ap)
{
    extern char __mbuffer[1024];
    int len;
    __mbuffer[0] = 0;
    len = vsnprintf(__mbuffer, sizeof(__mbuffer), format, *ap);
    return queue(__mbuffer, len);
}


#ifdef _USE_SSL

/* static */ int pollsocket::init_ssl(const char * certfile)
{
    if (ssl_ctx == NULL) {
	    DEBUG("Initializing SSL...\n");
      	SSL_load_error_strings();
  	    OpenSSL_add_ssl_algorithms();
      	ssl_ctx=SSL_CTX_new(SSLv23_method());

      	if (!ssl_ctx)  {
       		DEBUG("SSL_CTX_new() failed\n");
        	return 0;
       	}

      	if (seed_PRNG()) {  		
    		DEBUG("Wasn't able to properly seed the PRNG!\n");	
            if (ssl_ctx != NULL)
                SSL_CTX_free(ssl_ctx);
            ssl_ctx=NULL;
    		return 0;
	    }

        SSL_CTX_use_certificate_file(ssl_ctx, certfile,SSL_FILETYPE_PEM);
        SSL_CTX_use_RSAPrivateKey_file(ssl_ctx,certfile,SSL_FILETYPE_PEM);
        if (!SSL_CTX_check_private_key(ssl_ctx)) {
                DEBUG("Error loading private key/certificate, set correct file in options...\n");
                if (ssl_ctx != NULL)
                    SSL_CTX_free(ssl_ctx);
                ssl_ctx=NULL;
                return 0;
        }
    }
    return 1;
}

/* static */ int pollsocket::shutdown_ssl(void)
{
    if (ssl_ctx) {
	    DEBUG("Freeing SSL context...");
		SSL_CTX_free(ssl_ctx);
		ssl_ctx = NULL;
	}
    if (tls_rand_file) {
			RAND_write_file(tls_rand_file);
    }
	return 1;
}

bool pollsocket::switch_to_ssl()
{
    int err;
    if (ssl) {
		return 1;
    }

    ssl = SSL_new(ssl_ctx);
    if (!ssl) {
		DEBUG("SSL_new() failed\n");
		return 0;
    }
    SSL_set_fd(ssl, fd);
    err = SSL_connect(ssl);

    while (err <= 0) {
        if (!BIO_sock_should_retry(err)) {
    	    DEBUG("Error while SSL_connect()\n");
            SSL_shutdown(ssl);
            SSL_free(ssl);
            ssl = NULL;
            return 0;
    	}
        usleep(1000);
        err = SSL_connect(ssl);
    }

    if (err==1) {
         return 1;
    }
    DEBUG("Error while SSL_connect()");
    SSL_shutdown(ssl);
    SSL_free(ssl);
    ssl = NULL;
    return 0;
}

bool pollsocket::accept_to_ssl()
{
    int err;
    if (ssl) {
		return 1;
    }

    ssl = SSL_new(ssl_ctx);
    if (!ssl) {
		DEBUG("SSL_new() failed\n");
		return 0;
    }
    SSL_set_fd(ssl, fd);
    err = SSL_accept(ssl);

    while (err <= 0) {
        if (!BIO_sock_should_retry(err)) {
	        DEBUG("Error while SSL_accept(): %s\n", ERR_error_string(ERR_get_error(),NULL));
            SSL_shutdown(ssl);
            SSL_free(ssl);
            ssl = NULL;
            return 0;
    	}
        usleep(1000);
        err = SSL_accept(ssl);
    }
    return 1;
}

int pollsocket::seed_PRNG(void)
{
    char stackdata[1024];
    static char rand_file[300];
    FILE *fh;

#if OPENSSL_VERSION_NUMBER >= 0x00905100
    if (RAND_status())
    	return 0;     /* PRNG already good seeded */
#endif
    /* if the device '/dev/urandom' is present, OpenSSL uses it by default.
     * check if it's present, else we have to make random data ourselfs.
     */
    if ((fh = fopen("/dev/urandom", "r"))) {
    	fclose(fh);
    	return 0;
    }
    if (RAND_file_name(rand_file, sizeof(rand_file)))
    	tls_rand_file = rand_file;
    else
    	return 1;
    if (!RAND_load_file(rand_file, 1024)) {
	    /* no .rnd file found, create new seed */
    	unsigned int c;
    	c = time(NULL);
    	RAND_seed(&c, sizeof(c));
    	c = getpid();
    	RAND_seed(&c, sizeof(c));
    	RAND_seed(stackdata, sizeof(stackdata));
    }
#if OPENSSL_VERSION_NUMBER >= 0x00905100
    if (!RAND_status())
	    return 2;   /* PRNG still badly seeded */
#endif
    return 0;
}

#endif


syntax highlighted by Code2HTML, v. 0.9.1