connection.cpp

00001 /*
00002   Copyright (c) 2004-2006 by Jakob Schroeter <js@camaya.net>
00003   This file is part of the gloox library. http://camaya.net/gloox
00004 
00005   This software is distributed under a license. The full license
00006   agreement can be found in the file LICENSE in this distribution.
00007   This software may not be copied, modified, sold or distributed
00008   other than expressed in the named license agreement.
00009 
00010   This software is distributed without any warranty.
00011 */
00012 
00013 
00014 
00015 #include "gloox.h"
00016 
00017 #include "compression.h"
00018 #include "connection.h"
00019 #include "dns.h"
00020 #include "logsink.h"
00021 #include "prep.h"
00022 #include "parser.h"
00023 
00024 #ifdef __MINGW32__
00025 #include <winsock.h>
00026 #endif
00027 
00028 #ifndef WIN32
00029 #include <sys/types.h>
00030 #include <sys/socket.h>
00031 #include <sys/select.h>
00032 #include <unistd.h>
00033 #else
00034 #include <winsock.h>
00035 #endif
00036 
00037 #ifdef USE_WINTLS
00038 # include <schannel.h>
00039 #endif
00040 
00041 #include <time.h>
00042 
00043 #include <cstdlib>
00044 #include <string>
00045 #include <sstream>
00046 #include <algorithm>
00047 
00048 namespace gloox
00049 {
00050 
00051   Connection::Connection( Parser *parser, const LogSink& logInstance, const std::string& server,
00052                           unsigned short port )
00053     : m_parser( parser ), m_state ( StateDisconnected ), m_disconnect ( ConnNoError ),
00054       m_logInstance( logInstance ), m_compression( 0 ), m_buf( 0 ),
00055       m_server( Prep::idna( server ) ), m_port( port ), m_socket( -1 ), m_bufsize( 17000 ),
00056       m_cancel( true ), m_secure( false ), m_fdRequested( false ), m_enableCompression( false )
00057   {
00058     m_buf = (char*)calloc( m_bufsize + 1, sizeof( char ) );
00059 #ifdef USE_OPENSSL
00060     m_ssl = 0;
00061 #endif
00062   }
00063 
00064   Connection::~Connection()
00065   {
00066     cleanup();
00067     free( m_buf );
00068     m_buf = 0;
00069     m_parser = 0;
00070   }
00071 
00072 #ifdef HAVE_TLS
00073   void Connection::setClientCert( const std::string& clientKey, const std::string& clientCerts )
00074   {
00075     m_clientKey = clientKey;
00076     m_clientCerts = clientCerts;
00077   }
00078 #endif
00079 
00080 #if defined( USE_OPENSSL )
00081   bool Connection::tlsHandshake()
00082   {
00083     SSL_library_init();
00084     SSL_CTX *sslCTX = SSL_CTX_new( TLSv1_client_method() );
00085     if( !sslCTX )
00086       return false;
00087 
00088     if( !SSL_CTX_set_cipher_list( sslCTX, "HIGH:MEDIUM:AES:@STRENGTH" ) )
00089       return false;
00090 
00091     StringList::const_iterator it = m_cacerts.begin();
00092     for( ; it != m_cacerts.end(); ++it )
00093       SSL_CTX_load_verify_locations( sslCTX, (*it).c_str(), NULL );
00094 
00095     if( !m_clientKey.empty() && !m_clientCerts.empty() )
00096     {
00097       SSL_CTX_use_certificate_chain_file( sslCTX, m_clientCerts.c_str() );
00098       SSL_CTX_use_PrivateKey_file( sslCTX, m_clientKey.c_str(), SSL_FILETYPE_PEM );
00099     }
00100 
00101     m_ssl = SSL_new( sslCTX );
00102     SSL_set_connect_state( m_ssl );
00103 
00104     BIO *socketBio = BIO_new_socket( m_socket, BIO_NOCLOSE );
00105     if( !socketBio )
00106       return false;
00107 
00108     SSL_set_bio( m_ssl, socketBio, socketBio );
00109     SSL_set_mode( m_ssl, SSL_MODE_AUTO_RETRY );
00110 
00111     if( !SSL_connect( m_ssl ) )
00112       return false;
00113 
00114     m_secure = true;
00115 
00116     int res = SSL_get_verify_result( m_ssl );
00117     if( res != X509_V_OK )
00118       m_certInfo.status = CertInvalid;
00119     else
00120       m_certInfo.status = CertOk;
00121 
00122     X509 *peer;
00123     peer = SSL_get_peer_certificate( m_ssl );
00124     if( peer )
00125     {
00126       char peer_CN[256];
00127       X509_NAME_get_text_by_NID( X509_get_issuer_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00128       m_certInfo.issuer = peer_CN;
00129       X509_NAME_get_text_by_NID( X509_get_subject_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00130       m_certInfo.server = peer_CN;
00131       std::string p;
00132       p.assign( peer_CN );
00133       int (*pf)( int ) = tolower;
00134       transform( p.begin(), p.end(), p.begin(), pf );
00135       if( p != m_server )
00136         m_certInfo.status |= CertWrongPeer;
00137     }
00138     else
00139     {
00140       m_certInfo.status = CertInvalid;
00141     }
00142 
00143     const char *tmp;
00144     tmp = SSL_get_cipher_name( m_ssl );
00145     if( tmp )
00146       m_certInfo.cipher = tmp;
00147 
00148     tmp = SSL_get_cipher_version( m_ssl );
00149     if( tmp )
00150       m_certInfo.protocol = tmp;
00151 
00152     return true;
00153   }
00154 
00155   inline bool Connection::tls_send( const void *data, size_t len )
00156   {
00157     int ret;
00158     ret = SSL_write( m_ssl, data, len );
00159     return true;
00160   }
00161 
00162   inline int Connection::tls_recv( void *data, size_t len )
00163   {
00164     return SSL_read( m_ssl, data, len );
00165   }
00166 
00167   inline bool Connection::tls_dataAvailable()
00168   {
00169     return false; // SSL_pending( m_ssl ); // FIXME: crashes
00170   }
00171 
00172   inline void Connection::tls_cleanup()
00173   {
00174     SSL_shutdown( m_ssl );
00175     SSL_free( m_ssl );
00176   }
00177 
00178 #elif defined( USE_GNUTLS )
00179   bool Connection::tlsHandshake()
00180   {
00181     const int protocolPriority[] = { GNUTLS_TLS1, GNUTLS_SSL3, 0 };
00182     const int kxPriority[]       = { GNUTLS_KX_RSA, 0 };
00183     const int cipherPriority[]   = { GNUTLS_CIPHER_AES_256_CBC, GNUTLS_CIPHER_AES_128_CBC,
00184                                              GNUTLS_CIPHER_3DES_CBC, GNUTLS_CIPHER_ARCFOUR, 0 };
00185     const int compPriority[]     = { GNUTLS_COMP_ZLIB, GNUTLS_COMP_NULL, 0 };
00186     const int macPriority[]      = { GNUTLS_MAC_SHA, GNUTLS_MAC_MD5, 0 };
00187 
00188     if( gnutls_global_init() != 0 )
00189       return false;
00190 
00191     if( gnutls_certificate_allocate_credentials( &m_credentials ) < 0 )
00192       return false;
00193 
00194     StringList::const_iterator it = m_cacerts.begin();
00195     for( ; it != m_cacerts.end(); ++it )
00196       gnutls_certificate_set_x509_trust_file( m_credentials, (*it).c_str(), GNUTLS_X509_FMT_PEM );
00197 
00198     if( !m_clientKey.empty() && !m_clientCerts.empty() )
00199     {
00200       gnutls_certificate_set_x509_key_file( m_credentials, m_clientKey.c_str(),
00201                                             m_clientCerts.c_str(), GNUTLS_X509_FMT_PEM );
00202     }
00203 
00204     if( gnutls_init( &m_session, GNUTLS_CLIENT ) != 0 )
00205     {
00206       gnutls_certificate_free_credentials( m_credentials );
00207       return false;
00208     }
00209 
00210     gnutls_protocol_set_priority( m_session, protocolPriority );
00211     gnutls_cipher_set_priority( m_session, cipherPriority );
00212     gnutls_compression_set_priority( m_session, compPriority );
00213     gnutls_kx_set_priority( m_session, kxPriority );
00214     gnutls_mac_set_priority( m_session, macPriority );
00215     gnutls_credentials_set( m_session, GNUTLS_CRD_CERTIFICATE, m_credentials );
00216 
00217     gnutls_transport_set_ptr( m_session, (gnutls_transport_ptr_t)m_socket );
00218     if( gnutls_handshake( m_session ) != 0 )
00219     {
00220       gnutls_deinit( m_session );
00221       gnutls_certificate_free_credentials( m_credentials );
00222       return false;
00223     }
00224     gnutls_certificate_free_ca_names( m_credentials );
00225 
00226     m_secure = true;
00227 
00228     unsigned int status;
00229     bool error = false;
00230 
00231     if( gnutls_certificate_verify_peers2( m_session, &status ) < 0 )
00232       error = true;
00233 
00234     m_certInfo.status = 0;
00235     if( status & GNUTLS_CERT_INVALID )
00236       m_certInfo.status |= CertInvalid;
00237     if( status & GNUTLS_CERT_SIGNER_NOT_FOUND )
00238       m_certInfo.status |= CertSignerUnknown;
00239     if( status & GNUTLS_CERT_REVOKED )
00240       m_certInfo.status |= CertRevoked;
00241     if( status & GNUTLS_CERT_SIGNER_NOT_CA )
00242       m_certInfo.status |= CertSignerNotCa;
00243     const gnutls_datum_t* certList = 0;
00244     unsigned int certListSize;
00245     if( !error && ( ( certList = gnutls_certificate_get_peers( m_session, &certListSize ) ) == 0 ) )
00246       error = true;
00247 
00248     gnutls_x509_crt_t *cert = new gnutls_x509_crt_t[certListSize+1];
00249     for( unsigned int i=0; !error && ( i<certListSize ); ++i )
00250     {
00251       if( !error && ( gnutls_x509_crt_init( &cert[i] ) < 0 ) )
00252         error = true;
00253       if( !error && ( gnutls_x509_crt_import( cert[i], &certList[i], GNUTLS_X509_FMT_DER ) < 0 ) )
00254         error = true;
00255     }
00256 
00257     if( ( gnutls_x509_crt_check_issuer( cert[certListSize-1], cert[certListSize-1] ) > 0 )
00258          && certListSize > 0 )
00259       certListSize--;
00260 
00261     bool chain = true;
00262     for( unsigned int i=1; !error && ( i<certListSize ); ++i )
00263     {
00264       chain = error = !verifyAgainst( cert[i-1], cert[i] );
00265     }
00266     if( !chain )
00267       m_certInfo.status |= CertInvalid;
00268     m_certInfo.chain = chain;
00269 
00270     m_certInfo.chain = verifyAgainstCAs( cert[certListSize], 0 /*CAList*/, 0 /*CAListSize*/ );
00271 
00272     int t = (int)gnutls_x509_crt_get_expiration_time( cert[0] );
00273     if( t == -1 )
00274       error = true;
00275     else if( t < time( 0 ) )
00276       m_certInfo.status |= CertExpired;
00277     m_certInfo.date_to = t;
00278 
00279     t = (int)gnutls_x509_crt_get_activation_time( cert[0] );
00280     if( t == -1 )
00281       error = true;
00282     else if( t > time( 0 ) )
00283       m_certInfo.status |= CertNotActive;
00284     m_certInfo.date_from = t;
00285 
00286     char name[64];
00287     size_t nameSize = sizeof( name );
00288     gnutls_x509_crt_get_issuer_dn( cert[0], name, &nameSize );
00289     m_certInfo.issuer = name;
00290 
00291     nameSize = sizeof( name );
00292     gnutls_x509_crt_get_dn( cert[0], name, &nameSize );
00293     m_certInfo.server = name;
00294 
00295     const char* info;
00296     info = gnutls_compression_get_name( gnutls_compression_get( m_session ) );
00297     if( info )
00298       m_certInfo.compression = info;
00299 
00300     info = gnutls_mac_get_name( gnutls_mac_get( m_session ) );
00301     if( info )
00302       m_certInfo.mac = info;
00303 
00304     info = gnutls_cipher_get_name( gnutls_cipher_get( m_session ) );
00305     if( info )
00306       m_certInfo.cipher = info;
00307 
00308     info = gnutls_protocol_get_name( gnutls_protocol_get_version( m_session ) );
00309     if( info )
00310       m_certInfo.protocol = info;
00311 
00312     if( !gnutls_x509_crt_check_hostname( cert[0], m_server.c_str() ) )
00313       m_certInfo.status |= CertWrongPeer;
00314 
00315     for( unsigned int i=0; i<certListSize; ++i )
00316       gnutls_x509_crt_deinit( cert[i] );
00317 
00318     delete[] cert;
00319 
00320     return true;
00321   }
00322 
00323   bool Connection::verifyAgainst( gnutls_x509_crt_t cert, gnutls_x509_crt_t issuer )
00324   {
00325     unsigned int result;
00326     gnutls_x509_crt_verify( cert, &issuer, 1, 0, &result );
00327     if( result & GNUTLS_CERT_INVALID )
00328       return false;
00329 
00330     if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00331       return false;
00332 
00333     if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00334       return false;
00335 
00336     return true;
00337   }
00338 
00339   bool Connection::verifyAgainstCAs( gnutls_x509_crt_t cert, gnutls_x509_crt_t *CAList, int CAListSize )
00340   {
00341     unsigned int result;
00342     gnutls_x509_crt_verify( cert, CAList, CAListSize, GNUTLS_VERIFY_ALLOW_X509_V1_CA_CRT, &result );
00343     if( result & GNUTLS_CERT_INVALID )
00344       return false;
00345 
00346     if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00347       return false;
00348 
00349     if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00350       return false;
00351 
00352     return true;
00353   }
00354 
00355   inline bool Connection::tls_send( const void *data, size_t len )
00356   {
00357     int ret;
00358     do
00359     {
00360       ret = gnutls_record_send( m_session, data, len );
00361     }
00362     while( ( ret == GNUTLS_E_AGAIN ) || ( ret == GNUTLS_E_INTERRUPTED ) );
00363     return true;
00364   }
00365 
00366   inline int Connection::tls_recv( void *data, size_t len )
00367   {
00368     return gnutls_record_recv( m_session, data, len );
00369   }
00370 
00371   inline bool Connection::tls_dataAvailable()
00372   {
00373     return false; // gnutls_check_pending( m_session ); // FIXME: crashes
00374   }
00375 
00376   inline void Connection::tls_cleanup()
00377   {
00378     gnutls_bye( m_session, GNUTLS_SHUT_RDWR );
00379     gnutls_deinit( m_session );
00380     gnutls_certificate_free_credentials( m_credentials );
00381     gnutls_global_deinit();
00382   }
00383 
00384 #elif defined( USE_WINTLS )
00385   bool Connection::tlsHandshake()
00386   {
00387     INIT_SECURITY_INTERFACE pInitSecurityInterface;
00388 
00389     m_lib = LoadLibrary( "secur32.dll" );
00390     if( m_lib == NULL )
00391       return false;
00392 
00393     pInitSecurityInterface = (INIT_SECURITY_INTERFACE)GetProcAddress( m_lib, "InitSecurityInterfaceA" );
00394     if( pInitSecurityInterface == NULL )
00395     {
00396       FreeLibrary( m_lib );
00397       m_lib = 0;
00398       return false;
00399     }
00400 
00401     m_securityFunc = pInitSecurityInterface();
00402     if( !m_securityFunc )
00403     {
00404       FreeLibrary( m_lib );
00405       m_lib = 0;
00406       return false;
00407     }
00408 
00409     SCHANNEL_CRED schannelCred;
00410     memset( &schannelCred, 0, sizeof( schannelCred ) );
00411     memset( &m_credentials, 0, sizeof( m_credentials ) );
00412     memset( &m_context, 0, sizeof( m_context ) );
00413 
00414     schannelCred.dwVersion = SCHANNEL_CRED_VERSION;
00415     schannelCred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT;
00416     schannelCred.cSupportedAlgs = 0; // FIXME
00417 #ifdef MSVC
00418     schannelCred.dwMinimumCipherStrength = 0; // FIXME
00419     schannelCred.dwMaximumCipherStrength = 0; // FIXME
00420 #else
00421     schannelCred.dwMinimumCypherStrength = 0; // FIXME
00422     schannelCred.dwMaximumCypherStrength = 0; // FIXME
00423 #endif
00424     schannelCred.dwSessionLifespan = 0;
00425     schannelCred.dwFlags = SCH_CRED_NO_SERVERNAME_CHECK | SCH_CRED_NO_DEFAULT_CREDS |
00426                            SCH_CRED_MANUAL_CRED_VALIDATION; // FIXME check
00427 
00428     TimeStamp timeStamp;
00429     SECURITY_STATUS ret;
00430     ret = m_securityFunc->AcquireCredentialsHandleA( NULL, UNISP_NAME_A, SECPKG_CRED_OUTBOUND,
00431                                      NULL, &schannelCred, NULL,
00432                                      NULL, &m_credentials, &timeStamp );
00433     if( ret != SEC_E_OK )
00434     {
00435       printf( "AcquireCredentialsHandleA failed\n" );
00436       return false;
00437     }
00438 
00439     m_sspiFlags = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR
00440                       | ISC_REQ_MUTUAL_AUTH | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT
00441                       | ISC_REQ_STREAM;
00442 
00443     SecBufferDesc outBufferDesc;
00444     SecBuffer outBuffers[1];
00445 
00446     outBuffers[0].BufferType = SECBUFFER_TOKEN;
00447     outBuffers[0].pvBuffer = NULL;
00448     outBuffers[0].cbBuffer = 0;
00449 
00450     outBufferDesc.ulVersion = SECBUFFER_VERSION;
00451     outBufferDesc.cBuffers = 1;
00452     outBufferDesc.pBuffers = outBuffers;
00453 
00454     long unsigned int sspiFlagsOut;
00455     ret = m_securityFunc->InitializeSecurityContextA( &m_credentials, NULL, NULL, m_sspiFlags, 0,
00456         SECURITY_NATIVE_DREP, NULL, 0, &m_context,
00457         &outBufferDesc, &sspiFlagsOut, &timeStamp );
00458     if( ret == SEC_I_CONTINUE_NEEDED && outBuffers[0].cbBuffer != 0 && outBuffers[0].pvBuffer != NULL )
00459     {
00460       printf( "OK: Continue needed: " );
00461 
00462       int ret = ::send( m_socket, (const char*)outBuffers[0].pvBuffer, outBuffers[0].cbBuffer, 0 );
00463       if( ret == SOCKET_ERROR || ret == 0 )
00464       {
00465         m_securityFunc->FreeContextBuffer( outBuffers[0].pvBuffer );
00466         m_securityFunc->DeleteSecurityContext( &m_context );
00467         return false;
00468       }
00469 
00470       m_securityFunc->FreeContextBuffer( outBuffers[0].pvBuffer );
00471       outBuffers[0].pvBuffer = NULL;
00472     }
00473 
00474     if( !handshakeLoop() )
00475     {
00476       printf( "handshakeLoop failed\n" );
00477       return false;
00478     }
00479 
00480     ret = m_securityFunc->QueryContextAttributes( &m_context, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes );
00481     if( ret != SEC_E_OK )
00482     {
00483       printf( "could not read stream attribs (sizes)\n" );
00484       return false;
00485     }
00486 printf( "maximumMessage: %ld\n", m_streamSizes.cbMaximumMessage );
00487     int maxSize = m_streamSizes.cbHeader + m_streamSizes.cbMaximumMessage + m_streamSizes.cbTrailer;
00488     m_iBuffer = (char*)malloc( maxSize );
00489     if( !m_iBuffer )
00490       return false;
00491 
00492     m_oBuffer = (char*)malloc( maxSize );
00493     if( !m_oBuffer )
00494       return false;
00495 
00496     m_bufferOffset = 0;
00497     m_messageOffset = m_oBuffer + m_streamSizes.cbHeader;
00498 
00499     SecPkgContext_Authority streamAuthority;
00500     ret = m_securityFunc->QueryContextAttributes( &m_context, SECPKG_ATTR_AUTHORITY, &streamAuthority );
00501     if( ret != SEC_E_OK )
00502     {
00503       printf( "could not read stream attribs (sizes)\n" );
00504       return false;
00505     }
00506     else
00507     {
00508       m_certInfo.issuer.assign( streamAuthority.sAuthorityName );
00509     }
00510 
00511     SecPkgContext_ConnectionInfo streamInfo;
00512     ret = m_securityFunc->QueryContextAttributes( &m_context, SECPKG_ATTR_CONNECTION_INFO, &streamInfo );
00513     if( ret != SEC_E_OK )
00514     {
00515       printf( "could not read stream attribs (sizes)\n" );
00516       return false;
00517     }
00518     else
00519     {
00520       if( streamInfo.dwProtocol == SP_PROT_TLS1_CLIENT )
00521         m_certInfo.protocol = "TLS 1.0";
00522       else
00523         m_certInfo.protocol = "unknown";
00524 
00525       std::ostringstream oss;
00526       switch( streamInfo.aiCipher )
00527       {
00528         case CALG_3DES:
00529           oss << "3DES";
00530           break;
00531         case CALG_AES_128:
00532           oss << "AES";
00533           break;
00534         case CALG_AES_256:
00535           oss << "AES";
00536           break;
00537         case CALG_DES:
00538           oss << "DES";
00539           break;
00540         case CALG_RC2:
00541           oss << "RC2";
00542           break;
00543         case CALG_RC4:
00544           oss << "RC4";
00545           break;
00546         default:
00547           oss << "unknown";
00548       }
00549 
00550       oss << " " << streamInfo.dwCipherStrength;
00551       m_certInfo.cipher = oss.str();
00552       oss.str( "" );
00553 
00554       switch( streamInfo.aiHash  )
00555       {
00556         case CALG_MD5:
00557           oss << "MD5";
00558           break;
00559         case CALG_SHA:
00560           oss << "SHA";
00561           break;
00562         default:
00563           oss << "unknown";
00564       }
00565 
00566       oss << " " << streamInfo.dwHashStrength;
00567       m_certInfo.mac = oss.str();
00568 
00569       m_certInfo.compression = "unknown";
00570     }
00571 
00572     m_secure = true;
00573 
00574     return true;
00575   }
00576 
00577   bool Connection::handshakeLoop()
00578   {
00579     const int bufsize = 65536;
00580     char *buf = (char*)malloc( bufsize );
00581     if( !buf )
00582       return false;
00583 
00584     int bufFilled = 0;
00585     int dataRecv = 0;
00586     bool doRead = true;
00587 
00588     SecBufferDesc outBufferDesc, inBufferDesc;
00589     SecBuffer outBuffers[1], inBuffers[2];
00590 
00591     SECURITY_STATUS ret = SEC_I_CONTINUE_NEEDED;
00592 
00593     while( ret == SEC_I_CONTINUE_NEEDED ||
00594            ret == SEC_E_INCOMPLETE_MESSAGE ||
00595            ret == SEC_I_INCOMPLETE_CREDENTIALS )
00596     {
00597 
00598       if( doRead )
00599       {
00600         dataRecv = ::recv( m_socket, buf + bufFilled, bufsize - bufFilled, 0 );
00601 
00602         if( dataRecv == SOCKET_ERROR || dataRecv == 0 )
00603         {
00604           break;
00605         }
00606 
00607         printf( "%d bytes handshake data received\n", dataRecv );
00608 
00609         bufFilled += dataRecv;
00610       }
00611       else
00612       {
00613         doRead = true;
00614       }
00615 
00616       outBuffers[0].BufferType = SECBUFFER_TOKEN;
00617       outBuffers[0].pvBuffer = NULL;
00618       outBuffers[0].cbBuffer = 0;
00619 
00620       outBufferDesc.ulVersion = SECBUFFER_VERSION;
00621       outBufferDesc.cBuffers = 1;
00622       outBufferDesc.pBuffers = outBuffers;
00623 
00624       inBuffers[0].BufferType = SECBUFFER_TOKEN;
00625       inBuffers[0].pvBuffer = buf;
00626       inBuffers[0].cbBuffer = bufFilled;
00627 
00628       inBuffers[1].BufferType = SECBUFFER_EMPTY;
00629       inBuffers[1].pvBuffer = NULL;
00630       inBuffers[1].cbBuffer = 0;
00631 
00632       inBufferDesc.ulVersion = SECBUFFER_VERSION;
00633       inBufferDesc.cBuffers = 2;
00634       inBufferDesc.pBuffers = inBuffers;
00635 
00636       printf( "buffers inited, calling InitializeSecurityContextA\n" );
00637       long unsigned int sspiFlagsOut;
00638       TimeStamp timeStamp;
00639       ret = m_securityFunc->InitializeSecurityContextA( &m_credentials, &m_context, NULL,
00640                                                         m_sspiFlags, 0,
00641                                                         SECURITY_NATIVE_DREP, &inBufferDesc, 0, NULL,
00642                                                         &outBufferDesc, &sspiFlagsOut, &timeStamp );
00643       if( ret == SEC_E_OK || ret == SEC_I_CONTINUE_NEEDED ||
00644           ( FAILED( ret ) && sspiFlagsOut & ISC_RET_EXTENDED_ERROR ) )
00645       {
00646         if( outBuffers[0].cbBuffer != 0 && outBuffers[0].pvBuffer != NULL )
00647         {
00648           printf( "ISCA returned, buffers not empty\n" );
00649           dataRecv = ::send( m_socket, (const char*)outBuffers[0].pvBuffer, outBuffers[0].cbBuffer, 0  );
00650           if( dataRecv == SOCKET_ERROR || dataRecv == 0 )
00651           {
00652             m_securityFunc->FreeContextBuffer( &outBuffers[0].pvBuffer );
00653             m_securityFunc->DeleteSecurityContext( &m_context );
00654             free( buf );
00655             printf( "coudl not send bufer to server, exiting\n" );
00656             return false;
00657           }
00658 
00659           m_securityFunc->FreeContextBuffer( outBuffers[0].pvBuffer );
00660           outBuffers[0].pvBuffer = NULL;
00661         }
00662       }
00663 
00664       if( ret == SEC_E_INCOMPLETE_MESSAGE )
00665         continue;
00666 
00667       if( ret == SEC_E_OK )
00668       {
00669         printf( "handshake successful\n" );
00670         break;
00671       }
00672 
00673       if( FAILED( ret ) )
00674       {
00675         printf( "ISC failed: %ld\n", ret );
00676         break;
00677       }
00678 
00679       if( ret == SEC_I_INCOMPLETE_CREDENTIALS )
00680       {
00681         printf( "server requested client credentials\n" );
00682         ret = SEC_I_CONTINUE_NEEDED;
00683         continue;
00684       }
00685 
00686       if( inBuffers[1].BufferType == SECBUFFER_EXTRA )
00687       {
00688         printf("some xtra mem in inbuf\n" );
00689         MoveMemory( buf, buf + ( bufFilled - inBuffers[1].cbBuffer ),
00690                    inBuffers[1].cbBuffer );
00691 
00692         bufFilled = inBuffers[1].cbBuffer;
00693       }
00694       else
00695       {
00696         bufFilled = 0;
00697       }
00698     }
00699 
00700     if( FAILED( ret ) )
00701       m_securityFunc->DeleteSecurityContext( &m_context );
00702 
00703     free( buf );
00704 
00705     if( ret == SEC_E_OK )
00706       return true;
00707 
00708     return false;
00709   }
00710 
00711   inline bool Connection::tls_send( const void *data, size_t len )
00712   {
00713     if( len <= 0 )
00714       return true;
00715 
00716     SECURITY_STATUS ret;
00717 
00718     m_obuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
00719     m_obuffers[0].pvBuffer = m_oBuffer;
00720     m_obuffers[0].cbBuffer = m_streamSizes.cbHeader;
00721 
00722     m_obuffers[1].BufferType = SECBUFFER_DATA;
00723     m_obuffers[1].pvBuffer = m_messageOffset;
00724 
00725     m_obuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
00726     m_obuffers[2].cbBuffer = m_streamSizes.cbTrailer;
00727 
00728     m_obuffers[3].BufferType = SECBUFFER_EMPTY;
00729     m_obuffers[3].pvBuffer = NULL;
00730     m_obuffers[3].cbBuffer = 0;
00731 
00732     m_omessage.ulVersion = SECBUFFER_VERSION;
00733     m_omessage.cBuffers = 4;
00734     m_omessage.pBuffers = m_obuffers;
00735 
00736     while( len > 0 )
00737     {
00738       if( m_streamSizes.cbMaximumMessage < len )
00739       {
00740         memcpy( m_messageOffset, data, m_streamSizes.cbMaximumMessage );
00741         len -= m_streamSizes.cbMaximumMessage;
00742         m_obuffers[1].cbBuffer = m_streamSizes.cbMaximumMessage;
00743         m_obuffers[2].pvBuffer = m_messageOffset + m_streamSizes.cbMaximumMessage;
00744       }
00745       else
00746       {
00747         memcpy( m_messageOffset, data, len );
00748         m_obuffers[1].cbBuffer = len;
00749         m_obuffers[2].pvBuffer = m_messageOffset + len;
00750         len = 0;
00751       }
00752 
00753       ret = m_securityFunc->EncryptMessage( &m_context, 0, &m_omessage, 0 );
00754       if( ret != SEC_E_OK )
00755       {
00756         printf( "encryptmessage failed %ld\n", ret );
00757         return false;
00758       }
00759 
00760       int t = ::send( m_socket, m_oBuffer,
00761                       m_obuffers[0].cbBuffer + m_obuffers[1].cbBuffer + m_obuffers[2].cbBuffer, 0 );
00762       if( t == SOCKET_ERROR || t == 0 )
00763       {
00764         printf( "could not send: %d\n", WSAGetLastError() );
00765         return false;
00766       }
00767     }
00768 
00769     return true;
00770   }
00771 
00772   inline int Connection::tls_recv( void *data, size_t len )
00773   {
00774     SECURITY_STATUS ret;
00775     SecBuffer *dataBuffer = 0;
00776     int readable = 0;
00777 
00778     int maxLength = m_streamSizes.cbHeader + m_streamSizes.cbMaximumMessage + m_streamSizes.cbTrailer;
00779 
00780     printf( "bufferOffset is %d\n", m_bufferOffset );
00781 
00782     int t = ::recv( m_socket, m_iBuffer + m_bufferOffset, maxLength - m_bufferOffset, 0 );
00783     if( t == SOCKET_ERROR )
00784     {
00785       printf( "got SocketError\n" );
00786       return 0;
00787     }
00788     else if( t == 0 )
00789     {
00790       printf( "got connection close\n" );
00791       return 0;
00792     }
00793     else
00794       m_bufferOffset += t;
00795 
00796     while( m_bufferOffset )
00797     {
00798       printf( "continuing with bufferOffset: %d\n", m_bufferOffset );
00799 
00800       m_ibuffers[0].pvBuffer = m_iBuffer;
00801       m_ibuffers[0].cbBuffer = m_bufferOffset;
00802       m_ibuffers[0].BufferType = SECBUFFER_DATA;
00803 
00804       m_ibuffers[1].BufferType = SECBUFFER_EMPTY;
00805       m_ibuffers[2].BufferType = SECBUFFER_EMPTY;
00806       m_ibuffers[3].BufferType = SECBUFFER_EMPTY;
00807 
00808       m_imessage.ulVersion = SECBUFFER_VERSION;
00809       m_imessage.cBuffers = 4;
00810       m_imessage.pBuffers = m_ibuffers;
00811 
00812       ret = m_securityFunc->DecryptMessage( &m_context, &m_imessage, 0, NULL );
00813 
00814       if( ret == SEC_E_INCOMPLETE_MESSAGE )
00815       {
00816         printf( "recv'ed incomplete message\n" );
00817         return readable;
00818       }
00819 
00820 
00821   //    if( ret == SEC_I_CONTEXT_EXPIRED )
00822   //      return 0;
00823 
00824       if( ret != SEC_E_OK && ret != SEC_I_RENEGOTIATE )
00825       {
00826         printf( "DecryptMessage returned %ld\n", ret );
00827         printf( "GetLastError(): %ld\n", GetLastError() );
00828         printf( "input buffer length: %d, read in this run: %d\n", m_bufferOffset, t );
00829         return false;
00830       }
00831 
00832       m_bufferOffset = 0;
00833 
00834       for( int i = 1; i < 4; ++i )
00835       {
00836         if( dataBuffer == 0 && m_ibuffers[i].BufferType == SECBUFFER_DATA )
00837         {
00838           dataBuffer = &m_ibuffers[i];
00839         }
00840         if( m_bufferOffset == 0 && m_ibuffers[i].BufferType == SECBUFFER_EXTRA )
00841         {
00842   //         m_extraBuffer = &m_ibuffers[i];
00843   printf( "git exetra buffer, size %ld\n", m_ibuffers[i].cbBuffer );
00844 //          memcpy( m_iBuffer, m_ibuffers[i].pvBuffer, m_ibuffers[i].cbBuffer );
00845 //          m_bufferOffset = m_ibuffers[i].cbBuffer;
00846         }
00847       }
00848 
00849       if( dataBuffer )
00850       {
00851         if( dataBuffer->cbBuffer > len )
00852         {
00853           memcpy( data, dataBuffer->pvBuffer, len );
00854           return len;
00855         }
00856         else
00857         {
00858           memcpy( data, dataBuffer->pvBuffer, dataBuffer->cbBuffer );
00859           readable += dataBuffer->cbBuffer;
00860           printf( "recvbuffer (%d): %s\n", readable, data );
00861         }
00862       }
00863 
00864       if( ret == SEC_I_RENEGOTIATE )
00865       {
00866         printf( "server requested reneg\n" );
00867         ret = handshakeLoop();
00868       }
00869     }
00870 
00871     return readable;
00872   }
00873 
00874   inline bool Connection::tls_dataAvailable()
00875   {
00876     return false;
00877   }
00878 
00879   inline void Connection::tls_cleanup()
00880   {
00881     m_securityFunc->DeleteSecurityContext( &m_context );
00882   }
00883 #endif
00884 
00885 #ifdef HAVE_ZLIB
00886   bool Connection::initCompression( StreamFeature method )
00887   {
00888     delete m_compression;
00889     m_compression = 0;
00890     m_compression = new Compression( method );
00891     return true;
00892   }
00893 
00894   void Connection::enableCompression()
00895   {
00896     if( !m_compression )
00897       return;
00898 
00899     m_enableCompression = true;
00900   }
00901 #endif
00902 
00903   ConnectionState Connection::connect()
00904   {
00905     if( m_socket != -1 && m_state >= StateConnecting )
00906     {
00907       return m_state;
00908     }
00909 
00910     m_state = StateConnecting;
00911 
00912     if( m_port == ( unsigned short ) -1 )
00913       m_socket = DNS::connect( m_server, m_logInstance );
00914     else
00915       m_socket = DNS::connect( m_server, m_port, m_logInstance );
00916 
00917     if( m_socket < 0 )
00918     {
00919       switch( m_socket )
00920       {
00921         case -DNS::DNS_COULD_NOT_CONNECT:
00922           m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: could not connect" );
00923           break;
00924         case -DNS::DNS_NO_HOSTS_FOUND:
00925           m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: no hosts found" );
00926           break;
00927         case -DNS::DNS_COULD_NOT_RESOLVE:
00928           m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: could not resolve" );
00929           break;
00930       }
00931       cleanup();
00932     }
00933     else
00934       m_state = StateConnected;
00935 
00936     m_cancel = false;
00937     return m_state;
00938   }
00939 
00940   void Connection::disconnect( ConnectionError e )
00941   {
00942     m_disconnect = e;
00943     m_cancel = true;
00944 
00945     if( m_fdRequested )
00946       cleanup();
00947   }
00948 
00949   int Connection::fileDescriptor()
00950   {
00951     m_fdRequested = true;
00952     return m_socket;
00953   }
00954 
00955   bool Connection::dataAvailable( int timeout )
00956   {
00957     if( m_socket < 0 )
00958       return true;
00959 
00960 #ifdef HAVE_TLS
00961     if( tls_dataAvailable() )
00962     {
00963         return true;
00964     }
00965 #endif
00966 
00967     fd_set fds;
00968     struct timeval tv;
00969 
00970     FD_ZERO( &fds );
00971     FD_SET( m_socket, &fds );
00972 
00973     tv.tv_sec = timeout / 1000000;
00974     tv.tv_usec = timeout % 1000000;
00975 
00976     if( select( m_socket + 1, &fds, 0, 0, timeout == -1 ? 0 : &tv ) > 0 )
00977     {
00978       return FD_ISSET( m_socket, &fds ) ? true : false;
00979     }
00980     return false;
00981   }
00982 
00983   ConnectionError Connection::recv( int timeout )
00984   {
00985     if( m_cancel )
00986     {
00987       ConnectionError e = m_disconnect;
00988       cleanup();
00989       return e;
00990     }
00991 
00992     if( m_socket == -1 )
00993       return ConnNotConnected;
00994 
00995     if( !m_fdRequested && !dataAvailable( timeout ) )
00996     {
00997         return ConnNoError;
00998     }
00999 
01000     // optimize(?): recv returns the size. set size+1 = \0
01001     memset( m_buf, '\0', m_bufsize + 1 );
01002     int size = 0;
01003 #ifdef HAVE_TLS
01004     if( m_secure )
01005     {
01006       size = tls_recv( m_buf, m_bufsize );
01007     }
01008     else
01009 #endif
01010     {
01011 #ifdef SKYOS
01012       size = ::recv( m_socket, (unsigned char*)m_buf, m_bufsize, 0 );
01013 #else
01014       size = ::recv( m_socket, m_buf, m_bufsize, 0 );
01015 #endif
01016     }
01017 
01018     if( size < 0 )
01019     {
01020       // error
01021       m_secure = false;
01022       cleanup();
01023       return ConnIoError;
01024     }
01025     else if( size == 0 )
01026     {
01027       // connection closed
01028       m_secure = false;
01029       cleanup();
01030       return ConnUserDisconnected;
01031     }
01032 
01033     std::string buf;
01034     buf.assign( m_buf, size );
01035     if( m_compression && m_enableCompression )
01036       buf = m_compression->decompress( buf );
01037 
01038     Parser::ParserState ret = m_parser->feed( buf );
01039     if( ret != Parser::PARSER_OK )
01040     {
01041       cleanup();
01042       switch( ret )
01043       {
01044         case Parser::PARSER_BADXML:
01045           m_logInstance.log( LogLevelError, LogAreaClassConnection, "XML parse error" );
01046           break;
01047         case Parser::PARSER_NOMEM:
01048           m_logInstance.log( LogLevelError, LogAreaClassConnection, "memory allocation error" );
01049           break;
01050         default:
01051           m_logInstance.log( LogLevelError, LogAreaClassConnection, "unexpected error" );
01052           break;
01053       }
01054       //printf( "buffer data: %s\n", buf.c_str() );
01055       return ConnIoError;
01056     }
01057 
01058     return ConnNoError;
01059   }
01060 
01061   ConnectionError Connection::receive()
01062   {
01063     if( m_socket == -1 || !m_parser )
01064       return ConnNotConnected;
01065 
01066     while( !m_cancel )
01067     {
01068       ConnectionError r = recv( 1 );
01069       if( r != ConnNoError )
01070       {
01071         return r;
01072       }
01073     }
01074     ConnectionError e = m_disconnect;
01075     cleanup();
01076 
01077     return e;
01078   }
01079 
01080   bool Connection::send( const std::string& data )
01081   {
01082     if( data.empty() || ( m_socket == -1 ) )
01083       return false;
01084 
01085     std::string xml;
01086     if( m_compression && m_enableCompression )
01087       xml = m_compression->compress( data );
01088     else
01089       xml = data;
01090 
01091 #ifdef HAVE_TLS
01092     if( m_secure )
01093     {
01094       size_t len = xml.length();
01095       if( tls_send( xml.c_str (), len ) == false )
01096         return false;
01097     }
01098     else
01099 #endif
01100     {
01101       size_t num = 0;
01102       size_t len = xml.length();
01103       while( num < len )
01104       {
01105 #ifdef SKYOS
01106         int sent = ::send( m_socket, (unsigned char*)(xml.c_str()+num), len - num, 0 );
01107 #else
01108         int sent = ::send( m_socket, (xml.c_str()+num), len - num, 0 );
01109 #endif
01110         if( sent == -1 )
01111           return false;
01112 
01113         num += sent;
01114       }
01115     }
01116 
01117     return true;
01118   }
01119 
01120   void Connection::cleanup()
01121   {
01122 #ifdef HAVE_TLS
01123     if( m_secure )
01124     {
01125       tls_cleanup();
01126     }
01127 #endif
01128 
01129     if( m_socket != -1 )
01130     {
01131 #ifdef WIN32
01132       closesocket( m_socket );
01133 #else
01134       close( m_socket );
01135 #endif
01136       m_socket = -1;
01137     }
01138     m_state = StateDisconnected;
01139     m_disconnect = ConnNoError;
01140     m_enableCompression = false;
01141     m_secure = false;
01142     m_cancel = true;
01143     m_fdRequested = false;
01144   }
01145 
01146 }

Generated on Tue May 1 14:20:20 2007 for gloox by  doxygen 1.5.1