/*
 * Copyright (c) 2003-2005 Sendmail, Inc. and its suppliers.
 *	All rights reserved.
 *
 * By using this file, you agree to the terms and conditions set
 * forth in the LICENSE file which can be found at the top level of
 * the sendmail distribution.
 */

#include "sm/generic.h"
SM_RCSID("@(#)$Id: t-tls-0.c,v 1.15 2006/10/05 04:27:35 ca Exp $")

#include "sm/assert.h"
#include "sm/error.h"
#include "sm/memops.h"
#include "sm/test.h"
#include "sm/io.h"
#include "sm/ctype.h"
#include "sm/time.h"
#include "sm/net.h"
#include "sm/unixsock.h"
#include "sm/tls.h"
#include "sm/tlsbio.h"
#include "timing.h"

#include <stdio.h>
#include "t-net-common.c"
#include "st-net.h"

#if MTA_USE_TLS

extern sm_stream_T SmStThrIO, SmStThrNetIO;
static tlsl_ctx_P tlsl_ctx;

#define TESTCERT	"smtestcert.pem"
#define TESTKEY		"smtestkey.pem"
#define TESTCACERT	"CAcert.pem"
#define TESTCERTP	"certs"

#define MAXBUFSZ	8192

/*
**  CLIENT -- write wr characters to localhost:port
**
**	Parameters:
**		port -- port
**		wr -- number of chars to send
**		timeout -- timeout
**		both -- if >0: do write and read ("both" times)
**			sets double buffering
**		iter -- number of iterations (writing data)
**
**	Returns:
**		none
*/

static void
client(int port, int wr, int timeout, int both, int iter)
{
	sm_ret_T ret;
	ssize_t written;
	size_t u;
	int sock, r, addrlen;
	st_netfd_t rmt_nfd;
	SSL_CTX *ctx;
	SSL *con;
	sm_file_T *fp, *fptls;
	sm_sockaddr_T rmt_addr;
	uchar buf[MAXBUFSZ];
	tls_cnf_T tls_cnf;

	tls_cnf.tlscnf_certfile = TESTCERT;
	tls_cnf.tlscnf_keyfile = TESTKEY;
	tls_cnf.tlscnf_dsa_certfile = NULL;
	tls_cnf.tlscnf_dsa_keyfile = NULL;
	tls_cnf.tlscnf_cacertpath = TESTCERTP;
	tls_cnf.tlscnf_cacertfile = TESTCACERT;
	tls_cnf.tlscnf_dhparam = NULL;
	sock = INVALID_SOCKET;
	fp = fptls = NULL;
	ctx = NULL;
	ret = sm_tls_init(tlsl_ctx, &ctx, TLS_I_CLT, false, &tls_cnf);
	SM_TEST(ret == SM_SUCCESS);
	if (ret != SM_SUCCESS)
		return;
	if (Verbose > 1)
		fprintf(stderr, "clt: connect\n");

	sock = socket(AF_INET, SOCK_STREAM, 0);
	SM_TEST(is_valid_socket(sock));
	if (!is_valid_socket(sock))
		return;
	ret = sm_io_open(&SmStThrIO, (void *) &sock, SM_IO_RDWR, &fp,
			SM_IO_WHAT_END);
	SM_TEST(ret == SM_SUCCESS);
	if (ret != SM_SUCCESS)
	{
		close(sock);
		return;
	}
	SM_TEST(fp != NULL);
	sm_io_clrblocking(fp);
	rmt_nfd = (st_netfd_t) f_cookie(*(fp));
	r = timeout;
	ret = sm_io_setinfo(fp, SM_IO_WHAT_TIMEOUT, &r);
	SM_TEST(ret == SM_SUCCESS);
	if (ret != SM_SUCCESS)
		goto fail;

	ret = sm_io_setinfo(fp, SM_IO_DOUBLE, NULL);
	SM_TEST(ret == SM_SUCCESS);
	if (ret != SM_SUCCESS)
		goto fail;

	rmt_addr.sin.sin_addr.s_addr = htonl(0x7f000001);
	rmt_addr.sin.sin_family = AF_INET;
	rmt_addr.sin.sin_port = htons(port);
	addrlen = sizeof(rmt_addr.sin);
	if (st_connect(rmt_nfd, (sockaddr_P) &(rmt_addr),
			 addrlen, SEC2USEC(timeout)) < 0)
		goto fail;
	for (u = 0; u < sizeof(buf); u++)
		buf[u] = (u % 64) + ' ';

	con = SSL_new(ctx);
	SM_TEST(con != NULL);
	if (con == NULL)
		goto fail;
	SSL_set_connect_state(con);
	ret = tls_open(fp, con, NULL, &fptls);
	SM_TEST(ret == SM_SUCCESS);
	if (both > 0)
	{
		ret = sm_io_setinfo(fptls, SM_IO_DOUBLE, &both);
		SM_TEST(sm_is_success(ret));
	}
	ret = do_tls_operation(fptls, 
			   SSL_connect, NULL, NULL, NULL, 0, &written);
	SM_TEST(ret >= 0);
	if (Verbose > 0)
		fprintf(stderr, "clt: tls_operation=%x\n", ret);
	if (sm_is_err(ret))
		goto fail;

	while (iter-- > 0)
	{
		SM_ASSERT(wr < (int) sizeof(buf));
		ret = sm_io_write(fptls, buf, wr, &written);
		if (Verbose > 0)
			fprintf(stderr, "clt: ret=%x, written=%d\n",
				ret, (int) written);
		SM_TEST(sm_is_success(ret));
		if (!sm_is_success(ret))
		{
			if (Verbose > 0)
				fprintf(stderr, "clt: error %x\n", ret);
			goto fail;
		}
		ret = sm_io_flush(fptls);
		SM_TEST(ret >= 0);
		if (both)
		{
			ret = sm_io_read(fptls, buf, wr, &written);
			if (Verbose > 0)
			{
				fprintf(stderr,
					"clt: sm_io_read=%x, rrd=%d\n",
					ret, (int) written);
				for (r = 0; r < written; r++)
					putchar(buf[r]);
				putchar('\n');
			}
			if (ret == SM_IO_EOF)
			{
				break;
			}
			SM_TEST(!sm_is_err(ret));
			if (sm_is_err(ret))
			{
				if (Verbose > 0)
					fprintf(stderr,
						"clt: sm_io_read=%x, rrd=%d\n",
						ret, (int) written);
				goto fail;
			}
			for (r = 0; r < written; r++)
				buf[r]++;
		}
	}
	ret = sm_io_flush(fptls);
	SM_TEST(ret >= 0);
	if (fptls != NULL)
		sm_io_close(fptls, SM_IO_CF_NONE);
	return;

  fail:
	SM_TEST(0);
	if (fptls != NULL)
		sm_io_close(fptls, SM_IO_CF_NONE);
/*
	if (fp != NULL)
		sm_io_close(fp, SM_IO_CF_NONE);
*/
	if (sock >= 0)
		close(sock);
}

/*
**  SERVER -- receive rd characters on localhost:port
**
**	Parameters:
**		port -- port
**		rd -- number of chars to receive
**		rep -- loop through all of this rep times
**		backlog -- size of listen queue
**		timeout -- timeout
**		both -- if >0: do write and read ("both" times)
**			sets double buffering
**		iter -- number of iterations (reading data)
**
**	Returns:
**		none
*/

static void
server(int port, int rd, int rep, int backlog, int timeout, int both, int iter)
{
	int sock, fromlen, r;
	sm_ret_T ret;
	struct sockaddr addr;
	sockaddr_len_T addrlen;
	ssize_t rrd;
	SSL_CTX *ctx;
	SSL *con;
	st_netfd_t srv_nfd, cli_nfd;
	sm_file_T *fp, *fptls;
	struct sockaddr_in serv_addr;
	struct sockaddr_in from;
	uchar buf[MAXBUFSZ];
	tls_cnf_T tls_cnf;

	tls_cnf.tlscnf_certfile = TESTCERT;
	tls_cnf.tlscnf_keyfile = TESTKEY;
	tls_cnf.tlscnf_dsa_certfile = NULL;
	tls_cnf.tlscnf_dsa_keyfile = NULL;
	tls_cnf.tlscnf_cacertpath = TESTCERTP;
	tls_cnf.tlscnf_cacertfile = TESTCACERT;
	tls_cnf.tlscnf_dhparam = NULL;
	fp = fptls = NULL;
	con = NULL;
	ctx = NULL;
	ret = sm_tls_init(tlsl_ctx, &ctx, TLS_I_SRV, true, &tls_cnf);
	SM_TEST(ret == SM_SUCCESS);
	if (ret != SM_SUCCESS)
		return;

	sock = socket(PF_INET, SOCK_STREAM, 0);
	SM_TEST(is_valid_socket(sock));
	if (!is_valid_socket(sock))
		return;
	fromlen = 1;
	r = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char *)&fromlen,
			sizeof(fromlen));
	SM_TEST(r >= 0);
	if (r < 0)
		return;

	sm_memzero(&serv_addr, sizeof(serv_addr));
	serv_addr.sin_family = AF_INET;
	serv_addr.sin_port = htons(port);
	serv_addr.sin_addr.s_addr = htonl(0x7f000001);

	/* Do bind and listen */
	r = bind(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr));
	SM_TEST(r >= 0);
	if (r < 0)
		goto fail;
	r = listen(sock, backlog);
	SM_TEST(r >= 0);
	if (r < 0)
		goto fail;
	srv_nfd = st_netfd_open_socket(sock);
	SM_TEST(srv_nfd != NULL);
	if (srv_nfd == NULL)
		goto fail;

	while (rep-- > 0)
	{
		addrlen = sizeof(addr);
		if (Verbose > 1)
			fprintf(stderr, "srv: accept\n");

		cli_nfd = st_accept(srv_nfd, (struct sockaddr *)&from,
				&fromlen, -1);
		SM_TEST(cli_nfd != NULL);
		if (cli_nfd == NULL)
			return;

		/* Save peer address, so we can retrieve it later */
		st_netfd_setspecific(cli_nfd, &from.sin_addr, NULL);

		ret = sm_io_open(&SmStThrNetIO, (void *) &cli_nfd, SM_IO_RDWR,
				&fp, SM_IO_WHAT_END);
		SM_TEST(ret == SM_SUCCESS);
		if (ret != SM_SUCCESS)
		{
			/* ??? */
			st_netfd_close(cli_nfd);
			continue;
		}

		/* switch to non-blocking */
		sm_io_clrblocking(fp);
		r = timeout;
		ret = sm_io_setinfo(fp, SM_IO_WHAT_TIMEOUT, &r);
		SM_TEST(ret == SM_SUCCESS);
		if (ret != SM_SUCCESS)
		{
			sm_io_close(fp, SM_IO_CF_NONE);
			continue;
		}
		ret = sm_io_setinfo(fp, SM_IO_DOUBLE, NULL);
		SM_TEST(ret == SM_SUCCESS);
		if (ret != SM_SUCCESS)
		{
			sm_io_close(fp, SM_IO_CF_NONE);
			continue;
		}

		con = SSL_new(ctx);
		SM_TEST(con != NULL);
		if (con == NULL)
			goto fail;
		SSL_set_accept_state(con);
		ret = tls_open(fp, con, NULL, &fptls);
		SM_TEST(ret == SM_SUCCESS);
		if (both > 0)
		{
			ret = sm_io_setinfo(fptls, SM_IO_DOUBLE, &both);
			SM_TEST(sm_is_success(ret));
		}

/*
		fp = fptls;
		fptls = NULL;
*/

		ret = do_tls_operation(fptls, 
			   SSL_accept, NULL, NULL, NULL, 0, &rrd);
		SM_TEST(ret >= 0);
		if (Verbose > 0)
			fprintf(stderr, "server: tls_operation=%x\n", ret);
		if (sm_is_err(ret))
			goto fail;

		ret = 0;
		while (iter-- > 0)
		{
			SM_ASSERT(rd < (int) sizeof(buf));
			ret = sm_io_read(fptls, buf, rd, &rrd);
			if (Verbose > 0)
			{
				fprintf(stderr,
					"srv: sm_io_read=%x, rrd=%d\n",
					ret, (int) rrd);
				for (r = 0; r < rrd; r++)
					putchar(buf[r]);
				putchar('\n');
			}
			for (r = 0; r < rrd; r++)
				buf[r] += 8;
			if (ret == SM_IO_EOF)
			{
				break;
			}
			SM_TEST(!sm_is_err(ret));
			if (sm_is_err(ret))
			{
				if (Verbose > 0)
					fprintf(stderr,
						"srv: sm_io_read=%x, rrd=%d\n",
						ret, (int) rrd);
				goto fail;
			}
			if (Verbose > 0)
				fprintf(stderr, "srv: sm_io_read=%x\n", ret);
			if (sm_is_err(ret))
				goto fail;
			if (both)
			{
				ret = sm_io_write(fptls, buf, rd, &rrd);
				if (Verbose > 0)
				{
					fprintf(stderr, "srv: ret=%x, written=%d\n",
						ret, (int) rrd);
					for (r = 0; r < rrd; r++)
						putchar(buf[r]);
					putchar('\n');
				}
				SM_TEST(sm_is_success(ret));
				if (!sm_is_success(ret))
				{
					if (Verbose > 0)
						fprintf(stderr, "clt: error %x\n", ret);
					goto fail;
				}
				ret = sm_io_flush(fptls);
				SM_TEST(ret >= 0);
			}
		}
	}
	if (fptls != NULL)
		sm_io_close(fptls, SM_IO_CF_NONE);
	if (cli_nfd != NULL)
		st_netfd_close(cli_nfd);
	if (sock >= 0)
		close(sock);
	return;

  fail:
	SM_TEST(0);
	if (cli_nfd != NULL)
		st_netfd_close(cli_nfd);
	if (sock >= 0)
		close(sock);
	return;
}
#endif /* MTA_USE_TLS */

int
main(int argc, char *argv[])
{
	bool clt, any;
	int c, port, rd, wr, rep, backlog, timeout, both, iter;
	sm_ret_T ret;

	opterr = 0;
	clt = true;
	any = false;
	port = SM_DEFPORT;
	rd = wr = 0;
	rep = 1;
	backlog = 20;
	timeout = 5;
	Verbose = 0;
	both = 0;
	iter = 1;
	while ((c = getopt(argc, argv, "b:c:d:i:l:r:s:t:p:BR:STV")) != -1)
	{
		any = true;
		switch (c)
		{
		  case 'B':
			both++;
			break;
		  case 'R':
			both = atoi(optarg);
			break;
		  case 'c':
			clt = true;
			wr = atoi(optarg);
			break;
		  case 'i':
			iter = atoi(optarg);
			break;
		  case 'l':
			backlog = atoi(optarg);
			break;
		  case 'p':
			port = atoi(optarg);
			break;
		  case 'r':
			rep = atoi(optarg);
			break;
		  case 's':
			clt = false;
			rd = atoi(optarg);
			break;
		  case 't':
			timeout = atoi(optarg);
			break;
		  case 'S':
			fprintf(stderr, "sizeof sm_file_T: %d\n",
				(int) sizeof(sm_file_T));
			fprintf(stderr, "sizeof sm_stream_T: %d\n",
				(int) sizeof(sm_stream_T));
			fprintf(stderr, "sizeof smbuf_T: %d\n",
				(int) sizeof(smbuf_T));
			break;
		  case 'T':
			Timing = true;
			break;
		  case 'V':
			++Verbose;
			break;
		  default:
			usage(argv[0]);
			return(1);
		}
	}
	sm_test_begin(argc, argv, "test tls 0");
#if MTA_USE_TLS
	if (!any)
		goto end;

	if (st_init() < 0)
		goto end;
	ret = sm_tls_init_library(&tlsl_ctx);
	SM_TEST(sm_is_success(ret));
	if (sm_is_err(ret))
		goto end;

	if (clt)
		client(port, wr, timeout, both, iter);
	else
	{
		server(port, rd, rep, backlog, timeout, both, iter);
	}
#endif /* MTA_USE_TLS */
  end:
	return sm_test_end();
}


syntax highlighted by Code2HTML, v. 0.9.1