# Twisted, the Framework of Your Internet
# Copyright (C) 2001 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
"""
Asynchronous client DNS
API Stability: Unstable
Future plans: Proper nameserver acquisition on Windows/MacOS,
better caching, respect timeouts
@author: U{Jp Calderone<mailto:exarkun@twistedmatrix.com>}
"""
from __future__ import nested_scopes
import socket
import os
import errno
import time
# Twisted imports
from twisted.python.runtime import platform
from twisted.internet import defer, protocol, interfaces, threads
from twisted.python import log, failure
from twisted.protocols import dns
import common
class Resolver(common.ResolverBase):
__implements__ = (interfaces.IResolver,)
index = 0
timeout = None
factory = None
servers = None
dynServers = ()
pending = None
protocol = None
connections = None
resolv = None
_lastResolvTime = None
_resolvReadInterval = 60
def __init__(self, resolv = None, servers = None, timeout = (1, 3, 11, 45)):
"""
@type servers: C{list} of C{(str, int)} or C{None}
@param servers: If not None, interpreted as a list of addresses of
domain name servers to attempt to use for this lookup. Addresses
should be in dotted-quad form. If specified, overrides C{resolv}.
@type resolv: C{str}
@param resolv: Filename to read and parse as a resolver(5)
configuration file.
@type timeout: Sequence of C{int}
@param timeout: Default number of seconds after which to reissue the query.
When the last timeout expires, the query is considered failed.
@raise ValueError: Raised if no nameserver addresses can be found.
"""
common.ResolverBase.__init__(self)
self.timeout = timeout
if servers is None:
self.servers = []
else:
self.servers = servers
self.resolv = resolv
if not len(self.servers) and not resolv:
raise ValueError, "No nameservers specified"
self.factory = DNSClientFactory(self, timeout)
self.factory.noisy = 0 # Be quiet by default
self.protocol = dns.DNSDatagramProtocol(self)
self.protocol.noisy = 0 # You too
self.connections = []
self.pending = []
self.maybeParseConfig()
def __getstate__(self):
d = self.__dict__.copy()
d['connections'] = []
d['_parseCall'] = None
return d
def __setstate__(self, state):
self.__dict__.update(state)
self.maybeParseConfig()
def maybeParseConfig(self):
if self.resolv is None:
# Don't try to parse it, don't set up a call loop
return
try:
resolvConf = file(self.resolv)
except IOError, e:
if e.errno == errno.ENOENT:
pass
else:
raise
else:
mtime = os.fstat(resolvConf.fileno()).st_mtime
if mtime != self._lastResolvTime:
log.msg('%s changed, reparsing' % (self.resolv,))
self._lastResolvTime = mtime
self.parseConfig(resolvConf)
# Check the mtime again in a little while
from twisted.internet import reactor
self._parseCall = reactor.callLater(self._resolvReadInterval, self.maybeParseConfig)
def parseConfig(self, resolvConf):
servers = []
for L in resolvConf:
L = L.strip()
if L.startswith('nameserver'):
resolver = (L.split()[1], dns.PORT)
servers.append(resolver)
log.msg("Resolver added %r to server list" % (resolver,))
elif L.startswith('domain'):
try:
self.domain = L.split()[1]
except IndexError:
self.domain = ''
self.search = None
elif L.startswith('search'):
try:
self.search = L.split()[1:]
except IndexError:
self.search = ''
self.domain = None
self.dynServers = servers
def pickServer(self):
"""
Return the address of a nameserver.
TODO: Weight servers for response time so faster ones can be
preferred.
"""
if not self.servers and not self.dynServers:
return None
serverL = len(self.servers)
dynL = len(self.dynServers)
self.index += 1
self.index %= (serverL + dynL)
if self.index < serverL:
return self.servers[self.index]
else:
return self.dynServers[self.index - serverL]
def connectionMade(self, protocol):
self.connections.append(protocol)
for (d, q, t) in self.pending:
self.queryTCP(q, t).chainDeferred(d)
del self.pending[:]
def messageReceived(self, message, protocol, address = None):
log.msg("Unexpected message (%d) received from %r" % (message.id, address))
def queryUDP(self, queries, timeout = None):
"""
Make a number of DNS queries via UDP.
@type queries: A C{list} of C{dns.Query} instances
@param queries: The queries to make.
@type timeout: Sequence of C{int}
@param timeout: Number of seconds after which to reissue the query.
When the last timeout expires, the query is considered failed.
@rtype: C{Deferred}
@raise C{twisted.internet.defer.TimeoutError}: When the query times
out.
"""
if timeout is None:
timeout = self.timeout
address = self.pickServer()
if address is None:
return defer.fail(IOError("No domain name servers available"))
return self.protocol.query(address, queries, timeout[0]
).addErrback(self._reissue, address, queries, timeout[1:]
)
def _reissue(self, reason, address, query, timeout):
reason.trap(defer.TimeoutError)
if timeout and self.protocol.transport:
d = self.protocol.query(address, query, timeout[0], reason.value.id)
d.addErrback(self._reissue, address, query, timeout[1:])
return d
try:
del self.protocol.resends[reason.value.id]
except:
pass
return failure.Failure(defer.TimeoutError(query))
def queryTCP(self, queries, timeout = 10):
"""
Make a number of DNS queries via TCP.
@type queries: Any non-zero number of C{dns.Query} instances
@param queries: The queries to make.
@type timeout: C{int}
@param timeout: The number of seconds after which to fail.
@rtype: C{Deferred}
"""
if not len(self.connections):
address = self.pickServer()
if address is None:
return defer.fail(IOError("No domain name servers available"))
host, port = address
from twisted.internet import reactor
reactor.connectTCP(host, port, self.factory)
self.pending.append((defer.Deferred(), queries, timeout))
return self.pending[-1][0]
else:
return self.connections[0].query(queries, timeout)
def filterAnswers(self, message):
if message.trunc:
return self.queryTCP(message.queries).addCallback(self.filterAnswers)
else:
return (message.answers, message.authority, message.additional)
def _lookup(self, name, cls, type, timeout):
return self.queryUDP(
[dns.Query(name, type, cls)], timeout
).addCallback(self.filterAnswers)
# This one doesn't ever belong on UDP
def lookupZone(self, name, timeout = 10):
"""
Perform an AXFR request. This is quite different from usual
DNS requests. See http://cr.yp.to/djbdns/axfr-notes.html for
more information.
"""
address = self.pickServer()
if address is None:
return defer.fail(IOError('No domain name servers available'))
host,port = address
from twisted.internet import reactor
d = defer.Deferred()
d.setTimeout(timeout or 10)
controller = AXFRController(name, d)
factory = DNSClientFactory(controller, timeout)
factory.noisy = False #stfu
reactor.connectTCP(host, port, factory)
return d.addCallback(lambda x: (x, [], []))
class AXFRController:
def __init__(self, name, deferred):
self.name = name
self.deferred = deferred
self.soa = None
self.records = []
def connectionMade(self, protocol):
# dig saids recursion-desired to 0, so I will too
message = dns.Message(protocol.pickID(), recDes=0)
message.queries = [dns.Query(self.name, dns.AXFR, dns.IN)]
protocol.writeMessage(message)
def messageReceived(self, message, protocol):
# Caveat: We have to handle two cases: All records are in 1
# message, or all records are in N messages.
# According to http://cr.yp.to/djbdns/axfr-notes.html,
# 'authority' and 'additional' are always empty, and only
# 'answers' is present.
self.records.extend(message.answers)
if not self.records:
return
if not self.soa:
if self.records[0].type == dns.SOA:
#print "first SOA!"
self.soa = self.records[0]
if len(self.records) > 1 and self.records[-1].type == dns.SOA:
#print "It's the second SOA! We're done."
self.deferred.callback(self.records)
class ThreadedResolver:
__implements__ = (interfaces.IResolverSimple,)
def __init__(self):
self.cache = {}
def getHostByName(self, name, timeout = 10):
# XXX - Make this respect timeout
d = threads.deferToThread(socket.gethostbyname, name)
d.setTimeout(timeout)
return d
class DNSClientFactory(protocol.ClientFactory):
def __init__(self, controller, timeout = 10):
self.controller = controller
self.timeout = timeout
def clientConnectionLost(self, connector, reason):
pass
def buildProtocol(self, addr):
p = dns.DNSProtocol(self.controller)
p.factory = self
return p
def createResolver(servers = None, resolvconf = None, hosts = None):
from twisted.names import resolve, cache, root, hosts as hostsModule
if platform.getType() == 'posix':
if resolvconf is None:
resolvconf = '/etc/resolv.conf'
if hosts is None:
hosts = '/etc/hosts'
theResolver = Resolver(resolvconf, servers)
hostResolver = hostsModule.Resolver(hosts)
else:
if hosts is None:
hosts = r'c:\windows\hosts'
bootstrap = ThreadedResolver()
hostResolver = hostsModule.Resolver(hosts)
theResolver = root.bootstrap(bootstrap)
L = [hostResolver, cache.CacheResolver(), theResolver]
return resolve.ResolverChain(L)
theResolver = None
def _makeLookup(method):
def lookup(*a, **kw):
global theResolver
if theResolver is None:
try:
theResolver = createResolver()
except ValueError:
theResolver = createResolver(servers=[('127.0.0.1', 53)])
return getattr(theResolver, method)(*a, **kw)
return lookup
for method in common.typeToMethod.values():
globals()[method] = _makeLookup(method)
del method
getHostByName = _makeLookup('getHostByName')
syntax highlighted by Code2HTML, v. 0.9.1