# Twisted, the Framework of Your Internet # Copyright (C) 2001 Matthew W. Lefkowitz # # This library is free software; you can redistribute it and/or # modify it under the terms of version 2.1 of the GNU Lesser General Public # License as published by the Free Software Foundation. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public # License along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA """ Asynchronous client DNS API Stability: Unstable Future plans: Proper nameserver acquisition on Windows/MacOS, better caching, respect timeouts @author: U{Jp Calderone} """ from __future__ import nested_scopes import socket import os import errno import time # Twisted imports from twisted.python.runtime import platform from twisted.internet import defer, protocol, interfaces, threads from twisted.python import log, failure from twisted.protocols import dns import common class Resolver(common.ResolverBase): __implements__ = (interfaces.IResolver,) index = 0 timeout = None factory = None servers = None dynServers = () pending = None protocol = None connections = None resolv = None _lastResolvTime = None _resolvReadInterval = 60 def __init__(self, resolv = None, servers = None, timeout = (1, 3, 11, 45)): """ @type servers: C{list} of C{(str, int)} or C{None} @param servers: If not None, interpreted as a list of addresses of domain name servers to attempt to use for this lookup. Addresses should be in dotted-quad form. If specified, overrides C{resolv}. @type resolv: C{str} @param resolv: Filename to read and parse as a resolver(5) configuration file. @type timeout: Sequence of C{int} @param timeout: Default number of seconds after which to reissue the query. When the last timeout expires, the query is considered failed. @raise ValueError: Raised if no nameserver addresses can be found. """ common.ResolverBase.__init__(self) self.timeout = timeout if servers is None: self.servers = [] else: self.servers = servers self.resolv = resolv if not len(self.servers) and not resolv: raise ValueError, "No nameservers specified" self.factory = DNSClientFactory(self, timeout) self.factory.noisy = 0 # Be quiet by default self.protocol = dns.DNSDatagramProtocol(self) self.protocol.noisy = 0 # You too self.connections = [] self.pending = [] self.maybeParseConfig() def __getstate__(self): d = self.__dict__.copy() d['connections'] = [] d['_parseCall'] = None return d def __setstate__(self, state): self.__dict__.update(state) self.maybeParseConfig() def maybeParseConfig(self): if self.resolv is None: # Don't try to parse it, don't set up a call loop return try: resolvConf = file(self.resolv) except IOError, e: if e.errno == errno.ENOENT: pass else: raise else: mtime = os.fstat(resolvConf.fileno()).st_mtime if mtime != self._lastResolvTime: log.msg('%s changed, reparsing' % (self.resolv,)) self._lastResolvTime = mtime self.parseConfig(resolvConf) # Check the mtime again in a little while from twisted.internet import reactor self._parseCall = reactor.callLater(self._resolvReadInterval, self.maybeParseConfig) def parseConfig(self, resolvConf): servers = [] for L in resolvConf: L = L.strip() if L.startswith('nameserver'): resolver = (L.split()[1], dns.PORT) servers.append(resolver) log.msg("Resolver added %r to server list" % (resolver,)) elif L.startswith('domain'): try: self.domain = L.split()[1] except IndexError: self.domain = '' self.search = None elif L.startswith('search'): try: self.search = L.split()[1:] except IndexError: self.search = '' self.domain = None self.dynServers = servers def pickServer(self): """ Return the address of a nameserver. TODO: Weight servers for response time so faster ones can be preferred. """ if not self.servers and not self.dynServers: return None serverL = len(self.servers) dynL = len(self.dynServers) self.index += 1 self.index %= (serverL + dynL) if self.index < serverL: return self.servers[self.index] else: return self.dynServers[self.index - serverL] def connectionMade(self, protocol): self.connections.append(protocol) for (d, q, t) in self.pending: self.queryTCP(q, t).chainDeferred(d) del self.pending[:] def messageReceived(self, message, protocol, address = None): log.msg("Unexpected message (%d) received from %r" % (message.id, address)) def queryUDP(self, queries, timeout = None): """ Make a number of DNS queries via UDP. @type queries: A C{list} of C{dns.Query} instances @param queries: The queries to make. @type timeout: Sequence of C{int} @param timeout: Number of seconds after which to reissue the query. When the last timeout expires, the query is considered failed. @rtype: C{Deferred} @raise C{twisted.internet.defer.TimeoutError}: When the query times out. """ if timeout is None: timeout = self.timeout address = self.pickServer() if address is None: return defer.fail(IOError("No domain name servers available")) return self.protocol.query(address, queries, timeout[0] ).addErrback(self._reissue, address, queries, timeout[1:] ) def _reissue(self, reason, address, query, timeout): reason.trap(defer.TimeoutError) if timeout and self.protocol.transport: d = self.protocol.query(address, query, timeout[0], reason.value.id) d.addErrback(self._reissue, address, query, timeout[1:]) return d try: del self.protocol.resends[reason.value.id] except: pass return failure.Failure(defer.TimeoutError(query)) def queryTCP(self, queries, timeout = 10): """ Make a number of DNS queries via TCP. @type queries: Any non-zero number of C{dns.Query} instances @param queries: The queries to make. @type timeout: C{int} @param timeout: The number of seconds after which to fail. @rtype: C{Deferred} """ if not len(self.connections): address = self.pickServer() if address is None: return defer.fail(IOError("No domain name servers available")) host, port = address from twisted.internet import reactor reactor.connectTCP(host, port, self.factory) self.pending.append((defer.Deferred(), queries, timeout)) return self.pending[-1][0] else: return self.connections[0].query(queries, timeout) def filterAnswers(self, message): if message.trunc: return self.queryTCP(message.queries).addCallback(self.filterAnswers) else: return (message.answers, message.authority, message.additional) def _lookup(self, name, cls, type, timeout): return self.queryUDP( [dns.Query(name, type, cls)], timeout ).addCallback(self.filterAnswers) # This one doesn't ever belong on UDP def lookupZone(self, name, timeout = 10): """ Perform an AXFR request. This is quite different from usual DNS requests. See http://cr.yp.to/djbdns/axfr-notes.html for more information. """ address = self.pickServer() if address is None: return defer.fail(IOError('No domain name servers available')) host,port = address from twisted.internet import reactor d = defer.Deferred() d.setTimeout(timeout or 10) controller = AXFRController(name, d) factory = DNSClientFactory(controller, timeout) factory.noisy = False #stfu reactor.connectTCP(host, port, factory) return d.addCallback(lambda x: (x, [], [])) class AXFRController: def __init__(self, name, deferred): self.name = name self.deferred = deferred self.soa = None self.records = [] def connectionMade(self, protocol): # dig saids recursion-desired to 0, so I will too message = dns.Message(protocol.pickID(), recDes=0) message.queries = [dns.Query(self.name, dns.AXFR, dns.IN)] protocol.writeMessage(message) def messageReceived(self, message, protocol): # Caveat: We have to handle two cases: All records are in 1 # message, or all records are in N messages. # According to http://cr.yp.to/djbdns/axfr-notes.html, # 'authority' and 'additional' are always empty, and only # 'answers' is present. self.records.extend(message.answers) if not self.records: return if not self.soa: if self.records[0].type == dns.SOA: #print "first SOA!" self.soa = self.records[0] if len(self.records) > 1 and self.records[-1].type == dns.SOA: #print "It's the second SOA! We're done." self.deferred.callback(self.records) class ThreadedResolver: __implements__ = (interfaces.IResolverSimple,) def __init__(self): self.cache = {} def getHostByName(self, name, timeout = 10): # XXX - Make this respect timeout d = threads.deferToThread(socket.gethostbyname, name) d.setTimeout(timeout) return d class DNSClientFactory(protocol.ClientFactory): def __init__(self, controller, timeout = 10): self.controller = controller self.timeout = timeout def clientConnectionLost(self, connector, reason): pass def buildProtocol(self, addr): p = dns.DNSProtocol(self.controller) p.factory = self return p def createResolver(servers = None, resolvconf = None, hosts = None): from twisted.names import resolve, cache, root, hosts as hostsModule if platform.getType() == 'posix': if resolvconf is None: resolvconf = '/etc/resolv.conf' if hosts is None: hosts = '/etc/hosts' theResolver = Resolver(resolvconf, servers) hostResolver = hostsModule.Resolver(hosts) else: if hosts is None: hosts = r'c:\windows\hosts' bootstrap = ThreadedResolver() hostResolver = hostsModule.Resolver(hosts) theResolver = root.bootstrap(bootstrap) L = [hostResolver, cache.CacheResolver(), theResolver] return resolve.ResolverChain(L) theResolver = None def _makeLookup(method): def lookup(*a, **kw): global theResolver if theResolver is None: try: theResolver = createResolver() except ValueError: theResolver = createResolver(servers=[('127.0.0.1', 53)]) return getattr(theResolver, method)(*a, **kw) return lookup for method in common.typeToMethod.values(): globals()[method] = _makeLookup(method) del method getHostByName = _makeLookup('getHostByName')