/*
 * cpu.c - Make a connection to a cpu server
 *
 *	   Invoked by listen as 'cpu -R | -N service net netdir'
 *	    	   by users  as 'cpu [-h system] [-c cmd args ...]'
 */

#include <9pm/u.h>
#include <9pm/libc.h>
#include <9pm/auth.h>
#include <9pm/fcall.h>
#include <9pm/libsec.h>
#include <9pm/thread.h>
#include <9pm/draw.h>
#define NODEFINE
#include <9pm/ns.h>

#define	Maxfdata 8192

static void	remoteside(int);
static void	fatal(int, char*, ...);
static void	lclnoteproc(int);
static void	rmtnoteproc(void);
static void	catcher(void*, char*);
static void	usage(void);
static void	writestr(int, char*, char*, int);
static int	readstr(int, char*, int);
static char	*rexcall(int*, char*, char*);
static int	setamalg(char*);

static int 	notechan;
static char	system[32];
static int	cflag;
static int	hflag;
static int	dbg;
static char	*user;

static char	*srvname = "ncpu";
static char	*exportfs = "/bin/exportfs";
static char	*ealgs = "rc4_256 sha1";

/* message size for exportfs; may be larger so we can do big graphics in CPU window */
static int	msgsize = Maxfdata+IOHDRSZ;

/* authentication mechanisms */
static int	netkeyauth(int);
static int	p9auth(int);
static int	noauth(int);

typedef struct AuthMethod AuthMethod;
struct AuthMethod {
	char	*name;			/* name of method */
	int	(*cf)(int);		/* client side authentication */
	int	(*sf)(int, char*);	/* server side authentication */
} authmethod[] =
{
	{ "p9",		p9auth,		nil,},
	{ "netkey",	netkeyauth,	nil,},
	{ "none",	noauth,		nil,},
	{ nil,	nil}
};
static AuthMethod *am = authmethod;	/* default is p9 */

static char *p9authproto = "p9any";

static int setam(char*);

static void
usage(void)
{
	fprint(2, "usage: cpu [-h system] [-a authmethod] [-e 'crypt hash'] [-c cmd args ...]\n");
	exits("usage");
}
static int fdd;

void
threadmain(int argc, char *argv[])
{
	char dat[128], buf[128], cmd[128], *p, *err;
	int fd, ms, data;

	pm_adddev(&devssl);
	initdraw(nil, nil, "drawterm");

	/* see if we should use a larger message size */
	fd = open("/dev/draw", OREAD);
	if(fd > 0){
		ms = iounit(fd);
		if(msgsize < ms+IOHDRSZ)
			msgsize = ms+IOHDRSZ;
		close(fd);
	}

	user = getuser();
	if(user == nil)
		fatal(1, "can't read user name");
	ARGBEGIN{
	case 'a':
		p = EARGF(usage());
		if(setam(p) < 0)
			fatal(0, "unknown auth method %s", p);
		break;
	case 'e':
		ealgs = EARGF(usage());
		if(*ealgs == 0 || strcmp(ealgs, "clear") == 0)
			ealgs = nil;
		break;
	case 'd':
		dbg++;
		break;
	case 'f':
		/* ignored but accepted for compatibility */
		break;
	case 'h':
		hflag++;
		p = EARGF(usage());
		strcpy(system, p);
		break;
	case 'c':
		cflag++;
		cmd[0] = '!';
		cmd[1] = '\0';
		while(p = ARGF()) {
			strcat(cmd, " ");
			strcat(cmd, p);
		}
		break;
	case 'o':
		p9authproto = "p9sk2";
		srvname = "cpu";
		break;
	case 'u':
		user = EARGF(usage());
		break;
	default:
		usage();
	}ARGEND;


	if(argc != 0)
		usage();

	if(hflag == 0) {
		p = getenv("cpu");
		if(p == 0)
			fatal(0, "set $cpu");
		strcpy(system, p);
	}

	if(err = rexcall(&data, system, srvname))
		fatal(1, "%s: %s", err, system);

	/* Tell the remote side the command to execute and where our working directory is */
	if(cflag)
		writestr(data, cmd, "command", 0);
	if(getwd(dat, sizeof(dat)) == 0)
		writestr(data, "NO", "dir", 0);
	else
		writestr(data, dat, "dir", 0);

	/* start up a process to pass along notes */
	lclnoteproc(data);

	/* 
	 *  Wait for the other end to execute and start our file service
	 *  of /mnt/term
	 */
	if(readstr(data, buf, sizeof(buf)) < 0)
		fatal(1, "waiting for FS");
	if(strncmp("FS", buf, 2) != 0) {
		print("remote cpu: %s", buf);
		exits(buf);
	}
	if(readstr(data, buf, sizeof buf) < 0)
		fatal(1, "waiting for FS root");
	/* ignore fs root */

	write(data, "OK", 2);

	/* Begin serving the gnot namespace */
	export(data);
	fatal(1, "starting exportfs");
}

static void
fatal(int syserr, char *fmt, ...)
{
	char buf[ERRMAX];
	va_list arg;

	va_start(arg, fmt);
	doprint(buf, buf+sizeof(buf), fmt, arg);
	va_end(arg);
	if(syserr)
		fprint(2, "cpu: %s: %r\n", buf);
	else
		fprint(2, "cpu: %s\n", buf);
	exits(buf);
}

static char *negstr = "negotiating authentication method";

static char*
rexcall(int *fd, char *host, char *service)
{
	char *na;
	char dir[128];
	char err[ERRMAX];
	char msg[128];
	int n;

	na = netmkaddr(host, 0, service);
	if((*fd = dial(na, 0, dir, 0)) < 0)
		return "can't dial";

	/* negotiate authentication mechanism */
	if(ealgs != nil)
		snprint(msg, sizeof(msg), "%s %s", am->name, ealgs);
	else
		snprint(msg, sizeof(msg), "%s", am->name);
	writestr(*fd, msg, negstr, 0);
	n = readstr(*fd, err, sizeof err);
	if(n < 0)
		return negstr;
	if(*err){
		werrstr(err);
		return negstr;
	}

	/* authenticate */
	*fd = (*am->cf)(*fd);
	if(*fd < 0)
		return "can't authenticate";
	return 0;
}

static void
writestr(int fd, char *str, char *thing, int ignore)
{
	int l, n;

	l = strlen(str);
	n = write(fd, str, l+1);
	if(!ignore && n < 0)
		fatal(1, "writing network: %s", thing);
}

static int
readstr(int fd, char *str, int len)
{
	int n;

	while(len) {
		n = read(fd, str, 1);
		if(n < 0) 
			return -1;
		if(*str == '\0')
			return 0;
		str++;
		len--;
	}
	return -1;
}

static int
readln(char *buf, int n)
{
	int i;
	char *p;

	n--;	/* room for \0 */
	p = buf;
	for(i=0; i<n; i++){
		if(read(0, p, 1) != 1)
			break;
		if(*p == '\n' || *p == '\r')
			break;
		p++;
	}
	*p = '\0';
	return p-buf;
}

/*
 *  user level challenge/response
 */
static int
netkeyauth(int fd)
{
	char chall[32];
	char resp[32];

	strcpy(chall, getuser());
	print("user[%s]: ", chall);
	if(readln(resp, sizeof(resp)) < 0)
		return -1;
	if(*resp != 0)
		strcpy(chall, resp);
	writestr(fd, chall, "challenge/response", 1);

	for(;;){
		if(readstr(fd, chall, sizeof chall) < 0)
			break;
		if(*chall == 0)
			return fd;
		print("challenge: %s\nresponse: ", chall);
		if(readln(resp, sizeof(resp)) < 0)
			break;
		writestr(fd, resp, "challenge/response", 1);
	}
	return -1;
}

static int
netkeysrvauth(int fd, char *user)
{
	char response[32];
	Chalstate *ch;
	int tries;
	AuthInfo *ai;

	if(readstr(fd, user, 32) < 0)
		return -1;

	ai = nil;
	ch = nil;
	for(tries = 0; tries < 10; tries++){
		if((ch = auth_challenge("proto=p9cr role=server user=%q", user)) == nil)
			return -1;
		writestr(fd, ch->chal, "challenge", 1);
		if(readstr(fd, response, sizeof response) < 0)
			return -1;
		ch->resp = response;
		ch->nresp = strlen(response);
		if((ai = auth_response(ch)) != nil)
			break;
	}
	auth_freechal(ch);
	if(ai == nil)
		return -1;
	writestr(fd, "", "challenge", 1);
	if(auth_chuid(ai, 0) < 0)
		fatal(1, "newns");
	auth_freeAI(ai);
	return fd;
}

static void
mksecret(char *t, uchar *f)
{
	sprint(t, "%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux%2.2ux",
		f[0], f[1], f[2], f[3], f[4], f[5], f[6], f[7], f[8], f[9]);
}

/*
 *  plan9 authentication followed by rc4 encryption
 */
static int
p9auth(int fd)
{
	uchar key[16];
	uchar digest[SHA1dlen];
	char fromclientsecret[21];
	char fromserversecret[21];
	int i;
	AuthInfo *ai;

	ai = auth_proxy(fd, auth_getkey, "proto=%q user=%q role=client", p9authproto, user);
	if(ai == nil)
		return -1;
	memmove(key+4, ai->secret, ai->nsecret);
	if(ealgs == nil)
		return fd;

	/* exchange random numbers */
	srand(truerand());
	for(i = 0; i < 4; i++)
		key[i] = rand();
	if(write(fd, key, 4) != 4)
		return -1;
	if(readn(fd, key+12, 4) != 4)
		return -1;

	/* scramble into two secrets */
	sha1(key, sizeof(key), digest, nil);
	mksecret(fromclientsecret, digest);
	mksecret(fromserversecret, digest+10);

	/* set up encryption */
	i = pushssl(fd, ealgs, fromclientsecret, fromserversecret, nil);
	if(i < 0)
		werrstr("can't establish ssl connection: %r");
	return i;
}

static int
noauth(int fd)
{
	ealgs = nil;
	return fd;
}

/*
 *  set authentication mechanism
 */
static int
setam(char *name)
{
	for(am = authmethod; am->name != nil; am++)
		if(strcmp(am->name, name) == 0)
			return 0;
	am = authmethod;
	return -1;
}

/*
 *  set authentication mechanism and encryption/hash algs
 */
static int
setamalg(char *s)
{
	ealgs = strchr(s, ' ');
	if(ealgs != nil)
		*ealgs++ = 0;
	return setam(s);
}

enum
{
	Qdir,
	Qcpunote,

	Nfid = 32,
};

static struct {
	char	*name;
	Qid	qid;
	ulong	perm;
} fstab[] =
{
	[Qdir]		{ ".",		{Qdir, 0, QTDIR},	DMDIR|0555	},
	[Qcpunote]	{ "cpunote",	{Qcpunote, 0},		0444		},
};

typedef struct Note Note;
struct Note
{
	Note *next;
	char msg[ERRMAX];
};

typedef struct Request Request;
struct Request
{
	Request *next;
	Fcall f;
};

typedef struct Fid Fid;
struct Fid
{
	int	fid;
	int	file;
};
static Fid fids[Nfid];

static struct {
	Lock;
	Note *nfirst, *nlast;
	Request *rfirst, *rlast;
} nfs;

static int
fsreply(int fd, Fcall *f)
{
	uchar buf[IOHDRSZ+Maxfdata];
	int n;

	if(dbg)
		fprint(2, "<-%F\n", f);
	n = convS2M(f, buf, sizeof buf);
	if(n > 0){
		if(write(fd, buf, n) != n){
			close(fd);
			return -1;
		}
	}
	return 0;
}

/* match a note read request with a note, reply to the request */
static int
kick(int fd)
{
	Request *rp;
	Note *np;
	int rv;

	for(;;){
		lock(&nfs);
		rp = nfs.rfirst;
		np = nfs.nfirst;
		if(rp == nil || np == nil){
			unlock(&nfs);
			break;
		}
		nfs.rfirst = rp->next;
		nfs.nfirst = np->next;
		unlock(&nfs);

		rp->f.type = Rread;
		rp->f.count = strlen(np->msg);
		rp->f.data = np->msg;
		rv = fsreply(fd, &rp->f);
		free(rp);
		free(np);
		if(rv < 0)
			return -1;
	}
	return 0;
}

static void
flushreq(int tag)
{
	Request **l, *rp;

	lock(&nfs);
	for(l = &nfs.rfirst; *l != nil; l = &(*l)->next){
		rp = *l;
		if(rp->f.tag == tag){
			*l = rp->next;
			unlock(&nfs);
			free(rp);
			return;
		}
	}
	unlock(&nfs);
}

static Fid*
getfid(int fid)
{
	int i, freefid;

	freefid = -1;
	for(i = 0; i < Nfid; i++){
		if(freefid < 0 && fids[i].file < 0)
			freefid = i;
		if(fids[i].fid == fid)
			return &fids[i];
	}
	if(freefid >= 0){
		fids[freefid].fid = fid;
		return &fids[freefid];
	}
	return nil;
}

static int
fsstat(int fd, Fid *fid, Fcall *f)
{
	Dir d;
	uchar statbuf[256];

	memset(&d, 0, sizeof(d));
	d.name = fstab[fid->file].name;
	d.uid = user;
	d.gid = user;
	d.muid = user;
	d.qid = fstab[fid->file].qid;
	d.mode = fstab[fid->file].perm;
	d.atime = d.mtime = time(0);
	f->stat = statbuf;
	f->nstat = convD2M(&d, statbuf, sizeof statbuf);
	return fsreply(fd, f);
}

static int
fsread(int fd, Fid *fid, Fcall *f)
{
	Dir d;
	uchar buf[256];
	Request *rp;

	switch(fid->file){
	default:
		return -1;
	case Qdir:
		if(f->offset == 0 && f->count >0){
			memset(&d, 0, sizeof(d));
			d.name = fstab[Qcpunote].name;
			d.uid = user;
			d.gid = user;
			d.muid = user;
			d.qid = fstab[Qcpunote].qid;
			d.mode = fstab[Qcpunote].perm;
			d.atime = d.mtime = time(0);
			f->count = convD2M(&d, buf, sizeof buf);
			f->data = (char*)buf;
		} else
			f->count = 0;
		return fsreply(fd, f);
	case Qcpunote:
		rp = mallocz(sizeof(*rp), 1);
		if(rp == nil)
			return -1;
		rp->f = *f;
		lock(&nfs);
		if(nfs.rfirst == nil)
			nfs.rfirst = rp;
		else
			nfs.rlast->next = rp;
		nfs.rlast = rp;
		unlock(&nfs);
		return kick(fd);
	}
}

static char Eperm[] = "permission denied";
static char Enofile[] = "out of files";
static char Enotdir[] = "not a directory";

static void
notefs(int fd)
{
	uchar buf[IOHDRSZ+Maxfdata];
	int i, j, n;
	char err[ERRMAX];
	Fcall f;
	Fid *fid, *nfid;
	int doreply;

	rfork(RFNOTEG);
	fmtinstall('F', fcallconv);

	for(n = 0; n < Nfid; n++)
		fids[n].file = -1;

	for(;;){
		n = read9pmsg(fd, buf, sizeof(buf));
		if(n <= 0){
			if(dbg)
				fprint(2, "read9pmsg(%d) returns %d: %r\n", fd, n);
			break;
		}
		if(convM2S(buf, n, &f) < 0)
			break;
		if(dbg)
			fprint(2, "->%F\n", &f);
		doreply = 1;
		fid = getfid(f.fid);
		if(fid == nil){
nofids:
			f.type = Rerror;
			f.ename = Enofile;
			fsreply(fd, &f);
			continue;
		}
		switch(f.type++){
		default:
			f.type = Rerror;
			f.ename = "unknown type";
			break;
		case Tflush:
			flushreq(f.oldtag);
			break;
		case Tversion:
			if(f.msize > IOHDRSZ+Maxfdata)
				f.msize = IOHDRSZ+Maxfdata;
			break;
		case Tauth:
			f.type = Rerror;
			f.ename = "cpu: authentication not required";
			break;
		case Tattach:
			f.qid = fstab[Qdir].qid;
			fid->file = Qdir;
			break;
		case Twalk:
			nfid = nil;
			if(f.newfid != f.fid){
				nfid = getfid(f.newfid);
				if(nfid == nil)
					goto nofids;
				nfid->file = fid->file;
				fid = nfid;
			}

			f.ename = nil;
			for(i=0; i<f.nwname; i++){
				if(i > MAXWELEM){
					f.type = Rerror;
					f.ename = "too many name elements";
					break;
				}
				if(fid->file != Qdir){
					f.type = Rerror;
					f.ename = Enotdir;
					break;
				}
				if(strcmp(f.wname[i], "cpunote") == 0){
					fid->file = Qcpunote;
					f.wqid[i] = fstab[Qcpunote].qid;
					continue;
				}
				f.type = Rerror;
				f.ename = err;
				strcpy(err, "cpu: file \"");
				for(j=0; j<=i; j++){
					if(strlen(err)+1+strlen(f.wname[j])+32 > sizeof err)
						break;
					if(j != 0)
						strcat(err, "/");
					strcat(err, f.wname[j]);
				}
				strcat(err, "\" does not exist");
				break;
			}
			if(nfid != nil && (f.ename != nil || i < f.nwname))
				nfid ->file = -1;
			if(f.type != Rerror)
				f.nwqid = i;
			break;
		case Topen:
			if(f.mode != OREAD){
				f.type = Rerror;
				f.ename = Eperm;
			}
			f.qid = fstab[fid->file].qid;
			break;
		case Tcreate:
			f.type = Rerror;
			f.ename = Eperm;
			break;
		case Tread:
			if(fsread(fd, fid, &f) < 0)
				goto err;
			doreply = 0;
			break;
		case Twrite:
			f.type = Rerror;
			f.ename = Eperm;
			break;
		case Tclunk:
			fid->file = -1;
			break;
		case Tremove:
			f.type = Rerror;
			f.ename = Eperm;
			break;
		case Tstat:
			if(fsstat(fd, fid, &f) < 0)
				goto err;
			doreply = 0;
			break;
		case Twstat:
			f.type = Rerror;
			f.ename = Eperm;
			break;
		}
		if(doreply)
			if(fsreply(fd, &f) < 0)
				break;
	}
err:
	if(dbg)
		fprint(2, "notefs exiting: %r\n");
	close(fd);
}

static char 	notebuf[ERRMAX];

static void
catcher(void*, char *text)
{
	int n;

	n = strlen(text);
	if(n >= sizeof(notebuf))
		n = sizeof(notebuf)-1;
	memmove(notebuf, text, n);
	notebuf[n] = '\0';
	noted(NCONT);
}

static void
fsproc(void *v)
{
	int fd;
	struct { int pfd[2]; int netfd; Channel *c; } *s;

	s = v;
	rfork(RFFDG);
	close(s->pfd[1]);
	close(s->netfd);
	fd = s->pfd[0];
	sendul(s->c, 0);
	notefs(fd);
}

/*
 *  mount in /dev a note file for the remote side to read.
 */
static void
lclnoteproc(int netfd)
{
	struct { int pfd[2]; int netfd; Channel *c; } s;

	s.netfd = netfd;
	if(pipe(s.pfd) < 0){
		fprint(2, "cpu: can't start note proc: pipe: %r\n");
		return;
	}
	s.c = chancreate(sizeof(ulong), 0);

	proccreate(fsproc, &s, 32768);
	recvul(s.c);
	close(s.pfd[0]);
	mount(s.pfd[1], -1, "/dev", MBEFORE, "");
	close(s.pfd[1]);
}


syntax highlighted by Code2HTML, v. 0.9.1