# -*- Mode: Python; tab-width: 4 -*- # # Author: Sam Rushing # VERSION_STRING = '$Id: corodns.py,v 1.1 2000/04/11 00:50:22 hassan Exp $' # async resolver with synchronous interface # Copyright 1999 by eGroups, Inc. # # All Rights Reserved # # Permission to use, copy, modify, and distribute this software and # its documentation for any purpose and without fee is hereby # granted, provided that the above copyright notice appear in all # copies and that both that copyright notice and this permission # notice appear in supporting documentation, and that the name of # eGroups not be used in advertising or publicity pertaining to # distribution of the software without specific, written prior # permission. # # EGROUPS DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, # INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN # NO EVENT SHALL EGROUPS BE LIABLE FOR ANY SPECIAL, INDIRECT OR # CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS # OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN # CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import sys import dnslib import dnsclass import dnsopcode import dnstype import string import fifo import coro import socket import types import named_cache from log import * socket_socket = coro.make_socket time_sleep = coro.sleep_relative # XXX: need to support timeouts. HostNotFound = "HostNotFound" gInterface = '63.201.227.2' rootServers = [('a.root-servers.net', '198.41.0.4'), ('b.root-servers.net', '128.9.0.107'), ('f.root-servers.net', '192.5.5.241'), ('l.root-servers.net', '198.32.64.12'), ] def isIPAddress(hostname): parts = string.split(hostname, ".") if len(parts) != 4: return 0 if parts[3][0] in string.digits: return 1 return 0 class dns_reply: def __init__ (self): self.q = [] self.an = [] self.ns = [] self.ar = [] def __repr__ (self): return '' % (self.q, self.an, self.ns, self.ar) class TCP_Handler: def __init__(self, servername, servers): self.servername = servername self.servers = servers self.fifo = fifo.fifo() self.request_map = {} self.id = 0 self.max_outstanding = 100 self.socket = None coro.spawn(self.recvThread) coro.spawn(self.sendThread) def open(self): self.socket = None for server in self.servers: if type(server) == types.ListType: server = server[0] if not isIPAddress(server): server = gethostbyname(server, orig_queryname=q.queryname) self.socket = socket_socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.bind((gInterface, 0)) try: self.socket.connect((server, 53), timeout=20) return 1 except socket.error, reason: pass except coro.TimeoutError, reason: pass if self.socket is None: logred("could connect to", self.servername, self.servers) return 0 return 1 def close(self): if self.socket: self.socket.close() self.socket = None def lookup(self, q): while self.socket is None: time_sleep(1) t = coro.current() self.fifo.push((q,t)) try: ret = coro.yield_(timeout=30) finally: ## remove request for rid, (k, starttime) in self.request_map.items(): if k == t: del self.request_map[rid] return ret def recvThread(self): ret = self.open() if ret == 0: return log(self.servername, "CONNECTED") while 1: reply = self.recv() if reply: self.handle_reply(reply) def handle_reply(self, reply): if reply is None: return rid = (ord(reply[0])<<8) | ord(reply[1]) if self.request_map.has_key(rid): k, starttime = self.request_map[rid] #log("reply: %s" % (time.time() - starttime)) del self.request_map[rid] if reply is None: logred("tcp scheduling a None reply") coro.schedule(k, reply) else: log(self.servername, '*** orphaned TCP DNS reply "%s"' % rid) def sendThread(self): while self.socket is None: time_sleep(1) while 1: while (len(self.fifo) and (len(self.request_map) < self.max_outstanding)): q, k = self.fifo.pop() self.send(q, k) time_sleep(.1) def send(self, q, k): self.id = self.id + 1 self.id = self.id % 65536 r = q.build_request(self.id) self.request_map[self.id] = (k, time.time()) n = self.socket.send(dnslib.pack16bit(len(r)) + r) def recv(self): header = self.socket.recv(2) if len(header) < 2: log(self.servername, "*** EOF ***") self.close() ret = self.open() if ret == 0: return return None count = dnslib.unpack16bit(header) blocks = [] while count > 0: block = self.socket.recv(count) count = count - len(block) blocks.append(block) reply = string.join(blocks, '') return reply class UDP_Handler: def __init__(self): self.request_map = {} self.id = 0 self.fifo = fifo.fifo() self.max_outstanding = 1000 self.socket = None coro.spawn(self.udp_recv) coro.spawn(self.udp_send) coro.spawn(self.statusThread) def lookup(self, server, q, timeout=4): t = coro.current() self.fifo.push(((server, q), t)) try: ret = coro.yield_(timeout=timeout) finally: ## remove request for rid, (k, starttime) in self.request_map.items(): if t == k: del self.request_map[rid] if ret is socket.error: raise socket.error, "socket error" if ret is None: logred("udp lookup returned None") return ret def send_udp_request(self, (server, q), k): self.id = self.id + 1 self.id = self.id % 65536 r = q.build_request(self.id) if type(server) == types.ListType: server = server[0] if not isIPAddress(server): logred("UDP: Not an ip address", server) n = self.socket.sendto(r, (server, 53)) if n != len(r): raise socket.error, "sendto() underperformed" self.request_map[self.id] = (k, time.time()) def handle_reply(self, reply): if reply is None: return rid = (ord(reply[0])<<8) | ord(reply[1]) if self.request_map.has_key(rid): k, starttime = self.request_map[rid] #log("reply: %s" % (time.time() - starttime)) del self.request_map[rid] if reply is None: logred("udp scheduling a None reply") coro.schedule(k, reply) else: #log('*** orphaned UDP DNS reply "%s"' % rid) pass def setup_udpsocket(self): self.socket = coro.coroutine_socket() self.socket.create_socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket.bind((gInterface, 0)) # these can be set really high on linux, but freebsd 2 default limit is 256KB self.socket.socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 5 * 1024 * 1024) self.socket.socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 5 * 1024 * 1024) def udp_send(self): while self.socket is None: time_sleep(1) while 1: while (len(self.fifo) and (len(self.request_map) < self.max_outstanding)): q, k = self.fifo.pop() try: self.send_udp_request (q, k) except socket.error, reason: coro.schedule(k, socket.error) time_sleep(.1) def udp_recv(self): self.setup_udpsocket() while 1: try: reply, whence = self.socket.recvfrom(4096) if not reply is None: self.handle_reply(reply) except socket.error, reason: self.setup_udpsocket() pass def statusThread(self): while 1: time_sleep(30) log("Size of Fifo: %s" % len(self.fifo)) log("# of Outstanding Requests: %s" % len(self.request_map)) #log(self.request_map.keys()) class resolver: def __init__ (self): self.cache = {} self.nscache = {} self.id = 0 self.udp = UDP_Handler() self.connections = {} self.connectionsTime = {} coro.spawn(self.statusThread) coro.spawn(self.cleanConnections) servers = [] for name, ip in rootServers: self.StoreName((name, 'A', 'IN'), ip) servers.append(ip) self.nscache['com'] = servers self.debugfp = open("/tmp/corodns.log", "w") def HasKeyNameUsingCache(self, (hostname, qtype, qclass)): key = (hostname, qtype, qclass) return self.cache.has_key(key) def GetNameUsingCache(self, key): return self.cache[key] def StoreName(self, key, ip): try: l = self.cache[key] except KeyError: l = [] self.cache[key] = l try: i = l.index(ip) except ValueError: l.append(ip) def cleanConnections(self): while 1: l = [] now = time.time() for server,t in self.connectionsTime.items(): l.append(((now-t),server)) l.sort() l.reverse() for dt,server in l: c = self.connections[server] if dt > 10 and len(c.request_map) == 0 and len(c.fifo) == 0: del self.connections[server] del self.connectionsTime[server] time_sleep(10) def GetNewTCPHandler(self, domain, servers): while len(self.connections) > 20: time_sleep(1) t = TCP_Handler(domain, servers) self.connections[domain] = t return t def getTCPHandler(self, domain, servers): try: t = self.connections[domain] except KeyError: t = self.GetNewTCPHandler(domain, servers) self.connectionsTime[domain] = time.time() return t def _lookup(self, q): reply = None if self.connections.has_key(q.domain): t = self.getTCPHandler(q.domain, q.servers) try: q.incAttempts() reply = t.lookup(q) return reply except coro.TimeoutError, reason: logred("Timeout1", reason) return None if q.domain is not None: t = self.udp for n in range(3): timeout = .4 + (2 * n) for server in q.servers: if not isIPAddress(server): try: server = gethostbyname(server, orig_queryname=q.queryname) except HostNotFound, reason: continue if server is None: continue if len(server) == 0: continue try: q.incAttempts() reply = t.lookup(server, q, timeout=timeout) return reply except coro.TimeoutError, reason: pass except socket.error, reason: pass if q.domain is None: t = self.getTCPHandler(q.domain, q.servers) try: q.incAttempts() reply = t.lookup(q) return reply except coro.TimeoutError, reason: logred("Timeout2", reason) return None logred("Error...None returned", q.domain, q.servers) return reply def FindNearestServer(self, name, q=None): serverPack = ('com', self.nscache['com']) return serverPack name = string.lower(name) parts = string.split(name, '.') #self.debugfp.write("FindNearestServer %s\n" % str(name)) #self.debugfp.flush() hosts = None for i in range(1, len(parts)): domain = string.join(parts[i:], ".") try: hosts = self.nscache[domain] break except KeyError: pass if hosts: servers = [] for hostname in hosts: hostname = string.lower(hostname) if isIPAddress(hostname): ip = [hostname] else: try: ip = self.GetNameUsingCache((hostname, 'A', 'IN')) except KeyError: if hostname == name: log(name, "skipping", hostname) continue ip = query_with_cname(hostname, orig_queryname=name) servers = servers + ip serverPack = (domain, servers) return serverPack serverPack = ('com', self.nscache['com']) return serverPack def query(self, q): key = q.getKey() if q.servers is None: domain, servers = self.FindNearestServer(q.name, q=q) q.addServerPack(domain, servers) if self.HasKeyNameUsingCache(key): return self.GetNameUsingCache(key) else: reply = self._lookup(q) if reply is None: raise HostNotFound, "no reply: %s %s" % (q.queryname, q.attempts) try: result = unpack_reply(reply) except TypeError: raise HostNotFound, "unable to unpack reply: %s %s" % (q.queryname, q.attempts) for t, ns, hostname in result.ns: if t == "SOA": continue hostname = string.lower(hostname) ns = string.lower(ns) try: l = self.nscache[ns] except KeyError: l = [] self.nscache[ns] = l try: i = l.index(hostname) except ValueError: l.append(hostname) for t, hostname, ip in result.ar: hostname = string.lower(hostname) self.StoreName((hostname, t, 'IN'), ip) #log("STORE", hostname, ip) #self.StoreName(key, result) return result def Query(self, q): #self.debugfp.write("Query %s(%s) %s (%s %s)\n" % (q.queryname, q.orig_queryname, q.name, q.domain, str(q.servers))) #self.debugfp.flush() #log("Query", name, qtype) depth = 0 attempts = 0 if q.depth > 5: logred(q.queryname, q.name, "LOOP depth", q.depth, q.lookups) raise HostNotFound, 'CNAME loop' key = q.getKey() try: ips = self.GetNameUsingCache(key) q.addAnswers(ips) return except KeyError: pass result = self.query(q) if result is None: raise HostNotFound, "couldn't lookup host %s" % q.queryname if len(result.an) > 0: for r in result.an: if r[0] == 'CNAME': q.setName(r[2]) q.incDepth() self.Query(q) else: q.addAnswer(r[2]) if q.answers: return if len(result.ns) > 0: servers = [] for aType, aDomain, aNS in result.ns: aDomain = string.lower(aDomain) if aType == "SOA": aNS = aNS[0] hostname = string.lower(aNS) #servers = servers + gethostbyname(hostname, orig_queryname=q.queryname) servers.append(hostname) if aDomain == q.domain: logred("looping", q.domain, q.queryname) return q.addServerPack(aDomain, servers) q.incDepth() self.Query(q) return return def statusThread(self): while 1: time_sleep(30) log("Number of TCP Connections: %s" % len(self.connections)) # log("# of Outstanding Requests: %s" % len(self.request_map)) import exceptions class DNS_Exception (exceptions.Exception): pass the_resolver = None def initialize(): global the_resolver the_resolver = resolver() class NSQuery: def __init__(self, name, qtype='A', qclass='IN', recursion=1): self.orig_queryname = None self.queryname = name self.name = name self.lookups = [name] self.qtype = qtype self.qclass = qclass self.recursion = recursion self.domain = None self.servers = None self.answers = [] self.depth = 0 self.attempts = 0 def setName(self, name): self.name = name self.lookups.append(name) def addServerPack(self, domain, servers): self.domain = domain self.servers = servers self.lookups.append(domain) def incDepth(self): self.depth = self.depth + 1 def incAttempts(self): self.attempts = self.attempts + 1 def getKey(self): return (self.name, self.qtype, self.qclass) def addAnswer(self, answer): self.answers.append(answer) def addAnswers(self, answers): self.answers = self.answers + answers def build_request(self, rid): m = dnslib.Mpacker() m.addHeader( rid, 0, dnsopcode.QUERY, 0, 0, self.recursion, 0, 0, 0, 1, 0, 0, 0 ) m.addQuestion(self.name, getattr(dnstype, self.qtype), getattr(dnsclass, self.qclass)) return m.getbuf() def query_with_cname (name, qtype='A', qclass='IN', recursion=1, orig_queryname=None): global the_resolver if the_resolver is None: initialize() q = NSQuery(name, qtype, qclass, recursion) q.orig_queryname = orig_queryname the_resolver.Query(q) return q.answers def gethostbyname (host, orig_queryname=None): return query_with_cname (host, 'A', orig_queryname=orig_queryname) def get_rr (u): name, type, klass, ttl, rdlength = u.getRRheader() typename = dnstype.typestr(type) mname = 'get%sdata' % typename if hasattr (u, mname): return (typename, name, getattr(u, mname)()) else: return (typename, name, u.getbytes(rdlength)) def unpack_reply (reply): u = dnslib.Munpacker (reply) (id, qr, opcode, aa, tc, rd, ra, z, rcode, qdcount, ancount, nscount, arcount) = u.getHeader() r = dns_reply() for i in range(qdcount): r.q.append(u.getQuestion()) for i in range(ancount): r.an.append (get_rr(u)) for i in range(nscount): r.ns.append (get_rr(u)) for i in range(arcount): r.ar.append (get_rr(u)) return r # To test: # start this server in one window. # In another window, telnet to port 8023, and type # something like this: # # >>> query ('yoyodyne.com') # >>> query ('yoyodyne.com') # # >>> if __name__ == '__main__': import backdoor global the_resolver if len(sys.argv) > 1: initialize (sys.argv[1]) else: initialize ('127.0.0.1') def q2 (*args): reply = apply (query, args) dnslib.dumpM (dnslib.Munpacker (reply)) coro.spawn (backdoor.serve) coro.spawn (the_resolver.run) coro.event_loop (30.0)