/*
 * $Id: srv2.c,v 1.3 2006/02/16 19:19:40 ca Exp $
 */

#include "sm/generic.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <time.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <fcntl.h>
#include <signal.h>
#include <pwd.h>
#include "st.h"
#include "error.h"

#include <ctype.h>

#if !HAVE_SNPRINTF
# define snprintf sm_snprintf
# include "sm/string.h"
#endif /* !HAVE_SNPRINTF */

/* Default server port */
#define SERV_PORT_DEFAULT 8000
#define REPS	10
#define REPS	10

#define IOBUFSIZE (16*1024)

/* Socket listen queue size */
#define LISTENQ_SIZE_DEFAULT 256

/* Request read timeout (in seconds) */
#define SEC2USEC(s) ((s)*1000000LL)
#define REQUEST_TIMEOUT SEC2USEC(30)

struct socket_info
{
	st_netfd_t	 nfd;	/* Listening socket	*/
	char		*addr;	/* Bind address		*/
	int		 port;	/* Port			*/
} srv_socket;			/* listening socket	*/

static int serialize_accept = 0;
static int listenq_size	= LISTENQ_SIZE_DEFAULT;
static int errfd	= STDERR_FILENO;
static int debug	= 0;
static int retrd	= -1;
static int retwr	= -1;

#ifndef INADDR_NONE
#define INADDR_NONE 0xffffffff
#endif



/******************************************************************/

static void usage(const char *progname)
{
	fprintf(stderr, "Usage: %s -l <log_directory> [<options>]\n\n"
	  "Possible options:\n\n"
	  "\t-b <host>:<port>        Bind to specified address."
	  "\t-q <backlog>            Set max length of pending connections"
	  " queue.\n"
	  "\t-S                      Serialize all accept() calls.\n"
	  "\t-h                      Print this message.\n",
	  progname);
	exit(1);
}


/******************************************************************/

static void parse_arguments(int argc, char *argv[])
{
	extern char *optarg;
	int opt;
	char *c;

	while ((opt = getopt(argc, argv, "b:d:D:p:l:t:u:q:aiShw:")) != -1)
	{
		switch (opt)
		{
		  case 'b':
			if ((c = strdup(optarg)) == NULL)
				err_sys_quit(errfd, "ERROR: strdup");
			srv_socket.addr = c;
			break;
		  case 'd':
			debug = atoi(optarg);
			break;
		  case 'q':
			listenq_size = atoi(optarg);
			if (listenq_size < 1)
				err_quit(errfd, "ERROR: invalid listen queue size: %s", optarg);
			break;
		  case 'S':
			serialize_accept = 1;
			break;
		  case 'h':
		  case '?':
			usage(argv[0]);
		}
	}
}

/******************************************************************/

static void
create_listeners(void)
{
	int n, sock;
	char *c;
	struct sockaddr_in serv_addr;
	struct hostent *hp;
	short port;

	port = 0;
	if (srv_socket.addr != NULL &&
	    (c = strchr(srv_socket.addr, ':')) != NULL)
	{
		*c++ = '\0';
		port = (short) atoi(c);
	}
	if (srv_socket.addr == NULL || srv_socket.addr[0] == '\0')
		srv_socket.addr = "0.0.0.0";
	if (port == 0)
		port = SERV_PORT_DEFAULT;

	/* Create server socket */
	if ((sock = socket(PF_INET, SOCK_STREAM, 0)) < 0)
		err_sys_quit(errfd, "ERROR: can't create socket: socket");
	n = 1;
	if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char *)&n, sizeof(n))
	    < 0)
		err_sys_quit(errfd,
			"ERROR: can't set SO_REUSEADDR: setsockopt");
	memset(&serv_addr, 0, sizeof(serv_addr));
	serv_addr.sin_family = AF_INET;
	serv_addr.sin_port = htons(port);
	serv_addr.sin_addr.s_addr = inet_addr(srv_socket.addr);
	if (serv_addr.sin_addr.s_addr == INADDR_NONE)
	{
		/* not dotted-decimal */
		if ((hp = gethostbyname(srv_socket.addr)) == NULL)
			err_quit(errfd, "ERROR: can't resolve address: %s",
				srv_socket.addr);
		memcpy(&serv_addr.sin_addr, hp->h_addr, hp->h_length);
	}
	srv_socket.port = port;

	/* Do bind and listen */
	if (bind(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0)
		err_sys_quit(errfd, "ERROR: can't bind to address %s, port %d",
			srv_socket.addr, port);
	if (listen(sock, listenq_size) < 0)
		err_sys_quit(errfd, "ERROR: listen");

	/* Create file descriptor object from OS socket */
	if ((srv_socket.nfd = st_netfd_open_socket(sock)) == NULL)
		err_sys_quit(errfd, "ERROR: st_netfd_open_socket");
	if (serialize_accept && st_netfd_serialize_accept(srv_socket.nfd) < 0)
		err_sys_quit(errfd, "ERROR: st_netfd_serialize_accept");
}


/******************************************************************/

static void *
wrsrv(void *arg)
{
	int n, r, l;
	st_netfd_t rmt_nfd;
	char buf[IOBUFSIZE];

	rmt_nfd = (st_netfd_t) arg;
	for (n = 0; n < REPS; n++)
	{
		if (n < REPS - 1)
			snprintf(buf, sizeof(buf), "SRV %d\n", n);
		else
			snprintf(buf, sizeof(buf), "QUIT %d\n", n);
		l = strlen(buf);
		r = st_write(rmt_nfd, buf, l, REQUEST_TIMEOUT);
		if (r != l)
		{
			fprintf(stderr,
				"error write s='%s' n=%d, l=%d, r=%d, errno=%d\n",
				buf, n, l, r, errno);
			retwr = -1;
			return (void *) &retwr;
		}
		st_usleep(1);
	}
	retwr = 1;
	return (void *) &retwr;
}

static void *
rdsrv(void *arg)
{
	int n;
	st_netfd_t rmt_nfd;
	char buf[IOBUFSIZE];
	st_utime_t t1, t2;

	rmt_nfd = (st_netfd_t) arg;
	while (1)
	{
		t1 = st_utime();
		n = st_read(rmt_nfd, buf, IOBUFSIZE, REQUEST_TIMEOUT);
		t2 = st_utime();
		if (n == 0)
		{
			retrd = 0;
			return (void *) &retrd;
		}
		if (n < 0)
		{
			fprintf(stderr,
				"error read n=%d, errno=%d\n", n, errno);
			fprintf(stderr,
				"error read t1=%llu, t2=%llu, t2-t1=%llu\n",
				t1, t2, t2-t1);
			goto fail;
		}
		if (debug > 3)
		{
			fprintf(stderr, "rcvd: \"");
			write(STDERR_FILENO, buf, n);
			fprintf(stderr, "\"\n");
		}
		while (n > 0)
		{
			if (buf[--n] == 'Q')
			{
				retrd = 1;
				return (void *) &retrd;
			}
		}
		st_usleep(1);
	}

fail:
	retrd = -1;
	return (void *) &retrd;
}

static void *
handle_conn(void *arg)
{
	int fromlen, r, *rv;
	st_netfd_t srv_nfd, cli_nfd;
	st_thread_t thr1;
	struct sockaddr_in from;

	srv_nfd = srv_socket.nfd;
	fromlen = sizeof(from);

	while (1)
	{
		cli_nfd = st_accept(srv_nfd, (struct sockaddr *)&from,
				&fromlen, -1);
		if (NULL == cli_nfd)
		{
			err_sys_report(errfd,
				"ERROR: can't accept connection: st_accept");
			break;
		}
		/* Save peer address, so we can retrieve it later */
		st_netfd_setspecific(cli_nfd, &from.sin_addr, NULL);

		thr1 = st_thread_create(rdsrv, (void *) cli_nfd, 1, 0);
		if (NULL == thr1)
		{
			err_sys_report(errfd, "st_thread_create");
			exit(1);
		}
		wrsrv((void *)cli_nfd);

		r = st_thread_join(thr1, (void **) &rv);
		st_netfd_close(cli_nfd);
		if (*rv == -1)
			break;
	}
	return NULL;
}

int
main(int argc, char *argv[])
{
	if (getuid() == 0 || geteuid() == 0)
	{
		err_report(errfd, "WARNING: running as super-user!");
		exit(1);
	}

	/* Parse command-line options */
	parse_arguments(argc, argv);

	/* Initialize the ST library */
	if (st_init() < 0)
		err_sys_quit(errfd, "ERROR: initialization failed: st_init");

	/* Create listening sockets */
	create_listeners();

	/* Turn time caching on */
/*
	st_timecache_set(1);
*/

	handle_conn((void *)1);

	return 0;
}


syntax highlighted by Code2HTML, v. 0.9.1