#    nntpd.py - simple threaded nntp server classes for testing purposes.
#    Copyright (C) 2002-2004  Matthew Mueller <donut AT dakotacom.net>
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

from __future__ import nested_scopes

import os, select
import time
import threading
import SocketServer
import socket

#allow testing with ipv6.
if os.environ.get('TEST_NGET_IPv6'):
	serveraddr="::1"
	serveraf=socket.AF_INET6
else:
	serveraddr="127.0.0.1"
	serveraf=socket.AF_INET

def addressstr(address):
	host,port=address[:2]
	if ':' in host:
		return '[%s]:%s'%(host,port)
	return '%s:%s'%(host,port)

def chomp(line):
	if line[-2:] == '\r\n': return line[:-2]
	elif line[-1:] in '\r\n': return line[:-1]
	return line

import fnmatch
class WildMat:
	def __init__(self, pat):
		####pat should really be massaged into a regex since the wildmat semantics are not the same as that of fnmatch
		self.pat = pat
	def __call__(self, arg):
		return fnmatch.fnmatchcase(arg, self.pat)
def MatchAny(arg):
	return 1

class NNTPError(Exception):
	def __init__(self, code, text):
		self.code=code
		self.text=text
	def __str__(self):
		return '%03i %s'%(self.code, self.text)
class NNTPNoSuchGroupError(NNTPError):
	def __init__(self, g):
		NNTPError.__init__(self, 411, "No such newsgroup %s"%g)
class NNTPNoGroupSelectedError(NNTPError):
	def __init__(self):
		NNTPError.__init__(self, 412, "No newsgroup currently selected")
class NNTPNoSuchArticleNum(NNTPError):
	def __init__(self, anum):
		NNTPError.__init__(self, 423, "No such article %s in this newsgroup"%anum)
class NNTPNoSuchArticleMID(NNTPError):
	def __init__(self, mid):
		NNTPError.__init__(self, 430, "No article found with message-id %s"%mid)
class NNTPBadCommand(NNTPError):
	def __init__(self, s=''):
		NNTPError.__init__(self, 500, "Bad command" + (s and ' (%s)'%s or ''))
class NNTPSyntaxError(NNTPError):
	def __init__(self, s=''):
		NNTPError.__init__(self, 501, "Syntax error" + (s and ' (%s)'%s or ''))

class NNTPAuthRequired(NNTPError):
	def __init__(self):
		NNTPError.__init__(self, 480, "Authorization required")
class NNTPAuthPassRequired(NNTPError):
	def __init__(self):
		NNTPError.__init__(self, 381, "PASS required")
class NNTPAuthError(NNTPError):
	def __init__(self):
		NNTPError.__init__(self, 502, "Authentication error")

class NNTPDisconnect(Exception):
	def __init__(self, err=None):
		self.err=err

class AuthInfo:
	def __init__(self, user, password, caps=None):
		self.user=user
		self.password=password
		if caps is None:
			caps = {}
		self.caps=caps
	def has_auth(self, cmd):
		if not self.caps.has_key(cmd):
			if cmd in ('quit', 'authinfo'): #allow QUIT and AUTHINFO even if default has been set to no auth
				return 1
			return self.caps.get('*', 1) #default to full auth
		return self.caps[cmd]

def split_cmd(rcmd):
	rs = rcmd.split(' ',1)
	rs[0] = rs[0].lower()
	if len(rs)==1:
		return rs[0], ''
	else:
		return rs

class NNTPRequestHandler(SocketServer.StreamRequestHandler):
	def nwrite(self, s):
		self.wfile.write(s+"\r\n")
	def call_command(self, cmd, args):
		func = getattr(self, 'cmd_'+cmd, None)
		if func and callable(func):
			if not self.authed.has_auth(cmd):
				raise NNTPAuthRequired
			self.server.incrcount(cmd)
			func(args)
		else:
			raise NNTPBadCommand, cmd
	def handle(self):
		self.server.incrcount("_conns")
		readline = self.rfile.readline
		self.nwrite("200 Hello World, %s"%addressstr(self.client_address))
		self.group = None
		self._tmpuser = None
		self.authed = self.server.auth['']
		while 1:
			rcmd = readline()
			if not rcmd: break
			rcmd = rcmd.strip()
			cmd,args = split_cmd(rcmd)
			try:
				self.call_command(cmd, args)
			except NNTPDisconnect, d:
				if d.err:
					self.nwrite(str(d.err))
				return
			except NNTPError, e:
				self.nwrite(str(e))

	def cmd_authinfo(self, args):
		cmd,arg = split_cmd(args)
		if cmd=='user':
			self._tmpuser=arg
			raise NNTPAuthPassRequired
		elif cmd=='pass':
			if not self._tmpuser:
				raise NNTPAuthError
			a = self.server.auth.get(self._tmpuser)
			if not a:
				raise NNTPAuthError
			if arg != a.password:
				raise NNTPAuthError
			self.authed = a
			self.nwrite("281 Authentication accepted")
		else:
			raise NNTPSyntaxError, args
	
	def cmd_date(self, args):
		self.nwrite("111 "+time.strftime("%Y%m%d%H%M%S",time.gmtime()))
	
	def cmd_list(self, args):
		subcmd, args = split_cmd(args)
		self.call_command('list_'+subcmd, args)
	def cmd_list_newsgroups(self, args):
		self.nwrite("215 information follows")
		if args:
			matcher = WildMat(args)
		else:
			matcher = MatchAny
		for name,group in self.server.groups.items():
			if group.description and matcher(name):
				self.nwrite("%s %s"%(name, group.description))
		self.nwrite(".")
	def cmd_list_(self, args):
		if args:
			matcher = WildMat(args)
		else:
			matcher = MatchAny
		self.nwrite("215 list of newsgroups follows")
		for name,group in self.server.groups.items():
			if matcher(name):
				self.nwrite("%s %i %i %s"%(name, group.low, group.high, "y"))
		self.nwrite(".")
	cmd_list_active = cmd_list_

	def cmd_newgroups(self, args):
		#since = time.mktime(time.strptime(args,"%Y%m%d %H%M%S %Z"))
		since = ''.join(args.split()[0:2])
		if len(since)!=14:
			raise NNTPSyntaxError, args
		self.nwrite("231 list of new newsgroups follows")
		for name,group in self.server.groups.items():
			#if since < group.creationtime:
			if since < time.strftime("%Y%m%d%H%M%S",time.gmtime(group.creationtime)): #just do a lexicographical comparison. stupid c library. blah.
				self.nwrite("%s %i %i %s"%(name, group.low, group.high, "y"))
		self.nwrite(".")
		
	def cmd_listgroup(self, args):
		if args:
			group = self.server.groups.get(args)
			if not group:
				raise NNTPNoSuchGroupError, args
			self.group = group
		if not self.group:
			raise NNTPNoGroupSelectedError
		self.nwrite("211 list follows")
		anums = self.group.articles.keys()
		anums.sort()
		for an in anums:
			self.nwrite(str(an))
		self.nwrite(".")

	def cmd_group(self, args):
		self.group = self.server.groups.get(args)
		if not self.group:
			raise NNTPNoSuchGroupError, args
		self.nwrite("211 %i %i %i group %s selected"%(self.group.high-self.group.low+1, self.group.low, self.group.high, args))
	def cmd_xover(self, args):
		if not self.group:
			raise NNTPNoGroupSelectedError
		rng = args.split('-')
		if len(rng)>1:
			low,high = map(long, rng)
		else:
			low = high = long(rng[0])
		keys = [k for k in self.group.articles.keys() if k>=low and k<=high]
		keys.sort()
		self.nwrite("224 Overview information follows "+str(rng))
		for anum in keys:
			article = self.group.articles[anum]
			self.nwrite(str(anum)+'\t%(subject)s\t%(author)s\t%(date)s\t%(mid)s\t%(references)s\t%(bytes)s\t%(lines)s'%vars(article))
		self.nwrite('.')
	def cmd_xpat(self, args):
		if not self.group:
			raise NNTPNoGroupSelectedError
		field,rng,pat = args.split(' ', 2)
		field = field.lower()
		if field == 'message-id':
			field = 'mid'
		####doesn't handle specifing message-id instead of range (nget doesn't use this, though)
		rng = rng.split('-')
		if len(rng)>1:
			low,high = map(long, rng)
		else:
			low = high = long(rng[0])
		keys = [k for k in self.group.articles.keys() if k>=low and k<=high]
		keys.sort()
		matcher = WildMat(pat)
		self.nwrite("221 %s fields follow"%field)
		for anum in keys:
			article = self.group.articles[anum]
			val = getattr(article, field, "")
			if matcher(val):
				self.nwrite(str(anum)+self.server.xpat_field_sep+val)
		self.nwrite('.')
	def cmd_article(self, args):
		if args[0]=='<':
			try:
				article = self.server.midindex[args]
			except KeyError:
				raise NNTPNoSuchArticleMID, args
			anum=0
		else:
			if not self.group:
				raise NNTPNoGroupSelectedError
			try:
				anum=long(args)
				article = self.group.articles[anum]
			except KeyError:
				raise NNTPNoSuchArticleNum, args
		self.nwrite("220 %i %s Article follows"%(anum,article.mid))
		self.nwrite(article.text)
		self.nwrite('.')
	def cmd_mode(self, args):
		if args=='reader':
			self.nwrite("200 MODE READER enabled")
		else:
			raise NNTPSyntaxError, args
	def cmd_quit(self, args):
		raise NNTPDisconnect("205 Goodbye")


class _TimeToQuit(Exception): pass

class StoppableThreadingTCPServer(SocketServer.ThreadingTCPServer):
	def __init__(self, addr, handler):
		if os.name == "nt":
			import socket
			s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
			s1.bind(('127.0.0.1',0))
			s1.listen(1)
			s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
			s2.connect(s1.getsockname())
			self.controlr = s1.accept()[0]
			self.controlw = s2
			s1.close()
		else:
			self.controlr, self.controlw = os.pipe()
		self.address_family=serveraf
		SocketServer.ThreadingTCPServer.__init__(self, addr, handler)

	def stop_serving(self):
		if hasattr(self.controlw, 'send'):
			self.controlw.send('FOO')
		else:
			os.write(self.controlw, 'FOO')
	
	def get_request(self):
		readfds = [self.socket, self.controlr]
		while 1:
			ready = select.select(readfds, [], [])
			if self.controlr in ready[0]:
				raise _TimeToQuit
			if self.socket in ready[0]:
				return SocketServer.ThreadingTCPServer.get_request(self)
		
	def serve_forever(self):
		try:
			SocketServer.ThreadingTCPServer.serve_forever(self)
		except _TimeToQuit:
			if hasattr(self.controlw, 'close'):
				self.controlr.close()
				self.controlw.close()
			else:
				os.close(self.controlr)
				os.close(self.controlw)
			self.server_close() # Clean up before we leave

class NNTPTCPServer(StoppableThreadingTCPServer):
	def __init__(self, addr, RequestHandlerClass=NNTPRequestHandler):
		StoppableThreadingTCPServer.__init__(self, addr, RequestHandlerClass)
		self.groups = {}
		self.midindex = {}
		self.auth = {}
		self.adduser('','')
		self.lock = threading.Lock()
		self.counts = {}
		self.xpat_field_sep = ' '
	
	def count(self, key):
		return self.counts.get(key, 0)

	def incrcount(self, key):
		self.lock.acquire()
		self.counts[key] = self.count(key) + 1
		self.lock.release()

	def adduser(self, user, password, caps=None):
		self.auth[user]=AuthInfo(user, password, caps)

	def addarticle(self, groups, article, anum=None):
		self.midindex[article.mid]=article
		for g in groups:
			#if g not in self.groups:
			if not self.groups.has_key(g):
				self.groups[g]=Group()
			self.groups[g].addarticle(article, anum)
	
	def rmarticle(self, mid):
		article = self.midindex[mid]
		del self.midindex[mid]
		for g in self.groups.values():
			g.rmarticle(article)

	def addgroup(self, name, desc=None):
		if self.groups.has_key(name):
			self.groups[name].description=desc
		else:
			self.groups[name]=Group(description=desc)

class NNTPD_Master:
	def __init__(self, servers_num):
		self.servers = []
		self.threads = []
		if type(servers_num)==type(1): #servers_num is integer number of servers to start
			for i in range(servers_num):
				self.servers.append(NNTPTCPServer((serveraddr, 0))) #port 0 selects a port automatically.
		else: #servers_num is a list of servers already created
			self.servers.extend(servers_num)
			
	def start(self):
		for server in self.servers:
			s=threading.Thread(target=server.serve_forever)
			#s.setDaemon(1)
			s.start()
			self.threads.append(s)
			
	def stop(self):
		for server in self.servers:
			server.stop_serving()
		for thread in self.threads:
			thread.join()
		self.threads = []


class Group:
	def __init__(self, description=None):
		self.low = 1
		self.high = 0
		self.articles = {}
		self.creationtime = time.time()
		self.description = description
	def addarticle(self, article, anum=None):
		if anum is None:
			anum = self.high + 1
		if self.articles.has_key(anum):
			raise Exception, "already have article %s"%anum
		self.articles[anum] = article
		if anum > self.high:
			self.high = anum
		if anum < self.low:
			self.low = anum
	def rmarticle(self, article):
		for k,v in self.articles.items():
			if v==article:
				del self.articles[k]
				if self.articles:
					self.low = min(self.articles.keys())
				else:
					self.low = self.high + 1
				return

class FakeArticle:
	def __init__(self, mid, name, partno, totalparts, groups, body):
		self.mid=mid
		self.references=''
		a = []
		def add(foo):
			a.append(foo)
		add("Newsgroups: "+' '.join(groups))
		if totalparts>0:
			self.subject="%(name)s [%(partno)i/%(totalparts)i]"%vars()
		else:
			self.subject="Subject: %(name)s"%vars()
		add("Subject: "+self.subject)
		self.author = "<noone@nowhere> (test)"
		self.lines = len(body)
		add("From: "+self.author)
		self.date=time.ctime()
		add("Date: "+self.date)
		add("Lines: %i"%self.lines)
		add("Message-ID: "+mid)
		add("")
		for l in body:
			if l[0]=='.':
				add('.'+l)
			else:
				add(l)
		self.text = '\r\n'.join(a)
		self.bytes = len(self.text)

import rfc822
class FileArticle:
	def __init__(self, fobj):
		msg = rfc822.Message(fobj)
		self.author = msg.get("From")
		self.subject = msg.get("Subject")
		self.date = msg.get("Date")
		self.mid = msg.get("Message-ID")
		self.references = msg.get("References", '')
		a = [l.rstrip() for l in msg.headers]
		a.append('')
		for l in fobj.xreadlines():
			if l[0]=='.':
				a.append('.'+chomp(l))
			else:
				a.append(chomp(l))
		self.text = '\r\n'.join(a)
		self.lines = len(a) - 1 - len(msg.headers)
		self.bytes = len(self.text)


syntax highlighted by Code2HTML, v. 0.9.1