/*
 * socks.cxx
 *
 * SOCKS protocol
 *
 * Portable Windows Library
 *
 * Copyright (c) 1993-2002 Equivalence Pty. Ltd.
 *
 * The contents of this file are subject to the Mozilla Public License
 * Version 1.0 (the "License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 * http://www.mozilla.org/MPL/
 *
 * Software distributed under the License is distributed on an "AS IS"
 * basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
 * the License for the specific language governing rights and limitations
 * under the License.
 *
 * The Original Code is Portable Windows Library.
 *
 * The Initial Developer of the Original Code is Equivalence Pty. Ltd.
 *
 * $Log: socks.cxx,v $
 * Revision 1.9  2004/04/03 08:22:21  csoutheren
 * Remove pseudo-RTTI and replaced with real RTTI
 *
 * Revision 1.8  2003/09/08 01:42:48  dereksmithies
 * Add patch from Diego Tartara <dtartara@mens2.hq.novamens.com>. Many Thanks!
 *
 * Revision 1.7  2002/11/06 22:47:25  robertj
 * Fixed header comment (copyright etc)
 *
 * Revision 1.6  2002/08/05 05:40:45  robertj
 * Fixed missing pragma interface/implementation
 *
 * Revision 1.5  2001/09/10 02:51:23  robertj
 * Major change to fix problem with error codes being corrupted in a
 *   PChannel when have simultaneous reads and writes in threads.
 *
 * Revision 1.4  1999/11/23 08:45:10  robertj
 * Fixed bug in user/pass authentication version, thanks Dmitry <dipa@linkline.com>
 *
 * Revision 1.3  1999/02/16 08:08:06  robertj
 * MSVC 6.0 compatibility changes.
 *
 * Revision 1.2  1998/12/23 00:35:28  robertj
 * UDP support.
 *
 * Revision 1.1  1998/12/22 10:30:24  robertj
 * Initial revision
 *
 */

#include <ptlib.h>

#ifdef __GNUC__
#pragma implementation "socks.h"
#endif

#include <ptclib/socks.h>

#define new PNEW

#define SOCKS_VERSION_4 ((BYTE)4)
#define SOCKS_VERSION_5 ((BYTE)5)

#define SOCKS_CMD_CONNECT       ((BYTE)1)
#define SOCKS_CMD_BIND          ((BYTE)2)
#define SOCKS_CMD_UDP_ASSOCIATE ((BYTE)3)

#define SOCKS_AUTH_NONE      ((BYTE)0)
#define SOCKS_AUTH_USER_PASS ((BYTE)2)
#define SOCKS_AUTH_FAILED    ((BYTE)0xff)

#define SOCKS_ADDR_IPV4       ((BYTE)1)
#define SOCKS_ADDR_DOMAINNAME ((BYTE)3)
#define SOCKS_ADDR_IPV6       ((BYTE)4)


///////////////////////////////////////////////////////////////////////////////

PSocksProtocol::PSocksProtocol(WORD port)
  : serverHost("proxy")
{
  serverPort = DefaultServerPort;
  remotePort = port;
  localPort = 0;

  // get proxy information
  PConfig config(PConfig::System, "HKEY_CURRENT_USER\\Software\\Microsoft\\Windows\\CurrentVersion\\");

  // get the proxy configuration string
  PString str = config.GetString("Internet Settings", "ProxyServer", "");
  if (str.Find('=') == P_MAX_INDEX)
    SetServer("socks");
  else {
    PStringArray tokens = str.Tokenise(";");
    PINDEX i;
    for (i = 0; i < tokens.GetSize(); i++) {
      str = tokens[i];
      PINDEX equalPos = str.Find('=');
      if (equalPos != P_MAX_INDEX && (str.Left(equalPos) *= "socks")) {
        SetServer(str.Mid(equalPos+1));
        break;
      }
    }
  }
}


BOOL PSocksProtocol::SetServer(const PString & hostname, const char * service)
{
  return SetServer(hostname, PIPSocket::GetPortByService("tcp", service));
}


BOOL PSocksProtocol::SetServer(const PString & hostname, WORD port)
{
  PINDEX colon = hostname.Find(':');
  if (colon == P_MAX_INDEX)
    serverHost = hostname;
  else {
    unsigned portnum = hostname.Mid(colon+1).AsUnsigned();
    if (portnum == 0)
      serverHost = hostname;
    else {
      serverHost = hostname.Left(colon);
      port = (WORD)portnum;
    }
  }

  if (port == 0)
    port = DefaultServerPort;

  serverPort = port;

  return TRUE;
}


void PSocksProtocol::SetAuthentication(const PString & username, const PString & password)
{
  PAssert(authenticationUsername.GetLength() < 256, PInvalidParameter);
  authenticationUsername = username;
  PAssert(authenticationPassword.GetLength() < 256, PInvalidParameter);
  authenticationPassword = password;
}


BOOL PSocksProtocol::ConnectSocksServer(PTCPSocket & socket)
{
  PIPSocket::Address ipnum;
  if (!PIPSocket::GetHostAddress(serverHost, ipnum))
    return FALSE;

  remotePort = socket.GetPort();
  socket.SetPort(serverPort);
  return socket.PTCPSocket::Connect(0, ipnum);
}


BOOL PSocksProtocol::SendSocksCommand(PTCPSocket & socket,
                                      BYTE command,
                                      const char * hostname,
                                      PIPSocket::Address addr)
{
  if (!socket.IsOpen()) {
    if (!ConnectSocksServer(socket))
      return FALSE;

    socket << SOCKS_VERSION_5
           << (authenticationUsername.IsEmpty() ? '\001' : '\002') // length
           << SOCKS_AUTH_NONE;
    if (!authenticationUsername)
      socket << SOCKS_AUTH_USER_PASS;  // Simple cleartext username/password
    socket.flush();

    BYTE auth_pdu[2];
    if (!socket.ReadBlock(auth_pdu, sizeof(auth_pdu)))  // Should get 2 byte reply
      return FALSE;

    if (auth_pdu[0] != SOCKS_VERSION_5 || auth_pdu[1] == SOCKS_AUTH_FAILED) {
      socket.Close();
      SetErrorCodes(PChannel::AccessDenied, EACCES);
      return FALSE;
    }

    if (auth_pdu[1] == SOCKS_AUTH_USER_PASS) {
      // Send username and pasword
      socket << '\x01'
             << (BYTE)authenticationUsername.GetLength()  // Username length as single byte
             << authenticationUsername
             << (BYTE)authenticationPassword.GetLength()  // Password length as single byte
             << authenticationPassword
             << ::flush;

      if (!socket.ReadBlock(auth_pdu, sizeof(auth_pdu)))  // Should get 2 byte reply
        return FALSE;

      if (/*auth_pdu[0] != SOCKS_VERSION_5 ||*/ auth_pdu[1] != 0) {
        socket.Close();
        SetErrorCodes(PChannel::AccessDenied, EACCES);
        return FALSE;
      }
    }
  }

  socket << SOCKS_VERSION_5
         << command
         << '\000'; // Reserved
  if (hostname != NULL)
    socket << SOCKS_ADDR_DOMAINNAME << (BYTE)strlen(hostname) << hostname;
#if P_HAS_IPV6
  else if ( addr.GetVersion() == 6 )
  {
    socket << SOCKS_ADDR_IPV6;
    /* Should be 16 bytes */
    for ( PINDEX i = 0; i < addr.GetSize(); i++ )    
    {
        socket << addr[i];
    }
  }
#endif 
  else
    socket << SOCKS_ADDR_IPV4
           << addr.Byte1() << addr.Byte2() << addr.Byte3() << addr.Byte4();

  socket << (BYTE)(remotePort >> 8) << (BYTE)remotePort
         << ::flush;

  return ReceiveSocksResponse(socket, localAddress, localPort);
}


BOOL PSocksProtocol::ReceiveSocksResponse(PTCPSocket & socket,
                                          PIPSocket::Address & addr,
                                          WORD & port)
{
  int reply;
  if ((reply = socket.ReadChar()) < 0)
    return FALSE;

  if (reply != SOCKS_VERSION_5) {
    SetErrorCodes(PChannel::Miscellaneous, EINVAL);
    return FALSE;
  }

  if ((reply = socket.ReadChar()) < 0)
    return FALSE;

  switch (reply) {
    case 0 :  // No error
      break;

    case 2 :  // Refused permission
      SetErrorCodes(PChannel::AccessDenied, EACCES);
      return FALSE;

    case 3 : // Network unreachable
      SetErrorCodes(PChannel::NotFound, ENETUNREACH);
      return FALSE;

    case 4 : // Host unreachable
      SetErrorCodes(PChannel::NotFound, EHOSTUNREACH);
      return FALSE;

    case 5 : // Connection refused
      SetErrorCodes(PChannel::NotFound, EHOSTUNREACH);
      return FALSE;

    default :
      SetErrorCodes(PChannel::Miscellaneous, EINVAL);
      return FALSE;
  }

  // Ignore next byte (reserved)
  if ((reply = socket.ReadChar()) < 0)
    return FALSE;

  // Get type byte for bound address
  if ((reply = socket.ReadChar()) < 0)
    return FALSE;

  switch (reply) {
    case SOCKS_ADDR_DOMAINNAME :
      // Get length
      if ((reply = socket.ReadChar()) < 0)
        return FALSE;

      if (!PIPSocket::GetHostAddress(socket.ReadString(reply), addr))
        return FALSE;
      break;

    case SOCKS_ADDR_IPV4 :
      {
          in_addr add;
          if (!socket.ReadBlock(&add, sizeof(add)))
              return FALSE;
          addr = add;
      }
      break;

#if P_HAS_IPV6    
    case SOCKS_ADDR_IPV6 :
      {
        in6_addr add;
        if (!socket.ReadBlock(&add, sizeof(add)))
            return FALSE;
        addr = add;
      }
      break;
#endif

    default :
      SetErrorCodes(PChannel::Miscellaneous, EINVAL);
      return FALSE;
  }

  WORD rxPort;
  if (!socket.ReadBlock(&rxPort, sizeof(rxPort)))
    return FALSE;

  port = PSocket::Net2Host(rxPort);
  return TRUE;
}


///////////////////////////////////////////////////////////////////////////////

PSocksSocket::PSocksSocket(WORD port)
  : PSocksProtocol(port)
{
}


BOOL PSocksSocket::Connect(const PString & address)
{
  if (!SendSocksCommand(*this, SOCKS_CMD_CONNECT, address, 0))
    return FALSE;

  port = remotePort;
  return TRUE;
}


BOOL PSocksSocket::Connect(const Address & addr)
{
  if (!SendSocksCommand(*this, SOCKS_CMD_CONNECT, NULL, addr))
    return FALSE;

  port = remotePort;
  return TRUE;
}


BOOL PSocksSocket::Connect(WORD, const Address &)
{
  PAssertAlways(PUnsupportedFeature);
  return FALSE;
}


BOOL PSocksSocket::Listen(unsigned, WORD newPort, Reusability reuse)
{
  PAssert(newPort == 0 && port == 0, PUnsupportedFeature);
  PAssert(reuse, PUnsupportedFeature);

  if (!SendSocksCommand(*this, SOCKS_CMD_BIND, NULL, 0))
    return FALSE;

  port = localPort;
  return TRUE;
}


BOOL PSocksSocket::Accept()
{
  if (!IsOpen())
    return FALSE;

  return ReceiveSocksResponse(*this, remoteAddress, remotePort);
}


BOOL PSocksSocket::Accept(PSocket & socket)
{
  // If is right class, transfer the SOCKS socket to class to receive the accept
  // The "listener" socket is implicitly closed as there is really only one
  // handle in a SOCKS BIND operation.
  PAssert(PIsDescendant(&socket, PSocksSocket), PUnsupportedFeature);
  os_handle = ((PSocksSocket &)socket).TransferHandle(*this);
  return Accept();
}


int PSocksSocket::TransferHandle(PSocksSocket & destination)
{
  // This "transfers" the socket from one onstance to another.

  int the_handle = os_handle;
  destination.SetReadTimeout(readTimeout);
  destination.SetWriteTimeout(writeTimeout);

  // Close the instance of the socket but don't actually close handle.
  os_handle = -1;

  return the_handle;
}


BOOL PSocksSocket::GetLocalAddress(Address & addr)
{
  if (!IsOpen())
    return FALSE;

  addr = localAddress;
  return TRUE;
}


BOOL PSocksSocket::GetLocalAddress(Address & addr, WORD & port)
{
  if (!IsOpen())
    return FALSE;

  addr = localAddress;
  port = localPort;
  return TRUE;
}


BOOL PSocksSocket::GetPeerAddress(Address & addr)
{
  if (!IsOpen())
    return FALSE;

  addr = remoteAddress;
  return TRUE;
}


BOOL PSocksSocket::GetPeerAddress(Address & addr, WORD & port)
{
  if (!IsOpen())
    return FALSE;

  addr = remoteAddress;
  port = remotePort;
  return TRUE;
}


void PSocksSocket::SetErrorCodes(PChannel::Errors errCode, int osErr)
{
  SetErrorValues(errCode, osErr);
}


///////////////////////////////////////////////////////////////////////////////

PSocks4Socket::PSocks4Socket(WORD port)
  : PSocksSocket(port)
{
}


PSocks4Socket::PSocks4Socket(const PString & host, WORD port)
  : PSocksSocket(port)
{
  Connect(host);
}


PObject * PSocks4Socket::Clone() const
{
  return new PSocks4Socket(remotePort);
}


BOOL PSocks4Socket::SendSocksCommand(PTCPSocket & socket,
                                     BYTE command,
                                     const char * hostname,
                                     Address addr)
{
  if (hostname != NULL) {
    if (!GetHostAddress(hostname, addr))
      return FALSE;
  }

  if (!IsOpen()) {
    if (!ConnectSocksServer(*this))
      return FALSE;
  }

  PString user = PProcess::Current().GetUserName();
  socket << SOCKS_VERSION_4
         << command
         << (BYTE)(remotePort >> 8) << (BYTE)remotePort
         << addr.Byte1() << addr.Byte2() << addr.Byte3() << addr.Byte4()
         << user << ((BYTE)0)
         << ::flush;

  return ReceiveSocksResponse(socket, localAddress, localPort);
}


BOOL PSocks4Socket::ReceiveSocksResponse(PTCPSocket & socket,
                                         Address & addr,
                                         WORD & port)
{
  int reply;
  if ((reply = socket.ReadChar()) < 0)
    return FALSE;

  if (reply != 0 /*!= SOCKS_VERSION_4*/) {
    SetErrorCodes(PChannel::Miscellaneous, EINVAL);
    return FALSE;
  }

  if ((reply = socket.ReadChar()) < 0)
    return FALSE;

  switch (reply) {
    case 90 :  // No error
      break;

    case 91 : // Connection refused
      SetErrorCodes(PChannel::NotFound, EHOSTUNREACH);
      return FALSE;

    case 92 :  // Refused permission
      SetErrorCodes(PChannel::AccessDenied, EACCES);
      return FALSE;

    default :
      SetErrorCodes(PChannel::Miscellaneous, EINVAL);
      return FALSE;
  }

  WORD rxPort;
  if (!socket.ReadBlock(&rxPort, sizeof(rxPort)))
    return FALSE;

  port = PSocket::Net2Host(rxPort);

  in_addr add;
  if ( socket.ReadBlock(&add, sizeof(add)) )
  {
    addr = add;
    return TRUE;
  }

  return FALSE;
}


///////////////////////////////////////////////////////////////////////////////

PSocks5Socket::PSocks5Socket(WORD port)
  : PSocksSocket(port)
{
}


PSocks5Socket::PSocks5Socket(const PString & host, WORD port)
  : PSocksSocket(port)
{
  Connect(host);
}


PObject * PSocks5Socket::Clone() const
{
  return new PSocks5Socket(remotePort);
}


///////////////////////////////////////////////////////////////////////////////

PSocksUDPSocket::PSocksUDPSocket(WORD port)
  : PSocksProtocol(port)
{
}


PSocksUDPSocket::PSocksUDPSocket(const PString & host, WORD port)
  : PSocksProtocol(port)
{
  Connect(host);
}


PObject * PSocksUDPSocket::Clone() const
{
  return new PSocksUDPSocket(port);
}


BOOL PSocksUDPSocket::Connect(const PString & address)
{
  if (!SendSocksCommand(socksControl, SOCKS_CMD_UDP_ASSOCIATE, address, 0))
    return FALSE;

  socksControl.GetPeerAddress(serverAddress);
  return TRUE;
}


BOOL PSocksUDPSocket::Connect(const Address & addr)
{
  if (!SendSocksCommand(socksControl, SOCKS_CMD_UDP_ASSOCIATE, NULL, addr))
    return FALSE;

  socksControl.GetPeerAddress(serverAddress);
  return TRUE;
}


BOOL PSocksUDPSocket::Connect(WORD, const Address &)
{
  PAssertAlways(PUnsupportedFeature);
  return FALSE;
}


BOOL PSocksUDPSocket::Listen(unsigned, WORD newPort, Reusability reuse)
{
  PAssert(newPort == 0 && port == 0, PUnsupportedFeature);
  PAssert(reuse, PUnsupportedFeature);

  if (!SendSocksCommand(socksControl, SOCKS_CMD_UDP_ASSOCIATE, NULL, 0))
    return FALSE;

  socksControl.GetPeerAddress(serverAddress);
  port = localPort;
  return TRUE;
}


BOOL PSocksUDPSocket::GetLocalAddress(Address & addr)
{
  if (!IsOpen())
    return FALSE;

  addr = localAddress;
  return TRUE;
}


BOOL PSocksUDPSocket::GetLocalAddress(Address & addr, WORD & port)
{
  if (!IsOpen())
    return FALSE;

  addr = localAddress;
  port = localPort;
  return TRUE;
}


BOOL PSocksUDPSocket::GetPeerAddress(Address & addr)
{
  if (!IsOpen())
    return FALSE;

  addr = remoteAddress;
  return TRUE;
}


BOOL PSocksUDPSocket::GetPeerAddress(Address & addr, WORD & port)
{
  if (!IsOpen())
    return FALSE;

  addr = remoteAddress;
  port = remotePort;
  return TRUE;
}


BOOL PSocksUDPSocket::ReadFrom(void * buf, PINDEX len, Address & addr, WORD & port)
{
  PBYTEArray newbuf(len+262);
  Address rx_addr;
  WORD rx_port;
  if (!PUDPSocket::ReadFrom(newbuf.GetPointer(), newbuf.GetSize(), rx_addr, rx_port))
    return FALSE;

  if (rx_addr != serverAddress || rx_port != serverPort)
    return FALSE;

  PINDEX port_pos;
  switch (newbuf[3]) {
    case SOCKS_ADDR_DOMAINNAME :
      if (!PIPSocket::GetHostAddress(PString((const char *)&newbuf[5], (PINDEX)newbuf[4]), addr))
        return FALSE;

      port_pos = newbuf[4]+5;
      break;

    case SOCKS_ADDR_IPV4 :
      memcpy(&addr, &newbuf[4], 4);
      port_pos = 4;
      break;

    default :
      SetErrorCodes(PChannel::Miscellaneous, EINVAL);
      return FALSE;
  }

  port = (WORD)((newbuf[port_pos] << 8)|newbuf[port_pos+1]);
  memcpy(buf, &newbuf[port_pos+2], len);

  return TRUE;
}


BOOL PSocksUDPSocket::WriteTo(const void * buf, PINDEX len, const Address & addr, WORD port)
{
  PBYTEArray newbuf(len+10);
  BYTE * bufptr = newbuf.GetPointer();

  // Build header, bytes 0, 1 & 2 are zero

  bufptr[3] = SOCKS_ADDR_IPV4;
  memcpy(bufptr+4, &addr, 4);
  bufptr[8] = (BYTE)(port >> 8);
  bufptr[9] = (BYTE)port;
  memcpy(bufptr+10, buf, len);

  return PUDPSocket::WriteTo(newbuf, newbuf.GetSize(), serverAddress, serverPort);
}


void PSocksUDPSocket::SetErrorCodes(PChannel::Errors errCode, int osErr)
{
  SetErrorValues(errCode, osErr);
}


// End of File ///////////////////////////////////////////////////////////////


syntax highlighted by Code2HTML, v. 0.9.1