# Twisted, the Framework of Your Internet
# Copyright (C) 2001-2004 Matthew W. Lefkowitz
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of version 2.1 of the GNU Lesser General Public
# License as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
#
from twisted.conch.error import ConchError
from twisted.conch.ssh import channel, connection
from twisted.internet import defer, protocol, reactor
from twisted.python import log
from twisted.spread import banana
import os, stat, pickle
import types # this is for evil
class SSHUnixClientFactory(protocol.ClientFactory):
noisy = 1
def __init__(self, d, options, userAuthObject):
self.d = d
self.options = options
self.userAuthObject = userAuthObject
# def clientConnectionLost(self, connector, reason):
# stopConnection()
def clientConnectionFailed(self, connector, reason):
try:
os.unlink(connector.transport.addr)
except:
pass
if not self.d: return
d = self.d
self.d = None
d.errback(reason)
#reactor.connectTCP(options['host'], options['port'], SSHClientFactory())
def startedConnecting(self, connector):
fd = connector.transport.fileno()
stats = os.fstat(fd)
try:
filestats = os.stat(connector.transport.addr)
except:
connector.stopConnecting()
return
if stat.S_IMODE(filestats[0]) != 0600:
log.msg("socket mode is not 0600: %s" % oct(stat.S_IMODE(stats[0])))
elif filestats[4] != os.getuid():
log.msg("socket not owned by us: %s" % stats[4])
elif filestats[5] != os.getgid():
log.msg("socket not owned by our group: %s" % stats[5])
# XXX reenable this when i can fix it for cygwin
#elif filestats[-3:] != stats[-3:]:
# log.msg("socket doesn't have same create times")
else:
log.msg('conecting OK')
return
connector.stopConnecting()
def buildProtocol(self, addr):
# here comes the EVIL
obj = self.userAuthObject.instance
bases = []
for base in obj.__class__.__bases__:
if base == connection.SSHConnection:
bases.append(SSHUnixClientProtocol)
else:
bases.append(base)
newClass = types.ClassType(obj.__class__.__name__, tuple(bases), obj.__class__.__dict__)
obj.__class__ = newClass
SSHUnixClientProtocol.__init__(obj)
log.msg('returning %s' % obj)
return obj
class SSHUnixServerFactory(protocol.Factory):
def __init__(self, conn):
self.conn = conn
def buildProtocol(self, addr):
return SSHUnixServerProtocol(self.conn)
class SSHUnixProtocol(banana.Banana):
knownDialects = ['none']
def __init__(self):
banana.Banana.__init__(self)
self.deferredQueue = []
self.deferreds = {}
self.deferredID = 0
def connectionMade(self):
log.msg('connection made %s' % self)
banana.Banana.connectionMade(self)
def expressionReceived(self, lst):
vocabName = lst[0]
fn = "msg_%s" % vocabName
func = getattr(self, fn)
func(lst[1:])
def sendMessage(self, vocabName, *tup):
self.sendEncoded([vocabName] + list(tup))
def returnDeferredLocal(self):
d = defer.Deferred()
self.deferredQueue.append(d)
return d
def returnDeferredWire(self, d):
di = self.deferredID
self.deferredID += 1
self.sendMessage('returnDeferred', di)
d.addCallback(self._cbDeferred, di)
d.addErrback(self._ebDeferred, di)
def _cbDeferred(self, result, di):
self.sendMessage('callbackDeferred', di, pickle.dumps(result))
def _ebDeferred(self, reason, di):
self.sendMessage('errbackDeferred', di, pickle.dumps(reason))
def msg_returnDeferred(self, lst):
deferredID = lst[0]
self.deferreds[deferredID] = self.deferredQueue.pop(0)
def msg_callbackDeferred(self, lst):
deferredID, result = lst
d = self.deferreds[deferredID]
del self.deferreds[deferredID]
d.callback(pickle.loads(result))
def msg_errbackDeferred(self, lst):
deferredID, result = lst
d = self.deferreds[deferredID]
del self.deferreds[deferredID]
d.errback(pickle.loads(result))
class SSHUnixClientProtocol(SSHUnixProtocol):
def __init__(self):
SSHUnixProtocol.__init__(self)
self.isClient = 1
self.channelQueue = []
self.channels = {}
def connectionReady(self):
log.msg('connection ready')
self.serviceStarted()
def connectionLost(self, reason):
self.serviceStopped()
def requestRemoteForwarding(self, remotePort, hostport):
self.sendMessage('requestRemoteForwarding', remotePort, hostport)
def cancelRemoteForwarding(self, remotePort):
self.sendMessage('cancelRemoteForwarding', remotePort)
def sendGlobalRequest(self, request, data, wantReply = 0):
self.sendMessage('sendGlobalRequest', request, data, wantReply)
if wantReply:
return self.returnDeferredLocal()
def openChannel(self, channel, extra = ''):
self.channelQueue.append(channel)
channel.conn = self
self.sendMessage('openChannel', channel.name,
channel.localWindowSize,
channel.localMaxPacket, extra)
def sendRequest(self, channel, requestType, data, wantReply = 0):
self.sendMessage('sendRequest', channel.id, requestType, data, wantReply)
if wantReply:
self.returnDeferredLocal()
def adjustWindow(self, channel, bytesToAdd):
self.sendMessage('adjustWindow', channel.id, bytesToAdd)
def sendData(self, channel, data):
self.sendMessage('sendData', channel.id, data)
def sendExtendedData(self, channel, dataType, data):
self.sendMessage('sendExtendedData', channel.id, data)
def sendEOF(self, channel):
self.sendMessage('sendEOF', channel.id)
def sendClose(self, channel):
self.sendMessage('sendClose', channel.id)
def msg_channelID(self, lst):
channelID = lst[0]
self.channels[channelID] = self.channelQueue.pop(0)
self.channels[channelID].id = channelID
def msg_channelOpen(self, lst):
channelID, remoteWindow, remoteMax, specificData = lst
channel = self.channels[channelID]
channel.remoteWindowLeft = remoteWindow
channel.remoteMaxPacket = remoteMax
channel.channelOpen(specificData)
def msg_openFailed(self, lst):
channelID, reason = lst
self.channels[channelID].openFailed(pickle.loads(reason))
del self.channels[channelID]
def msg_addWindowBytes(self, lst):
channelID, bytes = lst
self.channels[channelID].addWindowBytes(bytes)
def msg_requestReceived(self, lst):
channelID, requestType, data = lst
d = defer.maybeDeferred(self.channels[channelID].requestReceived, requestType, data)
self.returnDeferredWire(d)
def msg_dataReceived(self, lst):
channelID, data = lst
self.channels[channelID].dataReceived(data)
def msg_extReceived(self, lst):
channelID, dataType, data = lst
self.channels[channelID].extReceived(dataType, data)
def msg_eofReceived(self, lst):
channelID = lst[0]
self.channels[channelID].eofReceived()
def msg_closed(self, lst):
channelID = lst[0]
self.channels[channelID].closed()
del self.channels[channelID]
class SSHUnixServerProtocol(SSHUnixProtocol):
def __init__(self, conn):
SSHUnixProtocol.__init__(self)
self.isClient = 0
self.conn = conn
def haveChannel(self, channelID):
return self.conn.channels.has_key(channelID)
def getChannel(self, channelID):
channel = self.conn.channels[channelID]
if not isinstance(channel, SSHUnixChannel):
raise ConchError('nice try bub')
return channel
def msg_requestRemoteForwarding(self, lst):
remotePort, hostport = lst
hostport = tuple(hostport)
self.conn.requestRemoteForwarding(remotePort, hostport)
def msg_cancelRemoteForwarding(self, lst):
[remotePort] = lst
self.conn.cancelRemoteForwarding(remotePort)
def msg_sendGlobalRequest(self, lst):
requestName, data, wantReply = lst
d = self.conn.sendGlobalRequest(requestName, data, wantReply)
if wantReply:
self.returnDeferred(d)
def msg_openChannel(self, lst):
name, windowSize, maxPacket, extra = lst
channel = SSHUnixChannel(self, name, windowSize, maxPacket)
self.conn.openChannel(channel, extra)
self.sendMessage('channelID', channel.id)
def msg_sendRequest(self, lst):
cn, requestType, data, wantReply = lst
if not self.haveChannel(cn):
if wantReply:
self.returnDeferred(defer.fail(ConchError("no channel")))
channel = self.getChannel(cn)
d = self.conn.sendRequest(channel, requestType, data, wantReply)
if wantReply:
self.returnDeferredWire(d)
def msg_adjustWindow(self, lst):
cn, bytesToAdd = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.adjustWindow(channel, bytesToAdd)
def msg_sendData(self, lst):
cn, data = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.sendData(channel, data)
def msg_sendExtended(self, lst):
cn, dataType, data = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.sendExtendedData(channel, dataType, data)
def msg_sendEOF(self, lst):
(cn, ) = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.sendEOF(channel)
def msg_sendClose(self, lst):
(cn, ) = lst
if not self.haveChannel(cn): return
channel = self.getChannel(cn)
self.conn.sendClose(channel)
class SSHUnixChannel(channel.SSHChannel):
def __init__(self, unix, name, windowSize, maxPacket):
channel.SSHChannel.__init__(self, windowSize, maxPacket, conn = unix.conn)
self.unix = unix
self.name = name
def channelOpen(self, specificData):
self.unix.sendMessage('channelOpen', self.id, self.remoteWindowLeft,
self.remoteMaxPacket, specificData)
def openFailed(self, reason):
self.unix.sendMessage('openFailed', self.id, pickle.dumps(reason))
def addWindowBytes(self, bytes):
self.unix.sendMessage('addWindowBytes', self.id, bytes)
def dataReceived(self, data):
self.unix.sendMessage('dataReceived', self.id, data)
def requestReceived(self, reqType, data):
self.unix.sendMessage('requestReceived', self.id, reqType, data)
return self.unix.returnDeferredLocal()
def extReceived(self, dataType, data):
self.unix.sendMessage('extReceived', self.id, dataType, data)
def eofReceived(self):
self.unix.sendMessage('eofReceived', self.id)
def closed(self):
self.unix.sendMessage('closed', self.id)
def connect(host, port, options, verifyHostKey, userAuthObject):
if options['nocache']:
return defer.fail(ConchError('not using connection caching'))
d = defer.Deferred()
filename = os.path.expanduser("~/.conch-%s-%s-%i" % (userAuthObject.user, host, port))
factory = SSHUnixClientFactory(d, options, userAuthObject)
reactor.connectUNIX(filename, factory, timeout=2, checkPID=1)
return d
syntax highlighted by Code2HTML, v. 0.9.1