/*
* sasl.cxx
*
* Simple Authentication Security Layer interface classes
*
* Portable Windows Library
*
* Copyright (c) 2004 Reitek S.p.A.
*
* The contents of this file are subject to the Mozilla Public License
* Version 1.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
* http://www.mozilla.org/MPL/
*
* Software distributed under the License is distributed on an "AS IS"
* basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
* the License for the specific language governing rights and limitations
* under the License.
*
* The Original Code is Portable Windows Library.
*
* The Initial Developer of the Original Code is Post Increment
*
* Contributor(s): ______________________________________.
*
* $Log: psasl.cxx,v $
* Revision 1.5 2004/05/09 07:23:50 rjongbloed
* More work on XMPP, thanks Federico Pinna and Reitek S.p.A.
*
* Revision 1.4 2004/04/28 11:26:43 csoutheren
* Hopefully fixed SASL and SASL2 problems
*
* Revision 1.3 2004/04/24 06:27:56 rjongbloed
* Fixed GCC 3.4.0 warnings about PAssertNULL and improved recoverability on
* NULL pointer usage in various bits of code.
*
* Revision 1.2 2004/04/18 12:34:22 csoutheren
* Modified to compile under Unix
*
* Revision 1.1 2004/04/18 12:02:31 csoutheren
* Added classes for SASL authentication
* Thanks to Federico Pinna and Reitek S.p.A.
*
*
*/
#ifdef __GNUC__
#pragma implementation "psasl.h"
#endif
#include <ptlib.h>
#include <ptclib/psasl.h>
#include <ptclib/cypher.h>
#if P_SASL2
extern "C" {
#if P_HAS_SASL_SASL_H
#include <sasl/sasl.h>
#else
#include <sasl.h>
#endif
};
#ifdef _MSC_VER
#pragma comment(lib, P_SASL_LIBRARY)
#endif
///////////////////////////////////////////////////////
static int PSASL_ClientRealm(void *, int id, const char **, const char **result)
{
if (id != SASL_CB_GETREALM)
return SASL_FAIL;
*result = (const char *)PSASLClient::GetRealm();
return SASL_OK;
}
static int PSASL_ClientAuthID(void *context, int id, const char **result, unsigned *len)
{
if (id != SASL_CB_AUTHNAME)
return SASL_FAIL;
if (PAssertNULL(context) == NULL)
return SASL_FAIL;
const PSASLClient * c = (const PSASLClient *)context;
*result = (const char *)c->GetAuthID();
if (len)
*len = *result ? strlen(*result) : 0;
return SASL_OK;
}
static int PSASL_ClientUserID(void *context, int id, const char **result, unsigned *len)
{
if (id != SASL_CB_USER)
return SASL_FAIL;
if (PAssertNULL(context) == NULL)
return SASL_FAIL;
const PSASLClient * c = (const PSASLClient *)context;
*result = (const char *)c->GetUserID();
if (len)
*len = *result ? strlen(*result) : 0;
return SASL_OK;
}
static int PSASL_ClientPassword(sasl_conn_t *, void *context, int id, sasl_secret_t **psecret)
{
if (id != SASL_CB_PASS)
return SASL_FAIL;
if (PAssertNULL(context) == NULL)
return SASL_FAIL;
const PSASLClient * c = (const PSASLClient *)context;
const char * pwd = c->GetPassword();
if (!pwd)
return SASL_FAIL;
size_t len = strlen(pwd);
*psecret = (sasl_secret_t *)malloc(sizeof(sasl_secret_t) + len);
(*psecret)->len = len;
strcpy((char *)(*psecret)->data, pwd);
return SASL_OK;
}
static int PSASL_ClientGetPath(void *, const char ** path)
{
*path = (const char *)PSASLClient::GetPath();
return SASL_OK;
}
static int PSASL_ClientLog(void *, int priority, const char *message)
{
#if PTRACING
static const char * labels[7] = { "Error", "Fail", "Warning", "Note", "Debug", "Trace", "Pass" };
#endif
if (!message || priority > SASL_LOG_PASS)
return SASL_BADPARAM;
if (priority < SASL_LOG_ERR)
return SASL_OK;
PTRACE(priority, "SASL\t" << labels[priority - 1] << ": " << message);
return SASL_OK;
}
static void psasl_Initialise()
{
PINDEX max = PSASLClient::GetPath().IsEmpty() ? 3 : 4;
sasl_callback_t * cbs = new sasl_callback_t[max];
cbs[0].id = SASL_CB_GETREALM;
cbs[0].proc = (int (*)())&PSASL_ClientRealm;
cbs[0].context = 0;
cbs[1].id = SASL_CB_LOG;
cbs[1].proc = (int (*)())&PSASL_ClientLog;
cbs[1].context = 0;
if (max == 4) {
cbs[2].id = SASL_CB_GETPATH;
cbs[2].proc = (int (*)())&PSASL_ClientGetPath;
cbs[2].context = 0;
}
cbs[max - 1].id = SASL_CB_LIST_END;
cbs[max - 1].proc = 0;
cbs[max - 1].context = 0;
sasl_client_init(cbs);
}
static PAtomicInteger psasl_UsageCount(0);
PString PSASLClient::s_Realm;
PString PSASLClient::s_Path;
PSASLClient::PSASLClient(const PString& service, const PString& uid, const PString& auth, const PString& pwd) :
m_CallBacks(NULL),
m_ConnState(NULL),
m_Service(service),
m_UserID(uid.IsEmpty() ? auth : uid),
m_AuthID(auth.IsEmpty() ? uid : auth),
m_Password(pwd)
{
if (++psasl_UsageCount == 1)
psasl_Initialise();
}
PSASLClient::~PSASLClient()
{
if (m_ConnState)
End();
delete (sasl_callback_t *)m_CallBacks;
}
BOOL PSASLClient::Init(const PString& fqdn, PStringSet& supportedMechanisms)
{
if (!m_CallBacks)
{
sasl_callback_t * cbs = new sasl_callback_t[4];
cbs[0].id = SASL_CB_AUTHNAME;
cbs[0].proc = (int (*)())&PSASL_ClientAuthID;
cbs[0].context = this;
cbs[1].id = SASL_CB_USER;
cbs[1].proc = (int (*)())&PSASL_ClientUserID;
cbs[1].context = this;
cbs[2].id = SASL_CB_PASS;
cbs[2].proc = (int (*)())&PSASL_ClientPassword;
cbs[2].context = this;
cbs[3].id = SASL_CB_LIST_END;
cbs[3].proc = 0;
cbs[3].context = 0;
m_CallBacks = cbs;
}
if (m_ConnState)
End();
int result = sasl_client_new(m_Service, fqdn, 0, 0, (const sasl_callback_t *)m_CallBacks, 0, (sasl_conn_t **)&m_ConnState);
if (result != SASL_OK)
return FALSE;
const char * list;
unsigned plen;
int pcount;
sasl_listmech((sasl_conn_t *)m_ConnState, 0, 0, " ", 0, &list, &plen, &pcount);
PStringArray a = PString(list).Tokenise(" ");
for (PINDEX i = 0, max = a.GetSize() ; i < max ; i++)
supportedMechanisms.Include(a[i]);
return TRUE;
}
BOOL PSASLClient::Start(const PString& mechanism, PString& output)
{
const char * _output = 0;
unsigned _len = 0;
if (Start(mechanism, &_output, _len))
{
if (_output)
{
PBase64 b64;
b64.StartEncoding();
b64.ProcessEncoding(_output, _len);
output = b64.CompleteEncoding();
output.Replace("\r\n", PString::Empty(), TRUE);
}
return TRUE;
}
return FALSE;
}
BOOL PSASLClient::Start(const PString& mechanism, const char ** output, unsigned& len)
{
if (!m_ConnState)
return FALSE;
int result = sasl_client_start((sasl_conn_t *)m_ConnState, mechanism, 0, output, &len, 0);
if (result == SASL_OK || result == SASL_CONTINUE)
return TRUE;
return FALSE;
}
PSASLClient::PSASLResult PSASLClient::Negotiate(const PString& input, PString& output)
{
PBase64 b64;
b64.StartDecoding();
b64.ProcessDecoding(input);
PBYTEArray _bin_input = b64.GetDecodedData();
PString _input((const char *)(const BYTE *)_bin_input, _bin_input.GetSize());
const char * _output;
PSASLClient::PSASLResult result = Negotiate(_input, &_output);
if (_output)
{
b64.StartEncoding();
b64.ProcessEncoding(_output);
output = b64.CompleteEncoding();
output.Replace("\r\n", PString::Empty(), TRUE);
}
return result;
}
PSASLClient::PSASLResult PSASLClient::Negotiate(const char * input, const char ** output)
{
unsigned len;
int result = sasl_client_step((sasl_conn_t *)m_ConnState, input, strlen(input), 0, output, &len);
if (result != SASL_OK && result != SASL_CONTINUE)
return PSASLClient::Fail;
if (result == SASL_OK)
return PSASLClient::OK;
else
return PSASLClient::Continue;
}
BOOL PSASLClient::End()
{
if (m_ConnState)
{
sasl_dispose((sasl_conn_t **)&m_ConnState);
m_ConnState = 0;
return TRUE;
}
return FALSE;
}
#endif // P_SASL2
// End of File ///////////////////////////////////////////////////////////////
syntax highlighted by Code2HTML, v. 0.9.1