00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include "gloox.h"
00016
00017 #include "compression.h"
00018 #include "connection.h"
00019 #include "dns.h"
00020 #include "logsink.h"
00021 #include "prep.h"
00022 #include "parser.h"
00023
00024 #ifdef __MINGW32__
00025 #include <winsock.h>
00026 #endif
00027
00028 #ifndef WIN32
00029 #include <sys/types.h>
00030 #include <sys/socket.h>
00031 #include <sys/select.h>
00032 #include <unistd.h>
00033 #else
00034 #include <winsock.h>
00035 #endif
00036
00037 #ifdef USE_WINTLS
00038 # include <schannel.h>
00039 #endif
00040
00041 #include <time.h>
00042
00043 #include <cstdlib>
00044 #include <string>
00045 #include <sstream>
00046 #include <algorithm>
00047
00048 namespace gloox
00049 {
00050
00051 Connection::Connection( Parser *parser, const LogSink& logInstance, const std::string& server,
00052 unsigned short port )
00053 : m_parser( parser ), m_state ( StateDisconnected ), m_disconnect ( ConnNoError ),
00054 m_logInstance( logInstance ), m_compression( 0 ), m_buf( 0 ),
00055 m_server( Prep::idna( server ) ), m_port( port ), m_socket( -1 ), m_bufsize( 17000 ),
00056 m_cancel( true ), m_secure( false ), m_fdRequested( false ), m_enableCompression( false )
00057 {
00058 m_buf = (char*)calloc( m_bufsize + 1, sizeof( char ) );
00059 #ifdef USE_OPENSSL
00060 m_ssl = 0;
00061 #endif
00062 }
00063
00064 Connection::~Connection()
00065 {
00066 cleanup();
00067 free( m_buf );
00068 m_buf = 0;
00069 m_parser = 0;
00070 }
00071
00072 #ifdef HAVE_TLS
00073 void Connection::setClientCert( const std::string& clientKey, const std::string& clientCerts )
00074 {
00075 m_clientKey = clientKey;
00076 m_clientCerts = clientCerts;
00077 }
00078 #endif
00079
00080 #if defined( USE_OPENSSL )
00081 bool Connection::tlsHandshake()
00082 {
00083 SSL_library_init();
00084 SSL_CTX *sslCTX = SSL_CTX_new( TLSv1_client_method() );
00085 if( !sslCTX )
00086 return false;
00087
00088 if( !SSL_CTX_set_cipher_list( sslCTX, "HIGH:MEDIUM:AES:@STRENGTH" ) )
00089 return false;
00090
00091 StringList::const_iterator it = m_cacerts.begin();
00092 for( ; it != m_cacerts.end(); ++it )
00093 SSL_CTX_load_verify_locations( sslCTX, (*it).c_str(), NULL );
00094
00095 if( !m_clientKey.empty() && !m_clientCerts.empty() )
00096 {
00097 SSL_CTX_use_certificate_chain_file( sslCTX, m_clientCerts.c_str() );
00098 SSL_CTX_use_PrivateKey_file( sslCTX, m_clientKey.c_str(), SSL_FILETYPE_PEM );
00099 }
00100
00101 m_ssl = SSL_new( sslCTX );
00102 SSL_set_connect_state( m_ssl );
00103
00104 BIO *socketBio = BIO_new_socket( m_socket, BIO_NOCLOSE );
00105 if( !socketBio )
00106 return false;
00107
00108 SSL_set_bio( m_ssl, socketBio, socketBio );
00109 SSL_set_mode( m_ssl, SSL_MODE_AUTO_RETRY );
00110
00111 if( !SSL_connect( m_ssl ) )
00112 return false;
00113
00114 m_secure = true;
00115
00116 int res = SSL_get_verify_result( m_ssl );
00117 if( res != X509_V_OK )
00118 m_certInfo.status = CertInvalid;
00119 else
00120 m_certInfo.status = CertOk;
00121
00122 X509 *peer;
00123 peer = SSL_get_peer_certificate( m_ssl );
00124 if( peer )
00125 {
00126 char peer_CN[256];
00127 X509_NAME_get_text_by_NID( X509_get_issuer_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00128 m_certInfo.issuer = peer_CN;
00129 X509_NAME_get_text_by_NID( X509_get_subject_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00130 m_certInfo.server = peer_CN;
00131 std::string p;
00132 p.assign( peer_CN );
00133 int (*pf)( int ) = tolower;
00134 transform( p.begin(), p.end(), p.begin(), pf );
00135 if( p != m_server )
00136 m_certInfo.status |= CertWrongPeer;
00137 }
00138 else
00139 {
00140 m_certInfo.status = CertInvalid;
00141 }
00142
00143 const char *tmp;
00144 tmp = SSL_get_cipher_name( m_ssl );
00145 if( tmp )
00146 m_certInfo.cipher = tmp;
00147
00148 tmp = SSL_get_cipher_version( m_ssl );
00149 if( tmp )
00150 m_certInfo.protocol = tmp;
00151
00152 return true;
00153 }
00154
00155 inline bool Connection::tls_send( const void *data, size_t len )
00156 {
00157 int ret;
00158 ret = SSL_write( m_ssl, data, len );
00159 return true;
00160 }
00161
00162 inline int Connection::tls_recv( void *data, size_t len )
00163 {
00164 return SSL_read( m_ssl, data, len );
00165 }
00166
00167 inline bool Connection::tls_dataAvailable()
00168 {
00169 return false;
00170 }
00171
00172 inline void Connection::tls_cleanup()
00173 {
00174 SSL_shutdown( m_ssl );
00175 SSL_free( m_ssl );
00176 }
00177
00178 #elif defined( USE_GNUTLS )
00179 bool Connection::tlsHandshake()
00180 {
00181 const int protocolPriority[] = { GNUTLS_TLS1, GNUTLS_SSL3, 0 };
00182 const int kxPriority[] = { GNUTLS_KX_RSA, 0 };
00183 const int cipherPriority[] = { GNUTLS_CIPHER_AES_256_CBC, GNUTLS_CIPHER_AES_128_CBC,
00184 GNUTLS_CIPHER_3DES_CBC, GNUTLS_CIPHER_ARCFOUR, 0 };
00185 const int compPriority[] = { GNUTLS_COMP_ZLIB, GNUTLS_COMP_NULL, 0 };
00186 const int macPriority[] = { GNUTLS_MAC_SHA, GNUTLS_MAC_MD5, 0 };
00187
00188 if( gnutls_global_init() != 0 )
00189 return false;
00190
00191 if( gnutls_certificate_allocate_credentials( &m_credentials ) < 0 )
00192 return false;
00193
00194 StringList::const_iterator it = m_cacerts.begin();
00195 for( ; it != m_cacerts.end(); ++it )
00196 gnutls_certificate_set_x509_trust_file( m_credentials, (*it).c_str(), GNUTLS_X509_FMT_PEM );
00197
00198 if( !m_clientKey.empty() && !m_clientCerts.empty() )
00199 {
00200 gnutls_certificate_set_x509_key_file( m_credentials, m_clientKey.c_str(),
00201 m_clientCerts.c_str(), GNUTLS_X509_FMT_PEM );
00202 }
00203
00204 if( gnutls_init( &m_session, GNUTLS_CLIENT ) != 0 )
00205 {
00206 gnutls_certificate_free_credentials( m_credentials );
00207 return false;
00208 }
00209
00210 gnutls_protocol_set_priority( m_session, protocolPriority );
00211 gnutls_cipher_set_priority( m_session, cipherPriority );
00212 gnutls_compression_set_priority( m_session, compPriority );
00213 gnutls_kx_set_priority( m_session, kxPriority );
00214 gnutls_mac_set_priority( m_session, macPriority );
00215 gnutls_credentials_set( m_session, GNUTLS_CRD_CERTIFICATE, m_credentials );
00216
00217 gnutls_transport_set_ptr( m_session, (gnutls_transport_ptr_t)m_socket );
00218 if( gnutls_handshake( m_session ) != 0 )
00219 {
00220 gnutls_deinit( m_session );
00221 gnutls_certificate_free_credentials( m_credentials );
00222 return false;
00223 }
00224 gnutls_certificate_free_ca_names( m_credentials );
00225
00226 m_secure = true;
00227
00228 unsigned int status;
00229 bool error = false;
00230
00231 if( gnutls_certificate_verify_peers2( m_session, &status ) < 0 )
00232 error = true;
00233
00234 m_certInfo.status = 0;
00235 if( status & GNUTLS_CERT_INVALID )
00236 m_certInfo.status |= CertInvalid;
00237 if( status & GNUTLS_CERT_SIGNER_NOT_FOUND )
00238 m_certInfo.status |= CertSignerUnknown;
00239 if( status & GNUTLS_CERT_REVOKED )
00240 m_certInfo.status |= CertRevoked;
00241 if( status & GNUTLS_CERT_SIGNER_NOT_CA )
00242 m_certInfo.status |= CertSignerNotCa;
00243 const gnutls_datum_t* certList = 0;
00244 unsigned int certListSize;
00245 if( !error && ( ( certList = gnutls_certificate_get_peers( m_session, &certListSize ) ) == 0 ) )
00246 error = true;
00247
00248 gnutls_x509_crt_t *cert = new gnutls_x509_crt_t[certListSize+1];
00249 for( unsigned int i=0; !error && ( i<certListSize ); ++i )
00250 {
00251 if( !error && ( gnutls_x509_crt_init( &cert[i] ) < 0 ) )
00252 error = true;
00253 if( !error && ( gnutls_x509_crt_import( cert[i], &certList[i], GNUTLS_X509_FMT_DER ) < 0 ) )
00254 error = true;
00255 }
00256
00257 if( ( gnutls_x509_crt_check_issuer( cert[certListSize-1], cert[certListSize-1] ) > 0 )
00258 && certListSize > 0 )
00259 certListSize--;
00260
00261 bool chain = true;
00262 for( unsigned int i=1; !error && ( i<certListSize ); ++i )
00263 {
00264 chain = error = !verifyAgainst( cert[i-1], cert[i] );
00265 }
00266 if( !chain )
00267 m_certInfo.status |= CertInvalid;
00268 m_certInfo.chain = chain;
00269
00270 m_certInfo.chain = verifyAgainstCAs( cert[certListSize], 0 , 0 );
00271
00272 int t = (int)gnutls_x509_crt_get_expiration_time( cert[0] );
00273 if( t == -1 )
00274 error = true;
00275 else if( t < time( 0 ) )
00276 m_certInfo.status |= CertExpired;
00277 m_certInfo.date_to = t;
00278
00279 t = (int)gnutls_x509_crt_get_activation_time( cert[0] );
00280 if( t == -1 )
00281 error = true;
00282 else if( t > time( 0 ) )
00283 m_certInfo.status |= CertNotActive;
00284 m_certInfo.date_from = t;
00285
00286 char name[64];
00287 size_t nameSize = sizeof( name );
00288 gnutls_x509_crt_get_issuer_dn( cert[0], name, &nameSize );
00289 m_certInfo.issuer = name;
00290
00291 nameSize = sizeof( name );
00292 gnutls_x509_crt_get_dn( cert[0], name, &nameSize );
00293 m_certInfo.server = name;
00294
00295 const char* info;
00296 info = gnutls_compression_get_name( gnutls_compression_get( m_session ) );
00297 if( info )
00298 m_certInfo.compression = info;
00299
00300 info = gnutls_mac_get_name( gnutls_mac_get( m_session ) );
00301 if( info )
00302 m_certInfo.mac = info;
00303
00304 info = gnutls_cipher_get_name( gnutls_cipher_get( m_session ) );
00305 if( info )
00306 m_certInfo.cipher = info;
00307
00308 info = gnutls_protocol_get_name( gnutls_protocol_get_version( m_session ) );
00309 if( info )
00310 m_certInfo.protocol = info;
00311
00312 if( !gnutls_x509_crt_check_hostname( cert[0], m_server.c_str() ) )
00313 m_certInfo.status |= CertWrongPeer;
00314
00315 for( unsigned int i=0; i<certListSize; ++i )
00316 gnutls_x509_crt_deinit( cert[i] );
00317
00318 delete[] cert;
00319
00320 return true;
00321 }
00322
00323 bool Connection::verifyAgainst( gnutls_x509_crt_t cert, gnutls_x509_crt_t issuer )
00324 {
00325 unsigned int result;
00326 gnutls_x509_crt_verify( cert, &issuer, 1, 0, &result );
00327 if( result & GNUTLS_CERT_INVALID )
00328 return false;
00329
00330 if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00331 return false;
00332
00333 if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00334 return false;
00335
00336 return true;
00337 }
00338
00339 bool Connection::verifyAgainstCAs( gnutls_x509_crt_t cert, gnutls_x509_crt_t *CAList, int CAListSize )
00340 {
00341 unsigned int result;
00342 gnutls_x509_crt_verify( cert, CAList, CAListSize, GNUTLS_VERIFY_ALLOW_X509_V1_CA_CRT, &result );
00343 if( result & GNUTLS_CERT_INVALID )
00344 return false;
00345
00346 if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00347 return false;
00348
00349 if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00350 return false;
00351
00352 return true;
00353 }
00354
00355 inline bool Connection::tls_send( const void *data, size_t len )
00356 {
00357 int ret;
00358 do
00359 {
00360 ret = gnutls_record_send( m_session, data, len );
00361 }
00362 while( ( ret == GNUTLS_E_AGAIN ) || ( ret == GNUTLS_E_INTERRUPTED ) );
00363 return true;
00364 }
00365
00366 inline int Connection::tls_recv( void *data, size_t len )
00367 {
00368 return gnutls_record_recv( m_session, data, len );
00369 }
00370
00371 inline bool Connection::tls_dataAvailable()
00372 {
00373 return false;
00374 }
00375
00376 inline void Connection::tls_cleanup()
00377 {
00378 gnutls_bye( m_session, GNUTLS_SHUT_RDWR );
00379 gnutls_deinit( m_session );
00380 gnutls_certificate_free_credentials( m_credentials );
00381 gnutls_global_deinit();
00382 }
00383
00384 #elif defined( USE_WINTLS )
00385 bool Connection::tlsHandshake()
00386 {
00387 INIT_SECURITY_INTERFACE pInitSecurityInterface;
00388
00389 m_lib = LoadLibrary( "secur32.dll" );
00390 if( m_lib == NULL )
00391 return false;
00392
00393 pInitSecurityInterface = (INIT_SECURITY_INTERFACE)GetProcAddress( m_lib, "InitSecurityInterfaceA" );
00394 if( pInitSecurityInterface == NULL )
00395 {
00396 FreeLibrary( m_lib );
00397 m_lib = 0;
00398 return false;
00399 }
00400
00401 m_securityFunc = pInitSecurityInterface();
00402 if( !m_securityFunc )
00403 {
00404 FreeLibrary( m_lib );
00405 m_lib = 0;
00406 return false;
00407 }
00408
00409 SCHANNEL_CRED schannelCred;
00410 memset( &schannelCred, 0, sizeof( schannelCred ) );
00411 memset( &m_credentials, 0, sizeof( m_credentials ) );
00412 memset( &m_context, 0, sizeof( m_context ) );
00413
00414 schannelCred.dwVersion = SCHANNEL_CRED_VERSION;
00415 schannelCred.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT;
00416 schannelCred.cSupportedAlgs = 0;
00417 #ifdef MSVC
00418 schannelCred.dwMinimumCipherStrength = 0;
00419 schannelCred.dwMaximumCipherStrength = 0;
00420 #else
00421 schannelCred.dwMinimumCypherStrength = 0;
00422 schannelCred.dwMaximumCypherStrength = 0;
00423 #endif
00424 schannelCred.dwSessionLifespan = 0;
00425 schannelCred.dwFlags = SCH_CRED_NO_SERVERNAME_CHECK | SCH_CRED_NO_DEFAULT_CREDS |
00426 SCH_CRED_MANUAL_CRED_VALIDATION;
00427
00428 TimeStamp timeStamp;
00429 SECURITY_STATUS ret;
00430 ret = m_securityFunc->AcquireCredentialsHandleA( NULL, UNISP_NAME_A, SECPKG_CRED_OUTBOUND,
00431 NULL, &schannelCred, NULL,
00432 NULL, &m_credentials, &timeStamp );
00433 if( ret != SEC_E_OK )
00434 {
00435 printf( "AcquireCredentialsHandleA failed\n" );
00436 return false;
00437 }
00438
00439 m_sspiFlags = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_EXTENDED_ERROR
00440 | ISC_REQ_MUTUAL_AUTH | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT
00441 | ISC_REQ_STREAM;
00442
00443 SecBufferDesc outBufferDesc;
00444 SecBuffer outBuffers[1];
00445
00446 outBuffers[0].BufferType = SECBUFFER_TOKEN;
00447 outBuffers[0].pvBuffer = NULL;
00448 outBuffers[0].cbBuffer = 0;
00449
00450 outBufferDesc.ulVersion = SECBUFFER_VERSION;
00451 outBufferDesc.cBuffers = 1;
00452 outBufferDesc.pBuffers = outBuffers;
00453
00454 long unsigned int sspiFlagsOut;
00455 ret = m_securityFunc->InitializeSecurityContextA( &m_credentials, NULL, NULL, m_sspiFlags, 0,
00456 SECURITY_NATIVE_DREP, NULL, 0, &m_context,
00457 &outBufferDesc, &sspiFlagsOut, &timeStamp );
00458 if( ret == SEC_I_CONTINUE_NEEDED && outBuffers[0].cbBuffer != 0 && outBuffers[0].pvBuffer != NULL )
00459 {
00460 printf( "OK: Continue needed: " );
00461
00462 int ret = ::send( m_socket, (const char*)outBuffers[0].pvBuffer, outBuffers[0].cbBuffer, 0 );
00463 if( ret == SOCKET_ERROR || ret == 0 )
00464 {
00465 m_securityFunc->FreeContextBuffer( outBuffers[0].pvBuffer );
00466 m_securityFunc->DeleteSecurityContext( &m_context );
00467 return false;
00468 }
00469
00470 m_securityFunc->FreeContextBuffer( outBuffers[0].pvBuffer );
00471 outBuffers[0].pvBuffer = NULL;
00472 }
00473
00474 if( !handshakeLoop() )
00475 {
00476 printf( "handshakeLoop failed\n" );
00477 return false;
00478 }
00479
00480 ret = m_securityFunc->QueryContextAttributes( &m_context, SECPKG_ATTR_STREAM_SIZES, &m_streamSizes );
00481 if( ret != SEC_E_OK )
00482 {
00483 printf( "could not read stream attribs (sizes)\n" );
00484 return false;
00485 }
00486 printf( "maximumMessage: %ld\n", m_streamSizes.cbMaximumMessage );
00487 int maxSize = m_streamSizes.cbHeader + m_streamSizes.cbMaximumMessage + m_streamSizes.cbTrailer;
00488 m_iBuffer = (char*)malloc( maxSize );
00489 if( !m_iBuffer )
00490 return false;
00491
00492 m_oBuffer = (char*)malloc( maxSize );
00493 if( !m_oBuffer )
00494 return false;
00495
00496 m_bufferOffset = 0;
00497 m_messageOffset = m_oBuffer + m_streamSizes.cbHeader;
00498
00499 SecPkgContext_Authority streamAuthority;
00500 ret = m_securityFunc->QueryContextAttributes( &m_context, SECPKG_ATTR_AUTHORITY, &streamAuthority );
00501 if( ret != SEC_E_OK )
00502 {
00503 printf( "could not read stream attribs (sizes)\n" );
00504 return false;
00505 }
00506 else
00507 {
00508 m_certInfo.issuer.assign( streamAuthority.sAuthorityName );
00509 }
00510
00511 SecPkgContext_ConnectionInfo streamInfo;
00512 ret = m_securityFunc->QueryContextAttributes( &m_context, SECPKG_ATTR_CONNECTION_INFO, &streamInfo );
00513 if( ret != SEC_E_OK )
00514 {
00515 printf( "could not read stream attribs (sizes)\n" );
00516 return false;
00517 }
00518 else
00519 {
00520 if( streamInfo.dwProtocol == SP_PROT_TLS1_CLIENT )
00521 m_certInfo.protocol = "TLS 1.0";
00522 else
00523 m_certInfo.protocol = "unknown";
00524
00525 std::ostringstream oss;
00526 switch( streamInfo.aiCipher )
00527 {
00528 case CALG_3DES:
00529 oss << "3DES";
00530 break;
00531 case CALG_AES_128:
00532 oss << "AES";
00533 break;
00534 case CALG_AES_256:
00535 oss << "AES";
00536 break;
00537 case CALG_DES:
00538 oss << "DES";
00539 break;
00540 case CALG_RC2:
00541 oss << "RC2";
00542 break;
00543 case CALG_RC4:
00544 oss << "RC4";
00545 break;
00546 default:
00547 oss << "unknown";
00548 }
00549
00550 oss << " " << streamInfo.dwCipherStrength;
00551 m_certInfo.cipher = oss.str();
00552 oss.str( "" );
00553
00554 switch( streamInfo.aiHash )
00555 {
00556 case CALG_MD5:
00557 oss << "MD5";
00558 break;
00559 case CALG_SHA:
00560 oss << "SHA";
00561 break;
00562 default:
00563 oss << "unknown";
00564 }
00565
00566 oss << " " << streamInfo.dwHashStrength;
00567 m_certInfo.mac = oss.str();
00568
00569 m_certInfo.compression = "unknown";
00570 }
00571
00572 m_secure = true;
00573
00574 return true;
00575 }
00576
00577 bool Connection::handshakeLoop()
00578 {
00579 const int bufsize = 65536;
00580 char *buf = (char*)malloc( bufsize );
00581 if( !buf )
00582 return false;
00583
00584 int bufFilled = 0;
00585 int dataRecv = 0;
00586 bool doRead = true;
00587
00588 SecBufferDesc outBufferDesc, inBufferDesc;
00589 SecBuffer outBuffers[1], inBuffers[2];
00590
00591 SECURITY_STATUS ret = SEC_I_CONTINUE_NEEDED;
00592
00593 while( ret == SEC_I_CONTINUE_NEEDED ||
00594 ret == SEC_E_INCOMPLETE_MESSAGE ||
00595 ret == SEC_I_INCOMPLETE_CREDENTIALS )
00596 {
00597
00598 if( doRead )
00599 {
00600 dataRecv = ::recv( m_socket, buf + bufFilled, bufsize - bufFilled, 0 );
00601
00602 if( dataRecv == SOCKET_ERROR || dataRecv == 0 )
00603 {
00604 break;
00605 }
00606
00607 printf( "%d bytes handshake data received\n", dataRecv );
00608
00609 bufFilled += dataRecv;
00610 }
00611 else
00612 {
00613 doRead = true;
00614 }
00615
00616 outBuffers[0].BufferType = SECBUFFER_TOKEN;
00617 outBuffers[0].pvBuffer = NULL;
00618 outBuffers[0].cbBuffer = 0;
00619
00620 outBufferDesc.ulVersion = SECBUFFER_VERSION;
00621 outBufferDesc.cBuffers = 1;
00622 outBufferDesc.pBuffers = outBuffers;
00623
00624 inBuffers[0].BufferType = SECBUFFER_TOKEN;
00625 inBuffers[0].pvBuffer = buf;
00626 inBuffers[0].cbBuffer = bufFilled;
00627
00628 inBuffers[1].BufferType = SECBUFFER_EMPTY;
00629 inBuffers[1].pvBuffer = NULL;
00630 inBuffers[1].cbBuffer = 0;
00631
00632 inBufferDesc.ulVersion = SECBUFFER_VERSION;
00633 inBufferDesc.cBuffers = 2;
00634 inBufferDesc.pBuffers = inBuffers;
00635
00636 printf( "buffers inited, calling InitializeSecurityContextA\n" );
00637 long unsigned int sspiFlagsOut;
00638 TimeStamp timeStamp;
00639 ret = m_securityFunc->InitializeSecurityContextA( &m_credentials, &m_context, NULL,
00640 m_sspiFlags, 0,
00641 SECURITY_NATIVE_DREP, &inBufferDesc, 0, NULL,
00642 &outBufferDesc, &sspiFlagsOut, &timeStamp );
00643 if( ret == SEC_E_OK || ret == SEC_I_CONTINUE_NEEDED ||
00644 ( FAILED( ret ) && sspiFlagsOut & ISC_RET_EXTENDED_ERROR ) )
00645 {
00646 if( outBuffers[0].cbBuffer != 0 && outBuffers[0].pvBuffer != NULL )
00647 {
00648 printf( "ISCA returned, buffers not empty\n" );
00649 dataRecv = ::send( m_socket, (const char*)outBuffers[0].pvBuffer, outBuffers[0].cbBuffer, 0 );
00650 if( dataRecv == SOCKET_ERROR || dataRecv == 0 )
00651 {
00652 m_securityFunc->FreeContextBuffer( &outBuffers[0].pvBuffer );
00653 m_securityFunc->DeleteSecurityContext( &m_context );
00654 free( buf );
00655 printf( "coudl not send bufer to server, exiting\n" );
00656 return false;
00657 }
00658
00659 m_securityFunc->FreeContextBuffer( outBuffers[0].pvBuffer );
00660 outBuffers[0].pvBuffer = NULL;
00661 }
00662 }
00663
00664 if( ret == SEC_E_INCOMPLETE_MESSAGE )
00665 continue;
00666
00667 if( ret == SEC_E_OK )
00668 {
00669 printf( "handshake successful\n" );
00670 break;
00671 }
00672
00673 if( FAILED( ret ) )
00674 {
00675 printf( "ISC failed: %ld\n", ret );
00676 break;
00677 }
00678
00679 if( ret == SEC_I_INCOMPLETE_CREDENTIALS )
00680 {
00681 printf( "server requested client credentials\n" );
00682 ret = SEC_I_CONTINUE_NEEDED;
00683 continue;
00684 }
00685
00686 if( inBuffers[1].BufferType == SECBUFFER_EXTRA )
00687 {
00688 printf("some xtra mem in inbuf\n" );
00689 MoveMemory( buf, buf + ( bufFilled - inBuffers[1].cbBuffer ),
00690 inBuffers[1].cbBuffer );
00691
00692 bufFilled = inBuffers[1].cbBuffer;
00693 }
00694 else
00695 {
00696 bufFilled = 0;
00697 }
00698 }
00699
00700 if( FAILED( ret ) )
00701 m_securityFunc->DeleteSecurityContext( &m_context );
00702
00703 free( buf );
00704
00705 if( ret == SEC_E_OK )
00706 return true;
00707
00708 return false;
00709 }
00710
00711 inline bool Connection::tls_send( const void *data, size_t len )
00712 {
00713 if( len <= 0 )
00714 return true;
00715
00716 SECURITY_STATUS ret;
00717
00718 m_obuffers[0].BufferType = SECBUFFER_STREAM_HEADER;
00719 m_obuffers[0].pvBuffer = m_oBuffer;
00720 m_obuffers[0].cbBuffer = m_streamSizes.cbHeader;
00721
00722 m_obuffers[1].BufferType = SECBUFFER_DATA;
00723 m_obuffers[1].pvBuffer = m_messageOffset;
00724
00725 m_obuffers[2].BufferType = SECBUFFER_STREAM_TRAILER;
00726 m_obuffers[2].cbBuffer = m_streamSizes.cbTrailer;
00727
00728 m_obuffers[3].BufferType = SECBUFFER_EMPTY;
00729 m_obuffers[3].pvBuffer = NULL;
00730 m_obuffers[3].cbBuffer = 0;
00731
00732 m_omessage.ulVersion = SECBUFFER_VERSION;
00733 m_omessage.cBuffers = 4;
00734 m_omessage.pBuffers = m_obuffers;
00735
00736 while( len > 0 )
00737 {
00738 if( m_streamSizes.cbMaximumMessage < len )
00739 {
00740 memcpy( m_messageOffset, data, m_streamSizes.cbMaximumMessage );
00741 len -= m_streamSizes.cbMaximumMessage;
00742 m_obuffers[1].cbBuffer = m_streamSizes.cbMaximumMessage;
00743 m_obuffers[2].pvBuffer = m_messageOffset + m_streamSizes.cbMaximumMessage;
00744 }
00745 else
00746 {
00747 memcpy( m_messageOffset, data, len );
00748 m_obuffers[1].cbBuffer = len;
00749 m_obuffers[2].pvBuffer = m_messageOffset + len;
00750 len = 0;
00751 }
00752
00753 ret = m_securityFunc->EncryptMessage( &m_context, 0, &m_omessage, 0 );
00754 if( ret != SEC_E_OK )
00755 {
00756 printf( "encryptmessage failed %ld\n", ret );
00757 return false;
00758 }
00759
00760 int t = ::send( m_socket, m_oBuffer,
00761 m_obuffers[0].cbBuffer + m_obuffers[1].cbBuffer + m_obuffers[2].cbBuffer, 0 );
00762 if( t == SOCKET_ERROR || t == 0 )
00763 {
00764 printf( "could not send: %d\n", WSAGetLastError() );
00765 return false;
00766 }
00767 }
00768
00769 return true;
00770 }
00771
00772 inline int Connection::tls_recv( void *data, size_t len )
00773 {
00774 SECURITY_STATUS ret;
00775 SecBuffer *dataBuffer = 0;
00776 int readable = 0;
00777
00778 int maxLength = m_streamSizes.cbHeader + m_streamSizes.cbMaximumMessage + m_streamSizes.cbTrailer;
00779
00780 printf( "bufferOffset is %d\n", m_bufferOffset );
00781
00782 int t = ::recv( m_socket, m_iBuffer + m_bufferOffset, maxLength - m_bufferOffset, 0 );
00783 if( t == SOCKET_ERROR )
00784 {
00785 printf( "got SocketError\n" );
00786 return 0;
00787 }
00788 else if( t == 0 )
00789 {
00790 printf( "got connection close\n" );
00791 return 0;
00792 }
00793 else
00794 m_bufferOffset += t;
00795
00796 while( m_bufferOffset )
00797 {
00798 printf( "continuing with bufferOffset: %d\n", m_bufferOffset );
00799
00800 m_ibuffers[0].pvBuffer = m_iBuffer;
00801 m_ibuffers[0].cbBuffer = m_bufferOffset;
00802 m_ibuffers[0].BufferType = SECBUFFER_DATA;
00803
00804 m_ibuffers[1].BufferType = SECBUFFER_EMPTY;
00805 m_ibuffers[2].BufferType = SECBUFFER_EMPTY;
00806 m_ibuffers[3].BufferType = SECBUFFER_EMPTY;
00807
00808 m_imessage.ulVersion = SECBUFFER_VERSION;
00809 m_imessage.cBuffers = 4;
00810 m_imessage.pBuffers = m_ibuffers;
00811
00812 ret = m_securityFunc->DecryptMessage( &m_context, &m_imessage, 0, NULL );
00813
00814 if( ret == SEC_E_INCOMPLETE_MESSAGE )
00815 {
00816 printf( "recv'ed incomplete message\n" );
00817 return readable;
00818 }
00819
00820
00821
00822
00823
00824 if( ret != SEC_E_OK && ret != SEC_I_RENEGOTIATE )
00825 {
00826 printf( "DecryptMessage returned %ld\n", ret );
00827 printf( "GetLastError(): %ld\n", GetLastError() );
00828 printf( "input buffer length: %d, read in this run: %d\n", m_bufferOffset, t );
00829 return false;
00830 }
00831
00832 m_bufferOffset = 0;
00833
00834 for( int i = 1; i < 4; ++i )
00835 {
00836 if( dataBuffer == 0 && m_ibuffers[i].BufferType == SECBUFFER_DATA )
00837 {
00838 dataBuffer = &m_ibuffers[i];
00839 }
00840 if( m_bufferOffset == 0 && m_ibuffers[i].BufferType == SECBUFFER_EXTRA )
00841 {
00842
00843 printf( "git exetra buffer, size %ld\n", m_ibuffers[i].cbBuffer );
00844
00845
00846 }
00847 }
00848
00849 if( dataBuffer )
00850 {
00851 if( dataBuffer->cbBuffer > len )
00852 {
00853 memcpy( data, dataBuffer->pvBuffer, len );
00854 return len;
00855 }
00856 else
00857 {
00858 memcpy( data, dataBuffer->pvBuffer, dataBuffer->cbBuffer );
00859 readable += dataBuffer->cbBuffer;
00860 printf( "recvbuffer (%d): %s\n", readable, data );
00861 }
00862 }
00863
00864 if( ret == SEC_I_RENEGOTIATE )
00865 {
00866 printf( "server requested reneg\n" );
00867 ret = handshakeLoop();
00868 }
00869 }
00870
00871 return readable;
00872 }
00873
00874 inline bool Connection::tls_dataAvailable()
00875 {
00876 return false;
00877 }
00878
00879 inline void Connection::tls_cleanup()
00880 {
00881 m_securityFunc->DeleteSecurityContext( &m_context );
00882 }
00883 #endif
00884
00885 #ifdef HAVE_ZLIB
00886 bool Connection::initCompression( StreamFeature method )
00887 {
00888 delete m_compression;
00889 m_compression = 0;
00890 m_compression = new Compression( method );
00891 return true;
00892 }
00893
00894 void Connection::enableCompression()
00895 {
00896 if( !m_compression )
00897 return;
00898
00899 m_enableCompression = true;
00900 }
00901 #endif
00902
00903 ConnectionState Connection::connect()
00904 {
00905 if( m_socket != -1 && m_state >= StateConnecting )
00906 {
00907 return m_state;
00908 }
00909
00910 m_state = StateConnecting;
00911
00912 if( m_port == ( unsigned short ) -1 )
00913 m_socket = DNS::connect( m_server, m_logInstance );
00914 else
00915 m_socket = DNS::connect( m_server, m_port, m_logInstance );
00916
00917 if( m_socket < 0 )
00918 {
00919 switch( m_socket )
00920 {
00921 case -DNS::DNS_COULD_NOT_CONNECT:
00922 m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: could not connect" );
00923 break;
00924 case -DNS::DNS_NO_HOSTS_FOUND:
00925 m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: no hosts found" );
00926 break;
00927 case -DNS::DNS_COULD_NOT_RESOLVE:
00928 m_logInstance.log( LogLevelError, LogAreaClassConnection, "connection error: could not resolve" );
00929 break;
00930 }
00931 cleanup();
00932 }
00933 else
00934 m_state = StateConnected;
00935
00936 m_cancel = false;
00937 return m_state;
00938 }
00939
00940 void Connection::disconnect( ConnectionError e )
00941 {
00942 m_disconnect = e;
00943 m_cancel = true;
00944
00945 if( m_fdRequested )
00946 cleanup();
00947 }
00948
00949 int Connection::fileDescriptor()
00950 {
00951 m_fdRequested = true;
00952 return m_socket;
00953 }
00954
00955 bool Connection::dataAvailable( int timeout )
00956 {
00957 if( m_socket < 0 )
00958 return true;
00959
00960 #ifdef HAVE_TLS
00961 if( tls_dataAvailable() )
00962 {
00963 return true;
00964 }
00965 #endif
00966
00967 fd_set fds;
00968 struct timeval tv;
00969
00970 FD_ZERO( &fds );
00971 FD_SET( m_socket, &fds );
00972
00973 tv.tv_sec = timeout / 1000000;
00974 tv.tv_usec = timeout % 1000000;
00975
00976 if( select( m_socket + 1, &fds, 0, 0, timeout == -1 ? 0 : &tv ) > 0 )
00977 {
00978 return FD_ISSET( m_socket, &fds ) ? true : false;
00979 }
00980 return false;
00981 }
00982
00983 ConnectionError Connection::recv( int timeout )
00984 {
00985 if( m_cancel )
00986 {
00987 ConnectionError e = m_disconnect;
00988 cleanup();
00989 return e;
00990 }
00991
00992 if( m_socket == -1 )
00993 return ConnNotConnected;
00994
00995 if( !m_fdRequested && !dataAvailable( timeout ) )
00996 {
00997 return ConnNoError;
00998 }
00999
01000
01001 memset( m_buf, '\0', m_bufsize + 1 );
01002 int size = 0;
01003 #ifdef HAVE_TLS
01004 if( m_secure )
01005 {
01006 size = tls_recv( m_buf, m_bufsize );
01007 }
01008 else
01009 #endif
01010 {
01011 #ifdef SKYOS
01012 size = ::recv( m_socket, (unsigned char*)m_buf, m_bufsize, 0 );
01013 #else
01014 size = ::recv( m_socket, m_buf, m_bufsize, 0 );
01015 #endif
01016 }
01017
01018 if( size < 0 )
01019 {
01020
01021 m_secure = false;
01022 cleanup();
01023 return ConnIoError;
01024 }
01025 else if( size == 0 )
01026 {
01027
01028 m_secure = false;
01029 cleanup();
01030 return ConnUserDisconnected;
01031 }
01032
01033 std::string buf;
01034 buf.assign( m_buf, size );
01035 if( m_compression && m_enableCompression )
01036 buf = m_compression->decompress( buf );
01037
01038 Parser::ParserState ret = m_parser->feed( buf );
01039 if( ret != Parser::PARSER_OK )
01040 {
01041 cleanup();
01042 switch( ret )
01043 {
01044 case Parser::PARSER_BADXML:
01045 m_logInstance.log( LogLevelError, LogAreaClassConnection, "XML parse error" );
01046 break;
01047 case Parser::PARSER_NOMEM:
01048 m_logInstance.log( LogLevelError, LogAreaClassConnection, "memory allocation error" );
01049 break;
01050 default:
01051 m_logInstance.log( LogLevelError, LogAreaClassConnection, "unexpected error" );
01052 break;
01053 }
01054
01055 return ConnIoError;
01056 }
01057
01058 return ConnNoError;
01059 }
01060
01061 ConnectionError Connection::receive()
01062 {
01063 if( m_socket == -1 || !m_parser )
01064 return ConnNotConnected;
01065
01066 while( !m_cancel )
01067 {
01068 ConnectionError r = recv( 1 );
01069 if( r != ConnNoError )
01070 {
01071 return r;
01072 }
01073 }
01074 ConnectionError e = m_disconnect;
01075 cleanup();
01076
01077 return e;
01078 }
01079
01080 bool Connection::send( const std::string& data )
01081 {
01082 if( data.empty() || ( m_socket == -1 ) )
01083 return false;
01084
01085 std::string xml;
01086 if( m_compression && m_enableCompression )
01087 xml = m_compression->compress( data );
01088 else
01089 xml = data;
01090
01091 #ifdef HAVE_TLS
01092 if( m_secure )
01093 {
01094 size_t len = xml.length();
01095 if( tls_send( xml.c_str (), len ) == false )
01096 return false;
01097 }
01098 else
01099 #endif
01100 {
01101 size_t num = 0;
01102 size_t len = xml.length();
01103 while( num < len )
01104 {
01105 #ifdef SKYOS
01106 int sent = ::send( m_socket, (unsigned char*)(xml.c_str()+num), len - num, 0 );
01107 #else
01108 int sent = ::send( m_socket, (xml.c_str()+num), len - num, 0 );
01109 #endif
01110 if( sent == -1 )
01111 return false;
01112
01113 num += sent;
01114 }
01115 }
01116
01117 return true;
01118 }
01119
01120 void Connection::cleanup()
01121 {
01122 #ifdef HAVE_TLS
01123 if( m_secure )
01124 {
01125 tls_cleanup();
01126 }
01127 #endif
01128
01129 if( m_socket != -1 )
01130 {
01131 #ifdef WIN32
01132 closesocket( m_socket );
01133 #else
01134 close( m_socket );
01135 #endif
01136 m_socket = -1;
01137 }
01138 m_state = StateDisconnected;
01139 m_disconnect = ConnNoError;
01140 m_enableCompression = false;
01141 m_secure = false;
01142 m_cancel = true;
01143 m_fdRequested = false;
01144 }
01145
01146 }