/*
* socket.cpp
*
* (C) 2000-2002 Murat Deligonul
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*
*/
#include <sys/types.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#ifdef HAVE_POLL_H
# include <poll.h>
#endif
#ifdef HAVE_SYS_POLL_H
# include <sys/poll.h>
#endif
#include "autoconf.h"
#include "socket.h"
#include "server.h"
#include "debug.h"
int pollsocket::num_socks = 0;
int pollsocket::MAX_SOCKETS = 0;
pollsocket **pollsocket::table = 0;
#ifdef _USE_SSL
SSL_CTX *pollsocket::ssl_ctx = 0;
char *pollsocket::tls_rand_file = 0;
#endif
#ifdef USE_SELECT
fd_set pollsocket::fds_read = { };
fd_set pollsocket::fds_write = { };
fd_set pollsocket::res_read = { };
fd_set pollsocket::res_write = { };
int pollsocket::highest_fd = -1;
#warning -----------------------------------------------
#warning Your system does not appear to support poll():
#warning Using select().
#warning -----------------------------------------------
#else
struct pollfd * pollsocket::pfds = 0;
int pollsocket::num_pfds = 0;
/* static */ struct pollfd * pollsocket::find_empty_slot(void)
{
for (int i = 0; i < num_pfds; i++)
if (pfds[i].fd == -1)
return &pfds[i];
return 0;
}
#endif
/* static */ int pollsocket::create_table(int m)
{
delete[] pollsocket::table;
pollsocket::MAX_SOCKETS = m;
pollsocket::table = new pollsocket*[m];
for (int i = 0; i < MAX_SOCKETS; ++i)
pollsocket::table[i] = 0;
return m;
}
pollsocket::pollsocket(int f, bool *success, int events, int buffMin, int buffMax)
{
assert(f >= 0);
fd = f;
ibuff = obuff = 0;
#ifdef _USE_SSL
// tls_rand_file = NULL;
ssl = NULL;
#endif
if (f >= MAX_SOCKETS)
{
*success = 0;
fd = -1;
return;
}
table[f] = this;
#ifdef USE_SELECT
if (!num_socks)
{
FD_ZERO(&fds_read);
FD_ZERO(&fds_write);
FD_ZERO(&res_read);
FD_ZERO(&res_write);
}
if (events & POLLIN)
FD_SET(fd, &fds_read);
if (events & POLLOUT)
FD_SET(fd, &fds_write);
if (fd > highest_fd)
highest_fd = fd;
#else
/* Register this file descriptor */
struct pollfd * pfd = find_empty_slot();
if (!pfd)
{
if (!pfds)
{
pfds = pfd = (struct pollfd *) malloc(sizeof(struct pollfd));
num_pfds = 1;
idx = 0;
}
else
{
pfds = (struct pollfd *) realloc(pfds, sizeof(struct pollfd) * ++num_pfds);
pfd = &pfds[num_pfds - 1];
idx = num_pfds - 1;
}
} else
{
DEBUG("pollsocket::pollsocket() -- Base is: %p diff is %d\n",pfds, pfd - pfds);
idx = (pfd - pfds);
}
DEBUG("pollsocket::pollsocket() [%p] @ %p idx: %d fd: %d\n", this, pfd, idx, f);
pfd->fd = f;
pfd->events = events;
pfd->revents = 0;
#endif
/* Create buffers etc.*/
if (buffMin > 0)
{
ibuff = new dynbuff(buffMin, buffMax);
obuff = new dynbuff(buffMin, buffMax);
}
num_socks++;
*success = 1;
}
int pollsocket::close()
{
DEBUG("pollsocket::close() for %p\n", this);
delete ibuff;
delete obuff;
obuff = ibuff = 0;
if (fd > -1)
{
#ifdef _USE_SSL
if (ssl) {
DEBUG("Shutting down SSL ...\n");
SSL_shutdown(ssl);
SSL_free(ssl);
ssl=NULL;
}
#endif
::close(fd);
#ifdef USE_SELECT
FD_CLR(fd, &fds_write);
FD_CLR(fd, &fds_read);
#else
if (idx > -1 && idx + 1 == num_pfds)
num_pfds--;
struct pollfd * pfd = &pfds[idx];
pfd->fd = -1;
pfd->events = pfd->revents = 0;
idx = -1;
#endif
num_socks--;
if (fd < MAX_SOCKETS)
table[fd] = 0;
fd = -1;
return 1;
}
return 0;
}
pollsocket::~pollsocket()
{
#ifdef USE_SELECT
close();
#else
close();
if (!num_socks)
{
free(pfds);
pfds = 0;
num_pfds = 0;
}
DEBUG("pollsocket::~pollsocket(): %p num_socks: %d num_pfds: %d\n", this, num_socks, num_pfds);
#endif
}
int pollsocket::set_revents(int revents)
{
#ifdef USE_SELECT
if (fd < 0)
return 0;
FD_CLR(fd, &res_read);
FD_CLR(fd, &res_write);
if (revents & POLLIN)
FD_SET(fd, &res_read);
if (revents & POLLOUT)
FD_SET(fd, &res_write);
#else
if (fd > -1)
pfds[idx].revents = revents;
else
return 0;
#endif
return revents;
}
int pollsocket::set_events(int events)
{
#ifdef USE_SELECT
if (fd < 0)
return 0;
FD_CLR(fd, &fds_read);
FD_CLR(fd, &fds_write);
if (events & POLLIN)
FD_SET(fd, &fds_read);
if (events & POLLOUT)
FD_SET(fd, &fds_write);
#else
if (fd > -1)
pfds[idx].events = events;
else
return 0;
#endif
return events;
}
/* static */ int pollsocket::poll_all(int wait)
{
#ifdef USE_SELECT
struct timeval tv;
int secs = (wait / 1000);
tv.tv_sec = secs;
tv.tv_usec = 0;
memcpy(&res_read, &fds_read, sizeof(fd_set));
memcpy(&res_write, &fds_write, sizeof(fd_set));
int r = select(highest_fd + 1, &res_read, &res_write, (fd_set *) 0, &tv);
if (r > 0)
{
pollsocket * p;
update_time_ctr();
for (int i = 0; i < MAX_SOCKETS; ++i)
{
p = table[i];
if (!p)
continue;
if (p->fd > -1)
{
/* Since the event handler expects a valid pollfd structure,
* we must create one */
struct pollfd pfd;
pfd.revents = 0;
pfd.events = 0;
pfd.fd = p->fd;
if (FD_ISSET(p->fd, &res_read))
pfd.revents |= POLLIN;
else if (FD_ISSET(p->fd, &res_write))
pfd.revents |= POLLOUT;
if (pfd.revents)
{
int fd = p->fd;
p->event_handler(&pfd);
FD_CLR(fd, &res_read);
FD_CLR(fd, &res_write);
}
}
}
}
#else
int r = poll(pfds, num_pfds, wait);
if (r > 0)
{
update_time_ctr();
for (int i = 0; i < num_pfds; i++)
if (pfds[i].fd > -1 && pfds[i].revents)
table[pfds[i].fd]->event_handler(&pfds[i]);
}
#endif
return r;
}
#ifdef USE_SELECT
int pollsocket::events() const
{
int r = 0;
if (fd < 0)
return 0;
if (FD_ISSET(fd, &fds_read))
r |= POLLIN;
if (FD_ISSET(fd, &fds_write))
r |= POLLOUT;
return r;
}
int pollsocket::revents() const
{
int r = 0;
if (fd < 0)
return 0;
if (FD_ISSET(fd, &res_read))
r |= POLLIN;
if (FD_ISSET(fd, &res_write))
r |= POLLOUT;
return r;
}
int pollsocket::revents(int flags) const
{
return (revents() & flags);
}
#endif
/* static */ int pollsocket::compress(void)
{
#ifdef USE_SELECT
/* FIXME: We don't need to do anything. However, we could loop
* through and find a new highest_fd */
return 0;
#else
int j = 0;
if (num_pfds == num_socks)
return 0;
DEBUG("pollsocket::compress(): entering with array size of %d\n", num_pfds);
struct pollfd * new_pfds = (struct pollfd *) malloc(num_socks * sizeof(pollfd));
memset(new_pfds, 0, sizeof(pollfd) * num_socks);
for (int i = 0; i < num_pfds; ++i)
{
if (pfds[i].fd > -1)
{
memcpy(&new_pfds[j], &pfds[i], sizeof(struct pollfd));
table[new_pfds[j].fd]->idx = j;
++j;
}
}
free(pfds);
pfds = new_pfds;
num_pfds = j;
DEBUG("pollsocket::compress(): Exiting with new array size of %d\n", num_pfds);
return 1;
#endif
}
void pollsocket::optimize_buffers(void)
{
if (obuff)
obuff->optimize();
if (ibuff)
ibuff->optimize();
}
int pollsocket::printf(const char * format, ...)
{
va_list ap;
va_start(ap, format);
printf_raw(format, &ap);
va_end(ap);
return flushO();
}
/* TODO
int pollsocket::printfQ(const char * format, ...)
{
}
*/
int pollsocket::printf_raw(const char * format, va_list * ap)
{
extern char __mbuffer[1024];
int len;
__mbuffer[0] = 0;
len = vsnprintf(__mbuffer, sizeof(__mbuffer), format, *ap);
return queue(__mbuffer, len);
}
#ifdef _USE_SSL
/* static */ int pollsocket::init_ssl(const char * certfile)
{
if (ssl_ctx == NULL) {
DEBUG("Initializing SSL...\n");
SSL_load_error_strings();
OpenSSL_add_ssl_algorithms();
ssl_ctx=SSL_CTX_new(SSLv23_method());
if (!ssl_ctx) {
DEBUG("SSL_CTX_new() failed\n");
return 0;
}
if (seed_PRNG()) {
DEBUG("Wasn't able to properly seed the PRNG!\n");
if (ssl_ctx != NULL)
SSL_CTX_free(ssl_ctx);
ssl_ctx=NULL;
return 0;
}
SSL_CTX_use_certificate_file(ssl_ctx, certfile,SSL_FILETYPE_PEM);
SSL_CTX_use_RSAPrivateKey_file(ssl_ctx,certfile,SSL_FILETYPE_PEM);
if (!SSL_CTX_check_private_key(ssl_ctx)) {
DEBUG("Error loading private key/certificate, set correct file in options...\n");
if (ssl_ctx != NULL)
SSL_CTX_free(ssl_ctx);
ssl_ctx=NULL;
return 0;
}
}
return 1;
}
/* static */ int pollsocket::shutdown_ssl(void)
{
if (ssl_ctx) {
DEBUG("Freeing SSL context...");
SSL_CTX_free(ssl_ctx);
ssl_ctx = NULL;
}
if (tls_rand_file) {
RAND_write_file(tls_rand_file);
}
return 1;
}
bool pollsocket::switch_to_ssl()
{
int err;
if (ssl) {
return 1;
}
ssl = SSL_new(ssl_ctx);
if (!ssl) {
DEBUG("SSL_new() failed\n");
return 0;
}
SSL_set_fd(ssl, fd);
err = SSL_connect(ssl);
while (err <= 0) {
if (!BIO_sock_should_retry(err)) {
DEBUG("Error while SSL_connect()\n");
SSL_shutdown(ssl);
SSL_free(ssl);
ssl = NULL;
return 0;
}
usleep(1000);
err = SSL_connect(ssl);
}
if (err==1) {
return 1;
}
DEBUG("Error while SSL_connect()");
SSL_shutdown(ssl);
SSL_free(ssl);
ssl = NULL;
return 0;
}
bool pollsocket::accept_to_ssl()
{
int err;
if (ssl) {
return 1;
}
ssl = SSL_new(ssl_ctx);
if (!ssl) {
DEBUG("SSL_new() failed\n");
return 0;
}
SSL_set_fd(ssl, fd);
err = SSL_accept(ssl);
while (err <= 0) {
if (!BIO_sock_should_retry(err)) {
DEBUG("Error while SSL_accept(): %s\n", ERR_error_string(ERR_get_error(),NULL));
SSL_shutdown(ssl);
SSL_free(ssl);
ssl = NULL;
return 0;
}
usleep(1000);
err = SSL_accept(ssl);
}
return 1;
}
int pollsocket::seed_PRNG(void)
{
char stackdata[1024];
static char rand_file[300];
FILE *fh;
#if OPENSSL_VERSION_NUMBER >= 0x00905100
if (RAND_status())
return 0; /* PRNG already good seeded */
#endif
/* if the device '/dev/urandom' is present, OpenSSL uses it by default.
* check if it's present, else we have to make random data ourselfs.
*/
if ((fh = fopen("/dev/urandom", "r"))) {
fclose(fh);
return 0;
}
if (RAND_file_name(rand_file, sizeof(rand_file)))
tls_rand_file = rand_file;
else
return 1;
if (!RAND_load_file(rand_file, 1024)) {
/* no .rnd file found, create new seed */
unsigned int c;
c = time(NULL);
RAND_seed(&c, sizeof(c));
c = getpid();
RAND_seed(&c, sizeof(c));
RAND_seed(stackdata, sizeof(stackdata));
}
#if OPENSSL_VERSION_NUMBER >= 0x00905100
if (!RAND_status())
return 2; /* PRNG still badly seeded */
#endif
return 0;
}
#endif
syntax highlighted by Code2HTML, v. 0.9.1