#include <9pm/u.h>
#include <9pm/libc.h>
#include <9pm/ns.h>
#include <9pm/thread.h>
#include <9pm/threadimpl.h>

static Lock chanlock;		// Central channel access lock

static int
emptyentry(Channel *c)
{
	int i, extra;

	assert((c->nentry == 0 && c->qentry == nil) || (c->nentry && c->qentry));
	for (i = 0; i < c->nentry; i++)
		if (c->qentry[i] == nil)
			return i;
	if (i == 0)
		extra = 1;
	else{
		extra = i;
		if (extra > 16) extra = 16;
	}
	c->nentry += extra;
	c->qentry = realloc(c->qentry, c->nentry*sizeof(c->qentry[0]));
	if (c->qentry == nil)
		sysfatal("realloc channel entries: %r");
	memset(&c->qentry[i], 0, extra*sizeof(c->qentry[0]));
	return i;
}

void
chanfree(Channel *c)
{
	int i, inuse;

	lock(&chanlock);
	inuse = 0;
	for (i = 0; i < c->nentry; i++)
		if (c->qentry[i]) inuse = 1;
	if (inuse)
		c->freed = 1;
	else {
		if (c->qentry) free(c->qentry);
		free(c);
	}
	unlock(&chanlock);
}

int
chaninit(Channel *c, int elemsize, int elemcnt)
{
	if(elemcnt < 0 || elemsize <= 0 || c == nil)
		return -1;
	c->f = 0;
	c->n = 0;
	c->freed = 0;
	c->s = elemcnt;
	c->e = elemsize;
	_threaddebug(DBGCHAN, "chaninit %lux", c);
	return 1;
}

Channel *
chancreate(int elemsize, int elemcnt)
{
	Channel *c;

	if(elemcnt < 0 || elemsize <= 0)
		return nil;
	c = _threadmalloc(sizeof(Channel) + elemcnt * elemsize, 1);
	c->s = elemcnt;
	c->e = elemsize;
	_threaddebug(DBGCHAN, "chancreate %lux", c);
	return c;
}

int
alt(Alt *alts)
{
	Alt *a, *xa;
	Channel *c;
	uchar *v;
	int i, n, entry;
	Thread *t;

	lock(&chanlock);
	t = _threadgetthr();
	t->alt = alts;
	t->call = Callalt;
repeat:

	// Test which channels can proceed
	n = 1;
	a = nil;
	entry = -1;
	for (xa = alts; xa->op; xa++) {
		xa->entryno = -1;
		if (xa->op == CHANNOP) continue;
		if (xa->op == CHANNOBLK) {
			if (a == nil) {
				t->call = Callnil;
				unlock(&chanlock);
				return xa - alts;
			} else
				break;
		}

		c = xa->c;
		if (c == nil)
			sysfatal("alt: nil channel in entry %ld\n", xa - alts);
		if ((xa->op == CHANSND && c->n < c->s) ||
			(xa->op == CHANRCV && c->n)) {
				// There's room to send in the channel
				if (nrand(n) == 0) {
					a = xa;
					entry = -1;
				}
				n++;
		} else {
			// Test for blocked senders or receivers
			for (i = 0; i < c->nentry; i++) {
				// Is it claimed?
				if (
					c->qentry[i]
					&& xa->op == (CHANSND+CHANRCV) - c->qentry[i]->op
							// complementary op
					&& *c->qentry[i]->tag == nil
				) {
					// No
					if (nrand(n) == 0) {
						a = xa;
						entry = i;
					}
					n++;
					break;
				}
			}
		}
	}

	if (a == nil) {
		// Nothing can proceed, enqueue on all channels
		c = nil;
		for (a = alts; a->op; a++) {
			Channel *cp;

			if (a->op == CHANNOP || a->op == CHANNOBLK) continue;
			cp = a->c;
			a->tag = &c;
			i = emptyentry(cp);
			cp->qentry[i] = a;
			a->entryno = i;
		}

		// And wait for the rendez vous
		unlock(&chanlock);
		if (_threadrendezvous((ulong)&c, 0) == ~0) {
			t->call = Callnil;
			return -1;
		}
		
		lock(&chanlock);

		/* We rendezed-vous on channel c, dequeue from all channels
		 * and find the Alt struct to which c belongs
		 */
		a = nil;
		for (xa = alts; xa->op; xa++) {
			Channel *xc;

			if (xa->op == CHANNOP || xa->op == CHANNOBLK) continue;
			xc = xa->c;
			threadassert(xa->entryno >= 0 && xa->entryno < xc->nentry && xc->qentry[xa->entryno]);
			xc->qentry[xa->entryno] = nil;
			xa->entryno = -1;
			if (xc == c)
				a = xa;
			
		}

		if (c->s) {
			// Buffered channel, try again
			sleep(0);
			goto repeat;
		}

		unlock(&chanlock);

		if (c->freed) chanfree(c);

		if (t->exiting)
			threadexits(nil);
		t->call = Callnil;
		return a - alts;
	}

	c = a->c;
	// Channel c can proceed

	if (c->s) {
		// Send or receive via the buffered channel
		if (a->op == CHANSND) {
			v = c->v + ((c->f + c->n) % c->s) * c->e;
			if (a->v)
				memmove(v, a->v, c->e);
			else
				memset(v, 0, c->e);
			c->n++;
		} else {
			if (a->v) {
				v = c->v + (c->f % c->s) * c->e;
				memmove(a->v, v, c->e);
			}
			c->n--;
			c->f++;
		}
	}
	if (entry < 0)
		for (i = 0; i < c->nentry; i++) {
			if (
				(xa = c->qentry[i])
				&& xa ->op == (CHANSND+CHANRCV) - a->op
				&& *xa ->tag == nil
			) {
				// Unblock peer process
				*xa->tag = c;

				unlock(&chanlock);
				if (_threadrendezvous((ulong)xa->tag, 0) == ~0) {
					t->call = Callnil;
					return -1;
				}
				t->call = Callnil;
				return a - alts;
			}
		}
	if (entry >= 0) {
		xa = c->qentry[entry];
		if (a->op == CHANSND) {
			if (xa->v) {
				if (a->v)
					memmove(xa->v, a->v, c->e);
				else
					memset(xa->v, 0, c->e);
			}
		} else {
			if (a->v) {
				if (xa->v)
					memmove(a->v, xa->v, c->e);
				else
					memset(a->v, 0, c->e);
			}
		}
		*xa->tag = c;

		unlock(&chanlock);
		if (_threadrendezvous((ulong)xa->tag, 0) == ~0) {
			t->call = Callnil;
			return -1;
		}
		t->call = Callnil;
		return a - alts;
	}
	unlock(&chanlock);
	yield();
	t->call = Callnil;
	return a - alts;
}

static int
recvcommon(Channel *c, void *v)
{
	Alt *a;
	int i;
	
	lock(&chanlock);
	for (i = 0; i < c->nentry; i++) {
		if (
			(a = c->qentry[i])
			&& a->op == CHANSND
			&& *a->tag == nil
		) {
			*a->tag = c;
			if (c->n) {
				// There's an item to receive in the buffered channel
				if (v)
					memmove(v, c->v + (c->f % c->s) * c->e, c->e);
				c->n--;
				c->f++;
			} else {
				if (v) {
					if (a->v)
						memmove(v, a->v, c->e);
					else
						memset(v, 0, c->e);
				}
			}

			unlock(&chanlock);
			if (_threadrendezvous((ulong)a->tag, 0) == ~0)
				return -1;
			return 1;
		}
	}
	if (c->n) {
		// There's an item to receive in the buffered channel
		if (v)
			memmove(v, c->v + (c->f % c->s) * c->e, c->e);
		c->n--;
		c->f++;
		unlock(&chanlock);
		return 1;
	}
	return 0;
}

int
nbrecv(Channel *c, void *v)
{
	int r;
	r = recvcommon(c, v);
	if (r == 0)
		unlock(&chanlock);
	return r;
}

int
recv(Channel *c, void *v)
{
	Alt a;
	Channel *tag;
	int i;
	Thread *t;

retry:
	if (i = recvcommon(c, v))
		// chanlock has been released
		return i;
	// chanlock is still held
	tag = nil;
	a.c = c;
	a.v = v;
	a.tag = &tag;
	a.op = CHANRCV;
	t = _threadgetthr();
	t->alt = &a;
	t->call = Callrcv;

	// enqueue on the channel
	i = emptyentry(c);
	c->qentry[i] = &a;
	a.entryno = i;
	unlock(&chanlock);
	if (_threadrendezvous((ulong)&tag, 0) == ~0) {
		t->call = Callnil;
		return -1;
	}
	lock(&chanlock);

	// dequeue from the channel
	threadassert(a.entryno >= 0 && a.entryno < c->nentry && c->qentry[a.entryno]);
	c->qentry[a.entryno] = nil;
	unlock(&chanlock);
	if (c->s) goto retry;	// Buffered channels: try the queue again
	if (c->freed) chanfree(c);
	t->call = Callnil;
	if (t->exiting)
		threadexits(nil);
	return 1;
}

static int
sendcommon(Channel *c, void *v)
{
	Alt *a;
	int i;

	lock(&chanlock);
	for (i = 0; i < c->nentry; i++) {
		if (
			(a = c->qentry[i])
			&& a->op == CHANRCV
			&& *a->tag == nil
		) {
			*a->tag = c;
			if (c->n < c->s) {
				// There's room to send in the buffered channel
				if (v)
					memmove(c->v + ((c->f + c->n) % c->s) * c->e, v, c->e);
				else
					memset(c->v + ((c->f + c->n) % c->s) * c->e, 0, c->e);
				c->n++;
			} else {
				if (a->v) {
					if (v)
						memmove(a->v, v, c->e);
					else
						memset(a->v, 0, c->e);
				}
			}
	
			unlock(&chanlock);
			if (_threadrendezvous((ulong)a->tag, 0) == ~0)
				return -1;
			return 1;
		}
	}
	if (c->n < c->s) {
		// There's room to send in the buffered channel
		if (v)
			memmove(c->v + ((c->f + c->n) % c->s) * c->e, v, c->e);
		else
			memset(c->v + ((c->f + c->n) % c->s) * c->e, 0, c->e);
		c->n++;
		unlock(&chanlock);
		yield();
		return 1;
	}
	return 0;
}

int nbsend(Channel *c, void *v)
{
	int r;

	r = sendcommon(c, v);
	if (r == 0)
		unlock(&chanlock);
	return r;
}

int
send(Channel *c, void *v)
{
	Alt a;
	Channel *tag;
	int i;
	Proc *p;
	Thread *t;

retry:
	if (i = sendcommon(c, v)){
		// chanlock has been released
		return i;
	}
	// chanlock is still held
	tag = nil;
	a.c = c;
	a.v = v;
	a.tag = &tag;
	a.op = CHANSND;
	t = _threadgetthr();
	t->alt = &a;
	t->call = Callsnd;

	// enqueue on the channel
	i = emptyentry(c);
	c->qentry[i] = &a;
	a.entryno = i;
	unlock(&chanlock);
	if (_threadrendezvous((ulong)&tag, 0) == ~0) {
		t->call = Callnil;
		return -1;
	}
	lock(&chanlock);
	// dequeue from the channel
	threadassert(a.entryno >= 0 && a.entryno < c->nentry && c->qentry[a.entryno]);
	c->qentry[a.entryno] = nil;
	unlock(&chanlock);
	if (c->s)
		goto retry;	// Buffered channels: try the queue again
	// Unbuffered channels: data is already transferred
	if (c->freed) chanfree(c);
	t->call = Callnil;
	if (t->exiting)
		threadexits(nil);
	return 1;
}

int
sendul(Channel *c, ulong v)
{
	threadassert(c->e == sizeof(ulong));
	return send(c, &v);
}

ulong
recvul(Channel *c)
{
	ulong v;

	threadassert(c->e == sizeof(ulong));
	if (recv(c, &v) < 0)
		return ~0;
	return v;
}

int
sendp(Channel *c, void *v)
{
	threadassert(c->e == sizeof(void *));
	return send(c, &v);
}

void *
recvp(Channel *c)
{
	void * v;

	threadassert(c->e == sizeof(void *));
	if (recv(c, &v) < 0)
		return nil;
	return v;
}

int
nbsendul(Channel *c, ulong v)
{
	threadassert(c->e == sizeof(ulong));
	return nbsend(c, &v);
}

ulong
nbrecvul(Channel *c)
{
	ulong v;

	threadassert(c->e == sizeof(ulong));
	if (nbrecv(c, &v) == 0)
		return 0;
	return v;
}

int
nbsendp(Channel *c, void *v)
{
	threadassert(c->e == sizeof(void *));
	return nbsend(c, &v);
}

void *
nbrecvp(Channel *c)
{
	void * v;

	threadassert(c->e == sizeof(void *));
	if (nbrecv(c, &v) == 0)
		return nil;
	return v;
}


syntax highlighted by Code2HTML, v. 0.9.1