/*
 * client SMTP program, use sm io.
 * generates "random" protocol data, tries to crash server
 *
 * based on proxy.c 
 */

#include "sm/generic.h"
SM_RCSID("@(#)$Id: smtpcr.c,v 1.6 2006/07/16 02:07:40 ca Exp $")
#include "sm/assert.h"
#include "sm/error.h"
#include "sm/str.h"
#include "sm/test.h"
#include "sm/io.h"
#include "sm/ctype.h"

#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>

#include "st.h"
#include "common.h"	/* HACK for debugging, otherwise structs are hidden */

extern sm_stream_T SmStThrIO;

#define IOBUFSIZE (5*1024)
static	uchar databuf[IOBUFSIZE];

#ifndef INADDR_NONE
# define INADDR_NONE 0xffffffff
#endif

#define MAXTC	65536L		/* max. number of total connections */
#define SC_MAXADDRS	256		/* max. number of addresses */
#define REQUEST_TIMEOUT 5
#define SEC2USEC(s) ((s)*1000000LL)

static char *prog;                     /* Program name   */
static struct sockaddr_in rmt_addr;    /* Remote address */
static unsigned int transactions;
static unsigned int sessions;
static int debug = 0;
static int stoponerror = 0;
static int count = 0;
static int busy = 0;
static int concurrent = 0;
static int waitdata = 0;
static unsigned int datalen = 0;
static unsigned int rcpts = 0;
static long unsigned int total = 0;
static unsigned int request_timeout = REQUEST_TIMEOUT;
static int showerror = 0;
static int usesize = 0;

static void read_address(const char *str, struct sockaddr_in *sin);
static void *handle_request(void *arg);
static void print_sys_error(const char *msg);

#define RAND(x) (random() % (x))

int
main(int argc, char *argv[])
{
	extern char    *optarg;
	int             opt, n, threads;
	int             raddr;
	long int        tc;
	size_t		j;

	prog = argv[0];
	raddr = 0;
	sessions = transactions = threads = 1;

	/* Parse arguments */
	while ((opt = getopt(argc, argv, "Cc:d:Ehl:m:n:O:r:s:St:T:w:Z"))
		!= -1)
	{
		switch (opt)
		{
		case 'C':
			count++;
			break;
		case 'T':
		case 'c':
			transactions = atoi(optarg);
			if (transactions < 1)
			{
				sm_io_fprintf(smioerr,
					"%s: invalid number of connections: %s\n",
					prog, optarg);
				exit(1);
			}
			break;
		case 'd':
			debug = atoi(optarg);
			if (debug < 0)
			{
				sm_io_fprintf(smioerr,
					"%s: invalid number for debug: %s\n",
					prog, optarg);
				exit(1);
			}
			break;
		case 'E':
			showerror = 1;
			break;
		case 'l':
			datalen = (unsigned int) atoi(optarg);
			break;
		case 'n':
			rcpts = (unsigned int) atoi(optarg);
			if (rcpts < 1)
			{
				sm_io_fprintf(smioerr,
					"%s: invalid number of rcpts: %s\n",
					prog, optarg);
				exit(1);
			}
			break;
		  case 'O':
			request_timeout  = atoi(optarg);
			if (request_timeout <= 0)
			{
				sm_io_fprintf(smioerr,
					"%s: invalid timeout: %s\n",
					prog, optarg);
				exit(1);
			}
			break;
		case 'r':
			read_address(optarg, &rmt_addr);
			if (rmt_addr.sin_addr.s_addr == INADDR_ANY)
			{
				sm_io_fprintf(smioerr,
					"%s: invalid remote address: %s\n",
					prog, optarg);
				exit(1);
			}
			raddr = 1;
			break;
		case 's':
			sessions = (unsigned int) atoi(optarg);
			if (sessions < 1)
			{
				sm_io_fprintf(smioerr,
					"%s: invalid number of sessions: %s\n",
					prog, optarg);
				exit(1);
			}
			break;
		case 'S':
			stoponerror++;
			break;
		case 't':
			threads = atoi(optarg);
			if (threads < 1)
			{
				sm_io_fprintf(smioerr,
					"%s: invalid number of threads: %s\n",
					prog, optarg);
				exit(1);
			}
			break;
		case 'w':
			waitdata = atoi(optarg);
			break;
		case 'Z':
			usesize = true;
			break;
		case 'h':
		case '?':
			sm_io_fprintf(smioerr,
				"Usage: %s [options] -r host:port\n"
				"-C          show counter\n"
				"-d n        set debug level\n"
				"-E          show SMTP dialogue errors\n"
				"-l n        data length\n"
				"-m n        number of messages to send\n"
				"-n n        number of recipients per transaction\n"
				"-O timeout  I/O timeout [%d].\n"
				"-s n        n total sessions\n"
				"-S          stop on errors\n"
				"-t n        concurrent threads\n"
				"-T n        transactions per thread\n"
				"-w n        wait for n seconds after DATA\n"
				"-Z          use SIZE= for MAIL\n"
				, prog, request_timeout);
			exit(1);
		}
	}
	if (!raddr)
	{
		sm_io_fprintf(smioerr, "%s: remote address required\n", prog);
		exit(1);
	}

	/* number of recipients not set? */
	if (rcpts == 0)
		rcpts = 1;

	srandom(st_time());
#define MAIL_HEADER "From: me\r\nTo: you\r\nSubject: test\r\n\r\n"
	j = strlcpy((char *)databuf, MAIL_HEADER, sizeof(databuf));
	for (j = strlen(MAIL_HEADER); j < sizeof(databuf); j++)
		databuf[j] = ' ' + (j % 64);
	for (j = 80; j < sizeof(databuf) - 1; j += 80)
	{
		databuf[j] = '\r';
		databuf[j + 1] = '\n';
	}

	tc = (long) sessions * (long) transactions;
	if (debug)
		sm_io_fprintf(smioerr, "%s: starting client [%d]\n",
			prog, threads);

	/* Initialize the ST library */
	if (st_init() < 0)
	{
		print_sys_error("st_init");
		exit(1);
	}
	for (n = 0; n < threads; n++)
	{
		if (debug)
			sm_io_fprintf(smioerr, "%s: starting client %d/%d\n",
				prog, n, threads);
		if (st_thread_create(handle_request, (void *) n, 0, 0) == NULL)
		{
			print_sys_error("st_thread_create");
			exit(1);
		}
	}

	/* wait for them... */
	st_sleep(1);
	while (busy > 0)
		st_sleep(1);
	/* XXX how? */

	sm_io_fprintf(smioerr, "%s: total=%lu (should be %lu)\n",
			prog, total, tc);
	return 0;
}

static void
read_address(const char *str, struct sockaddr_in * sin)
{
	char            host[128], *p;
	struct hostent *hp;
	short           port;

	strlcpy(host, str, sizeof(host));
	if ((p = strchr(host, ':')) == NULL)
	{
		sm_io_fprintf(smioerr, "%s: invalid address: %s\n", prog, host);
		exit(1);
	}
	*p++ = '\0';
	port = (short) atoi(p);
	if (port < 1)
	{
		sm_io_fprintf(smioerr, "%s: invalid port: %s\n", prog, p);
		exit(1);
	}
	memset(sin, 0, sizeof(struct sockaddr_in));
	sin->sin_family = AF_INET;
	sin->sin_port = htons(port);
	if (host[0] == '\0')
	{
		sin->sin_addr.s_addr = INADDR_ANY;
		return;
	}
	sin->sin_addr.s_addr = inet_addr(host);
	if (sin->sin_addr.s_addr == INADDR_NONE)
	{
		/* not dotted-decimal */
		if ((hp = gethostbyname(host)) == NULL)
		{
			sm_io_fprintf(smioerr,
				"%s: can't resolve address: %s\n", prog, host);
			exit(1);
		}
		memcpy(&sin->sin_addr, hp->h_addr, hp->h_length);
	}
}

/* before changing these, check the macros! */
#define SMTP_OK		0
#define SMTP_AN		1	/* SMTP reply type isn't 2 or 3 */
#define SMTP_SSD	2	/* 421 */
#define SMTP_RD		3	/* read error */
#define SMTP_WR		4	/* write error */
#define SMTP_IO_ERR(r)	((r) >= SMTP_RD)
#define SMTP_FATAL(r)	((r) >= SMTP_SSD)

static int
smtpdata(int l, sm_file_T *fp, int tid, int i)
{
	int             wr;
	sm_ret_T        ret;
	ssize_t         b;

	while (l > 0)
	{
		wr = SM_MIN(l, (int) sizeof(databuf));
		ret = sm_io_write(fp, databuf, wr, &b);
		if (b != wr || ret != SM_SUCCESS)
		{
			sm_io_fprintf(smioerr, "[%d] data write error i=%d, n=%d, r=%d, ret=%#x\n",
				      tid, i, l, (int) b, ret);
			return SMTP_WR;
		}
		l -= wr;
	}
	return SMTP_OK;
}

static int
smtpcommand(char *str, int l, sm_file_T * fp, sm_str_P out, int tid, int i)
{
	sm_ret_T        ret;
	ssize_t         b;

	if (debug > 3)
	{
		sm_io_fprintf(smioerr, "[%d] send: ", tid);
		sm_io_write(smioerr, (uchar *) str, l, &b);
		sm_io_flush(smioerr);
	}
	ret = sm_io_write(fp, (uchar *) str, l, &b);
	if (b != l)
	{
		sm_io_fprintf(smioerr, "[%d] write error i=%d, n=%d, r=%d, ret=%#x\n",
			      tid, i, l, (int) b, ret);
		if (stoponerror)
			sessions = 0;
		return SMTP_RD;
	}
	ret = sm_io_flush(fp);
	if (sm_is_error(ret))
	{
		sm_io_fprintf(smioerr, "[%d] flush error i=%d, n=%d, ret=%#x\n",
			      tid, i, (int) b, ret);
		if (stoponerror)
			sessions = 0;
		return SMTP_WR;
	}
	do
	{
		time_t before, after;

		sm_str_clr(out);
		before = st_time();
		ret = sm_fgetline0(fp, out);
		after = st_time();

		if (debug > 3)
		{
			sm_io_fprintf(smioerr, "[%d] rcvd [len=%d, res=%#x]: ", tid,
				      sm_str_getlen(out), ret);
			sm_io_write(smioerr, sm_str_getdata(out),
				sm_str_getlen(out), &b);
			sm_io_flush(smioerr);
		}
		if (sm_is_error(ret))
		{
			if (showerror)
			{
				uchar r;

				r = '\0'; /* avoid bogus compiler warning */
				if (l > 16)
				{
					r = str[16];
					str[16] = 0;
				}
				sm_io_fprintf(smioerr,
					"[%d] cmd=%s, error=read, i=%d, b=%d, ret=%#x, after-before=%ld\n",
					tid, str, i, (int) b, ret,
					(long) (after - before));
				if (l > 16)
					str[16] = r;
			}
			if (stoponerror)
				sessions = 0;
			return SMTP_WR;
		}
		if (sm_str_getlen(out) == 0 ||
		    (sm_str_rd_elem(out, 0) != '2' && sm_str_rd_elem(out, 0) != '3'))
			return SMTP_AN;
	} while (!sm_is_error(ret) && sm_str_getlen(out) > 4 && sm_str_rd_elem(out, 3) == '-');

	if (sm_str_getlen(out) > 3
	    && sm_str_rd_elem(out, 0) == '4'
	    && sm_str_rd_elem(out, 1) == '2'
	    && sm_str_rd_elem(out, 2) == '1')
		return SMTP_SSD;

	if (sm_str_getlen(out) > 3
	    && sm_str_rd_elem(out, 0) != '2'
	    && sm_str_rd_elem(out, 0) != '3')
		return SMTP_AN;
	/* check reply code... */

	return SMTP_OK;
}

static int
gen_addr(char *buf, size_t buflen)
{
	char c;
	size_t i, len;
	bool has_at;

	SM_ASSERT(buf != 0);
	SM_ASSERT(buflen > 0);

	switch (RAND(16))
	{
	  case 1:
		strlcat(buf, "{", buflen);
		break;
	  case 2:
		strlcat(buf, "(", buflen);
		break;
	  case 3:
		strlcat(buf, "[", buflen);
		break;
	  default:
		strlcat(buf, "<", buflen);
		break;
	}

	if (buflen < 256)
		len = buflen;
	if (buflen < 1024)
		len = buflen - RAND(256);
	else if (buflen < 2 * 1024)
		len = buflen - RAND(1024) - RAND(256);
	else if (buflen < 3 * 1024)
		len = buflen - RAND(2 * 1024) - RAND(256);
	else
		len = buflen - RAND(3 * 1024) - RAND(256);

	has_at = false;
	for (i = strlen(buf); i < len; i++)
	{
		c = '\0';
		switch (RAND(8))
		{
		  case 0:
			c = RAND(256);
			break;
		  case 1:
			if (!has_at || RAND(8) < 5)
			{
				c = '@';
				has_at = true;
			}
			break;
		  case 2:
			c = '%';
			break;
		  case 3:
			c = 'a';
			break;
		  case 4:
		  case 5:
			if (!has_at)
			{
				c = '@';
				has_at = true;
			}
			break;
		  default:
			c = RAND(64) + ' ';
			break;
		}
		if (c == '\0')
			break;
		buf[i] = c;
		buf[i + 1] = '\0';
	}

	switch (RAND(16))
	{
	  case 1:
		strlcat(buf, "}", buflen);
		break;
	  case 2:
		strlcat(buf, ")", buflen);
		break;
	  case 3:
		strlcat(buf, "]", buflen);
		break;
	  default:
		strlcat(buf, ">", buflen);
		break;
	}

	return 0;
}

static int
gen_mail(char *buf, size_t buflen)
{
	SM_ASSERT(buf != 0);
	SM_ASSERT(buflen > 0);
	buf[0] = '\0';
	switch (RAND(16))
	{
	  case 0:
		strlcat(buf, "mAIL fROM:", buflen);
		break;
	  case 1:
		strlcat(buf, "Mail From", buflen);
		break;
	  case 2:
		strlcat(buf, "Mail  From", buflen);
		break;
	  case 3:
		strlcat(buf, "Mail  From:", buflen);
		break;
	  case 4:
		strlcat(buf, "Mail  %s:", buflen);
		break;
	  default:
		strlcat(buf, "Mail From:", buflen);
		break;
	}

	gen_addr(buf, buflen);

	if (usesize && datalen > 0)
	{
		char sbuf[32];

		(void) sm_snprintf(sbuf, sizeof(sbuf), " SIZE=%d", datalen);
		strlcat(buf, sbuf, buflen);
	}
	else
	{
		/* generate random args? */
	}
	switch (RAND(32))
	{
	  case 1:
		strlcat(buf, "\n\r", buflen);
		break;
	  case 2:
		strlcat(buf, "\r\n\r", buflen);
		break;
	  case 3:
		strlcat(buf, "\n\r\n", buflen);
		break;
	  default:
		strlcat(buf, "\r\n", buflen);
		break;
	}
	return strlen(buf);
}

static int
gen_rcpt(char *buf, size_t buflen)
{
	SM_ASSERT(buf != 0);
	SM_ASSERT(buflen > 0);
	buf[0] = '\0';
	switch (RAND(8))
	{
	  case 0:
		strlcat(buf, "rCPT tO:", buflen);
		break;
	  case 1:
		strlcat(buf, "Rcpt To", buflen);
		break;
	  case 2:
		strlcat(buf, "Rcpt Too", buflen);
		break;
	  case 3:
		strlcat(buf, "Rcpt  To:", buflen);
		break;
	  case 4:
		strlcat(buf, "Rcpt  %s:", buflen);
		break;
	  default:
		strlcat(buf, "Rcpt to:", buflen);
		break;
	}

	gen_addr(buf, buflen);

	/* generate random args? */

	switch (RAND(4))
	{
	  case 0:
		strlcat(buf, "\n\r", buflen);
		break;
	  case 1:
		strlcat(buf, "\r\n\r", buflen);
		break;
	  case 2:
		strlcat(buf, "\n\r\n", buflen);
		break;
	  default:
		strlcat(buf, "\r\n", buflen);
		break;
	}
	return strlen(buf);
}

static void *
handle_request(void *arg)
{
	st_netfd_t      rmt_nfd;
	int             sock, n, tid, r, rok;
	unsigned int    i, j;
	ssize_t         b;
	sm_file_T      *fp;
	sm_str_P        out;
	sm_ret_T        ret;
	char            buf[IOBUFSIZE];

	++busy;
	i = 0;
	tid = (int) arg;
	if (debug)
		sm_io_fprintf(smioerr, "client[%d]: transactions=%d\n",
			tid, transactions);

	out = sm_str_new(NULL, IOBUFSIZE, IOBUFSIZE);
	if (NULL == out)
	{
		sm_snprintf(buf, sizeof(buf), "[%d] str new failed i=%u",
			tid, i);
		print_sys_error(buf);
		if (stoponerror)
			sessions = 0;
		goto done;
	}

	/*
	**  Only run a certain number of sessions;
	**  this is global variable but we can manipulate it without locking.
	*/

	while (sessions > 0)
	{
		--sessions;

		/* Connect to remote host */
		if ((sock = socket(PF_INET, SOCK_STREAM, 0)) < 0)
		{
			sm_snprintf(buf, sizeof(buf), "[%d] socket i=%u",
				tid, i);
			print_sys_error(buf);
			if (stoponerror)
				sessions = 0;
			goto done;
		}
		ret = sm_io_open(&SmStThrIO, (void *) &sock, SM_IO_RDWR, &fp,
				NULL);
		if (ret != SM_SUCCESS)
		{
			sm_snprintf(buf, sizeof(buf),
				"[%d] sm_io_open()=%d, i=%u", tid, ret, i);
			print_sys_error(buf);
			close(sock);
			if (stoponerror)
				sessions = 0;
			goto done;
		}
		sm_io_clrblocking(fp);
		rmt_nfd = (st_netfd_t) f_cookie(*fp);

		r = request_timeout;
		ret = sm_io_setinfo(fp, SM_IO_WHAT_TIMEOUT, &r);
		if (ret != SM_SUCCESS)
		{
			sm_snprintf(buf, sizeof(buf),
				"[%d] set timeout()=%d, i=%u", tid, ret, i);
			print_sys_error(buf);
			sm_io_close(fp, SM_IO_CF_NONE);
			if (stoponerror)
				sessions = 0;
			goto done;
		}
		ret = sm_io_setinfo(fp, SM_IO_DOUBLE, NULL);
		if (ret != SM_SUCCESS)
		{
			sm_snprintf(buf, sizeof(buf),
				"[%d] set double()=%d, i=%u", tid, ret, i);
			print_sys_error(buf);
			sm_io_close(fp, SM_IO_CF_NONE);
			if (stoponerror)
				sessions = 0;
			goto done;
		}
		rmt_nfd = (st_netfd_t) f_cookie(*fp);
		if (st_connect(rmt_nfd, (struct sockaddr *) & rmt_addr,
			       sizeof(rmt_addr), SEC2USEC(request_timeout)) < 0)
		{
			sm_snprintf(buf, sizeof(buf), "[%d] connect i=%u",
				tid, i);
			print_sys_error(buf);
			sm_io_close(fp, SM_IO_CF_NONE);
			if (stoponerror)
				sessions = 0;
			goto done;
		}
		++concurrent;
		if (debug > 2)
			sm_io_fprintf(smioerr,
				"client[%d]: connected, i=%u, ta=%d, conc=%d\n",
				tid, i, transactions, concurrent);

		do
		{
			sm_str_clr(out);
			ret = sm_fgetline0(fp, out);

			if (debug > 3)
			{
				sm_io_fprintf(smioerr,
					"[%d] greet [len=%d, res=%#x]: ", tid,
					sm_str_getlen(out), ret);
				sm_io_write(smioerr, sm_str_getdata(out),
					sm_str_getlen(out), &b);
				sm_io_flush(smioerr);
			}
			if (sm_is_error(ret))
			{
				sm_io_fprintf(smioerr,
					"[%d] read greet i=%u, ret=%#x\n",
					tid, i, ret);
				if (stoponerror)
					sessions = 0;
				goto fail;
			}
		} while (!sm_is_error(ret) && sm_str_getlen(out) > 4 &&
			 sm_str_rd_elem(out, 3) == '-');
		if (!sm_is_error(ret)
		    && sm_str_getlen(out) > 2
		    && sm_str_rd_elem(out, 0) == '4'
		    && sm_str_rd_elem(out, 1) == '2'
		    && sm_str_rd_elem(out, 2) == '1')
			goto fail;
		if (sm_is_error(ret) || sm_str_getlen(out) <= 0
		    || sm_str_rd_elem(out, 0) != '2')
			goto fail;

		n = strlcpy(buf, "EHLO me.local\r\n", sizeof(buf));
		r = smtpcommand(buf, n, fp, out, tid, 0);
		if (SMTP_SSD == r)
			goto fail;
		if (r != SMTP_OK)
		{
			if (showerror)
				sm_io_fprintf(smioerr,
					"GREETING=error, se=%d\n",
					sessions);
			goto fail;
		}

		for (i = 0; i < transactions; i++)
		{
			n = gen_mail(buf, sizeof(buf));
			r = smtpcommand(buf, n, fp, out, tid, i);
			if (SMTP_SSD == r)
				goto fail;
			if (r != SMTP_OK)
			{
				if (showerror)
					sm_io_fprintf(smioerr,
						"MAIL=error, se=%d, ta=%d\n",
						sessions, i);
				break;
			}

			rok = 0;
			for (j = 0; j < rcpts; j++)
			{
				n = gen_rcpt(buf, sizeof(buf));
				r = smtpcommand(buf, n, fp, out, tid, i);
				if (SMTP_SSD == r)
					goto fail;
				if (r != SMTP_OK)
				{
					if (showerror)
						sm_io_fprintf(smioerr,
							"RCPT=error, se=%d, ta=%d, rcpt=%d\n",
							sessions, i, j);
					if (SMTP_IO_ERR(r))
						break;
				}
				else
					++rok;
			}
			if (SMTP_FATAL(r) || rok == 0)
				break;

			n = strlcpy(buf, "DATA\r\n", sizeof(buf));
			r = smtpcommand(buf, n, fp, out, tid, i);
			if (SMTP_SSD == r)
				goto fail;
			if (r != SMTP_OK)
			{
				if (showerror)
					sm_io_fprintf(smioerr,
						"DATA=error, se=%d, ta=%d, r=%d\n",
						sessions, i, r);
				break;
			}
			if (waitdata > 0)
				st_sleep(waitdata);

			if (datalen == 0)
			{
				n = strlcpy(buf, "From: me\r\nTo: you\r\nSubject: test\r\n\r\nbody\r\n.\r\n", sizeof(buf));
			}
			else
			{
				r = smtpdata(datalen, fp, tid, i);
				if (r != SMTP_OK)
					break;
				n = strlcpy(buf, "\r\n.\r\n", sizeof(buf));
			}
			r = smtpcommand(buf, n, fp, out, tid, i);
			if (SMTP_SSD == r)
				goto fail;
			if (r != SMTP_OK)
			{
				if (showerror)
					sm_io_fprintf(smioerr,
						"DOT=error, se=%d, ta=%d\n",
						sessions, i);
				break;
			}

			if (i < transactions - 1)
			{
				n = strlcpy(buf, "RSET\r\n", sizeof(buf));
				r = smtpcommand(buf, n, fp, out, tid, i);
			}
			if (r != SMTP_OK)
				break;

			++total;
			if (count)
				sm_io_fprintf(smioerr, "%ld\r", total);
			{
				sessions = 0;
				break;
			}
		}
		n = strlcpy(buf, "QUIT\r\n", sizeof(buf));
		r = smtpcommand(buf, n, fp, out, tid, i);
		/*
		if (r != SMTP_OK)
			;
		*/

		sm_io_close(fp, SM_IO_CF_NONE);
		fp = NULL;
		--concurrent;
	}

  fail:
	if (fp != NULL)
	{
		sm_io_close(fp, SM_IO_CF_NONE);
		fp = NULL;
		--concurrent;
	}

  done:
	--busy;
	return NULL;
}

static void 
print_sys_error(const char *msg)
{
	sm_io_fprintf(smioerr, "%s: %s: %s\n", prog, msg, strerror(errno));
}


syntax highlighted by Code2HTML, v. 0.9.1