connection.cpp

00001 /*
00002   Copyright (c) 2004-2005 by Jakob Schroeter <js@camaya.net>
00003   This file is part of the gloox library. http://camaya.net/gloox
00004 
00005   This software is distributed under a license. The full license
00006   agreement can be found in the file LICENSE in this distribution.
00007   This software may not be copied, modified, sold or distributed
00008   other than expressed in the named license agreement.
00009 
00010   This software is distributed without any warranty.
00011 */
00012 
00013 
00014 
00015 #include "gloox.h"
00016 
00017 #include "connection.h"
00018 #include "dns.h"
00019 #include "prep.h"
00020 #include "parser.h"
00021 
00022 #ifdef __MINGW32__
00023 #include <winsock.h>
00024 #endif
00025 
00026 #ifndef WIN32
00027 #include <sys/types.h>
00028 #include <sys/socket.h>
00029 #include <sys/select.h>
00030 #include <unistd.h>
00031 #else
00032 #include <winsock.h>
00033 #define strcasecmp stricmp
00034 #endif
00035 
00036 #include <time.h>
00037 
00038 #include <string>
00039 
00040 namespace gloox
00041 {
00042 
00043   static const int BUFSIZE = 1024;
00044 
00045   Connection::Connection( Parser *parser, const std::string& server, int port )
00046     : m_parser( parser ), m_buf( 0 ), m_server( Prep::idna( server ) ), m_port( port ),
00047       m_socket( -1 ), m_compCount( 0 ), m_decompCount( 0 ), m_dataOutCount( 0 ),
00048       m_dataInCount( 0 ), m_cancel( true ), m_secure( false ), m_compression( false ),
00049       m_fdRequested( false ), m_compInited( false )
00050   {
00051     m_buf = (char*)calloc( BUFSIZE, sizeof( char ) );
00052   }
00053 
00054   Connection::~Connection()
00055   {
00056     cleanup();
00057     free( m_buf );
00058     m_buf = 0;
00059 #ifdef HAVE_ZLIB
00060     initCompression( false );
00061 #endif
00062   }
00063 
00064 #if defined( USE_OPENSSL )
00065   bool Connection::tlsHandshake()
00066   {
00067     SSL_library_init();
00068     SSL_CTX *sslCTX = SSL_CTX_new( TLSv1_client_method() );
00069     if( !sslCTX )
00070       return false;
00071 
00072     if( !SSL_CTX_set_cipher_list( sslCTX, "HIGH:MEDIUM:AES:@STRENGTH" ) )
00073       return false;
00074 
00075     StringList::const_iterator it = m_cacerts.begin();
00076     for( ; it != m_cacerts.end(); ++it )
00077       SSL_CTX_load_verify_locations( sslCTX, (*it).c_str(), NULL );
00078 
00079     m_ssl = SSL_new( sslCTX );
00080     SSL_set_connect_state( m_ssl );
00081 
00082     BIO *socketBio = BIO_new_socket( m_socket, BIO_NOCLOSE );
00083     if( !socketBio )
00084       return false;
00085 
00086     SSL_set_bio( m_ssl, socketBio, socketBio );
00087     SSL_set_mode( m_ssl, SSL_MODE_AUTO_RETRY );
00088 
00089     if( !SSL_connect( m_ssl ) )
00090       return false;
00091 
00092     m_secure = true;
00093 
00094     int res = SSL_get_verify_result( m_ssl );
00095     if( res != X509_V_OK )
00096       m_certInfo.status = CERT_INVALID;
00097     else
00098       m_certInfo.status = CERT_OK;
00099 
00100     X509 *peer;
00101     peer = SSL_get_peer_certificate( m_ssl );
00102     if( peer )
00103     {
00104       char peer_CN[256];
00105       X509_NAME_get_text_by_NID( X509_get_issuer_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00106       m_certInfo.issuer = peer_CN;
00107       X509_NAME_get_text_by_NID( X509_get_subject_name( peer ), NID_commonName, peer_CN, sizeof( peer_CN ) );
00108       m_certInfo.server = peer_CN;
00109       if( strcasecmp( peer_CN, m_server.c_str() ) )
00110         m_certInfo.status |= CERT_WRONG_PEER;
00111     }
00112     else
00113     {
00114       m_certInfo.status = CERT_INVALID;
00115     }
00116 
00117     const char *tmp;
00118     tmp = SSL_get_cipher_name( m_ssl );
00119     if( tmp )
00120       m_certInfo.cipher = tmp;
00121 
00122     tmp = SSL_get_cipher_version( m_ssl );
00123     if( tmp )
00124       m_certInfo.protocol = tmp;
00125 
00126     return true;
00127   }
00128 
00129 #elif defined( USE_GNUTLS )
00130   bool Connection::tlsHandshake()
00131   {
00132     const int protocolPriority[] = { GNUTLS_TLS1, GNUTLS_SSL3, 0 };
00133     const int kxPriority[]       = { GNUTLS_KX_RSA, 0 };
00134     const int cipherPriority[]   = { GNUTLS_CIPHER_AES_256_CBC, GNUTLS_CIPHER_AES_128_CBC,
00135                                              GNUTLS_CIPHER_3DES_CBC, GNUTLS_CIPHER_ARCFOUR, 0 };
00136     const int compPriority[]     = { GNUTLS_COMP_ZLIB, GNUTLS_COMP_NULL, 0 };
00137     const int macPriority[]      = { GNUTLS_MAC_SHA, GNUTLS_MAC_MD5, 0 };
00138 
00139     if( gnutls_global_init() != 0 )
00140       return false;
00141 
00142     if( gnutls_certificate_allocate_credentials( &m_credentials ) < 0 )
00143       return false;
00144 
00145     StringList::const_iterator it = m_cacerts.begin();
00146     for( ; it != m_cacerts.end(); ++it )
00147       gnutls_certificate_set_x509_trust_file( m_credentials, (*it).c_str(), GNUTLS_X509_FMT_PEM );
00148 
00149     if( gnutls_init( &m_session, GNUTLS_CLIENT ) != 0 )
00150     {
00151       gnutls_certificate_free_credentials( m_credentials );
00152       return false;
00153     }
00154 
00155     gnutls_protocol_set_priority( m_session, protocolPriority );
00156     gnutls_cipher_set_priority( m_session, cipherPriority );
00157     gnutls_compression_set_priority( m_session, compPriority );
00158     gnutls_kx_set_priority( m_session, kxPriority );
00159     gnutls_mac_set_priority( m_session, macPriority );
00160     gnutls_credentials_set( m_session, GNUTLS_CRD_CERTIFICATE, m_credentials );
00161 
00162     gnutls_transport_set_ptr( m_session, (gnutls_transport_ptr_t)m_socket );
00163     if( gnutls_handshake( m_session ) != 0 )
00164     {
00165       gnutls_deinit( m_session );
00166       gnutls_certificate_free_credentials( m_credentials );
00167       return false;
00168     }
00169     gnutls_certificate_free_ca_names( m_credentials );
00170 
00171     m_secure = true;
00172 
00173     unsigned int status;
00174     bool error = false;
00175 
00176     if( gnutls_certificate_verify_peers2( m_session, &status ) < 0 )
00177       error = true;
00178 
00179     m_certInfo.status = 0;
00180     if( status & GNUTLS_CERT_INVALID )
00181       m_certInfo.status |= CERT_INVALID;
00182     if( status & GNUTLS_CERT_SIGNER_NOT_FOUND )
00183       m_certInfo.status |= CERT_SIGNER_UNKNOWN;
00184     if( status & GNUTLS_CERT_REVOKED )
00185       m_certInfo.status |= CERT_REVOKED;
00186     if( status & GNUTLS_CERT_SIGNER_NOT_CA )
00187       m_certInfo.status |= CERT_SIGNER_NOT_CA;
00188     const gnutls_datum_t* certList = 0;
00189     unsigned int certListSize;
00190     if( !error && ( ( certList = gnutls_certificate_get_peers( m_session, &certListSize ) ) == 0 ) )
00191       error = true;
00192 
00193     gnutls_x509_crt_t *cert = new gnutls_x509_crt_t[certListSize+1];
00194     for( unsigned int i=0; !error && ( i<certListSize ); ++i )
00195     {
00196       if( !error && ( gnutls_x509_crt_init( &cert[i] ) < 0 ) )
00197         error = true;
00198       if( !error && ( gnutls_x509_crt_import( cert[i], &certList[i], GNUTLS_X509_FMT_DER ) < 0 ) )
00199         error = true;
00200     }
00201 
00202     if( ( gnutls_x509_crt_check_issuer( cert[certListSize-1], cert[certListSize-1] ) > 0 )
00203          && certListSize > 0 )
00204       certListSize--;
00205 
00206     bool chain = true;
00207     for( unsigned int i=1; !error && ( i<certListSize ); ++i )
00208     {
00209       chain = error = !verifyAgainst( cert[i-1], cert[i] );
00210     }
00211     if( !chain )
00212       m_certInfo.status |= CERT_INVALID;
00213     m_certInfo.chain = chain;
00214 
00215     m_certInfo.chain = verifyAgainstCAs( cert[certListSize], 0 /*CAList*/, 0 /*CAListSize*/ );
00216 
00217     int t = (int)gnutls_x509_crt_get_expiration_time( cert[0] );
00218     if( t == -1 )
00219       error = true;
00220     else if( t < time( 0 ) )
00221       m_certInfo.status |= CERT_EXPIRED;
00222     m_certInfo.date_from = t;
00223 
00224     t = (int)gnutls_x509_crt_get_activation_time( cert[0] );
00225     if( t == -1 )
00226       error = true;
00227     else if( t > time( 0 ) )
00228       m_certInfo.status |= CERT_NOT_ACTIVE;
00229     m_certInfo.date_to = t;
00230 
00231     char name[64];
00232     size_t nameSize = sizeof( name );
00233     gnutls_x509_crt_get_issuer_dn( cert[0], name, &nameSize );
00234     m_certInfo.issuer = name;
00235 
00236     nameSize = sizeof( name );
00237     gnutls_x509_crt_get_dn( cert[0], name, &nameSize );
00238     m_certInfo.server = name;
00239 
00240     const char* info;
00241     info = gnutls_compression_get_name( gnutls_compression_get( m_session ) );
00242     if( info )
00243       m_certInfo.compression = info;
00244 
00245     info = gnutls_mac_get_name( gnutls_mac_get( m_session ) );
00246     if( info )
00247       m_certInfo.mac = info;
00248 
00249     info = gnutls_cipher_get_name( gnutls_cipher_get( m_session ) );
00250     if( info )
00251       m_certInfo.cipher = info;
00252 
00253     info = gnutls_protocol_get_name( gnutls_protocol_get_version( m_session ) );
00254     if( info )
00255       m_certInfo.protocol = info;
00256 
00257     if( !gnutls_x509_crt_check_hostname( cert[0], m_server.c_str() ) )
00258       m_certInfo.status |= CERT_WRONG_PEER;
00259 
00260     for( unsigned int i=0; i<certListSize; ++i )
00261       gnutls_x509_crt_deinit( cert[i] );
00262 
00263     delete[] cert;
00264 
00265     return true;
00266   }
00267 
00268   bool Connection::verifyAgainst( gnutls_x509_crt_t cert, gnutls_x509_crt_t issuer )
00269   {
00270     unsigned int result;
00271     gnutls_x509_crt_verify( cert, &issuer, 1, 0, &result );
00272     if( result & GNUTLS_CERT_INVALID )
00273       return false;
00274 
00275     if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00276       return false;
00277 
00278     if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00279       return false;
00280 
00281     return true;
00282   }
00283 
00284   bool Connection::verifyAgainstCAs( gnutls_x509_crt_t cert, gnutls_x509_crt_t *CAList, int CAListSize )
00285   {
00286     unsigned int result;
00287     gnutls_x509_crt_verify( cert, CAList, CAListSize, GNUTLS_VERIFY_ALLOW_X509_V1_CA_CRT, &result );
00288     if( result & GNUTLS_CERT_INVALID )
00289       return false;
00290 
00291     if( gnutls_x509_crt_get_expiration_time( cert ) < time( 0 ) )
00292       return false;
00293 
00294     if( gnutls_x509_crt_get_activation_time( cert ) > time( 0 ) )
00295       return false;
00296 
00297     return true;
00298   }
00299 #endif
00300 
00301 #ifdef HAVE_ZLIB
00302   bool Connection::initCompression( bool init )
00303   {
00304     int ret = Z_OK;
00305 
00306     if( init )
00307     {
00308       m_zinflate.zalloc = Z_NULL;
00309       m_zinflate.zfree = Z_NULL;
00310       m_zinflate.opaque = Z_NULL;
00311       ret = inflateInit( &m_zinflate );
00312     }
00313     else if( m_compInited && !init )
00314       inflateEnd( &m_zinflate );
00315 
00316     if( ret == Z_OK )
00317     {
00318       m_compInited = init;
00319       return true;
00320     }
00321     else
00322     {
00323       m_compInited = false;
00324       return false;
00325     }
00326   }
00327 
00328   void Connection::setCompression( bool compression )
00329   {
00330     if( m_compInited )
00331       m_compression = compression;
00332   }
00333 
00334   std::string Connection::compress( const std::string& data )
00335   {
00336     if( data.empty() )
00337       return "";
00338 
00339     int CHUNK = data.length() + ( data.length() / 100 ) + 13;
00340     Bytef *out = new Bytef[CHUNK];
00341     const char *in = data.c_str();
00342 
00343     ::compress( out, (uLongf*)&CHUNK, (Bytef*)in, data.length() );
00344     std::string result;
00345     result.assign( (char*)out, CHUNK );
00346     m_compCount += result.length();
00347     m_dataOutCount += data.length();
00348     delete[] out;
00349 
00350     return result;
00351   }
00352 
00353   std::string Connection::decompress( const std::string& data )
00354   {
00355     if( data.empty() )
00356       return "";
00357 
00358     int CHUNK = data.length() * 10;
00359     char *out = new char[CHUNK];
00360     const char *in = data.c_str();
00361 
00362     m_zinflate.avail_in = data.length();
00363     m_zinflate.next_in = (Bytef*)in;
00364 
00365     int ret;
00366     std::string result, tmp;
00367     do {
00368       m_zinflate.avail_out = CHUNK;
00369       m_zinflate.next_out = (Bytef*)out;
00370 
00371       ret = inflate( &m_zinflate, Z_FINISH );
00372       tmp.assign( out, CHUNK - m_zinflate.avail_out );
00373       result += tmp;
00374     } while( m_zinflate.avail_out == 0 );
00375 
00376     m_decompCount += result.length();
00377     m_dataInCount += data.length();
00378     delete[] out;
00379 
00380     return result;
00381   }
00382 #endif
00383 
00384   void Connection::disconnect( ConnectionError e )
00385   {
00386     m_disconnect = e;
00387     m_cancel = true;
00388 
00389     if( m_fdRequested )
00390       cleanup();
00391   }
00392 
00393   int Connection::fileDescriptor()
00394   {
00395     m_fdRequested = true;
00396     return m_socket;
00397   }
00398 
00399   ConnectionError Connection::recv( int timeout )
00400   {
00401     if( m_cancel )
00402     {
00403       ConnectionError e = m_disconnect;
00404       cleanup();
00405       return e;
00406     }
00407 
00408     if( !m_fdRequested )
00409     {
00410       fd_set fds;
00411       struct timeval tv;
00412 
00413       FD_ZERO( &fds );
00414       FD_SET( m_socket, &fds );
00415 
00416       tv.tv_sec = timeout;
00417       tv.tv_usec = 0;
00418 
00419       if( select( m_socket + 1, &fds, 0, 0, timeout == -1 ? 0 : &tv ) < 0 )
00420         return CONN_IO_ERROR;
00421 
00422       if( !FD_ISSET( m_socket, &fds ) )
00423         return CONN_OK;
00424     }
00425 
00426     // optimize(?): recv returns the size. set size+1 = \0
00427     memset( m_buf, '\0', BUFSIZE );
00428     int size;
00429 #if defined( USE_GNUTLS )
00430     if( m_secure )
00431     {
00432       size = gnutls_record_recv( m_session, m_buf, BUFSIZE );
00433     }
00434     else
00435 #elif defined( USE_OPENSSL )
00436     if( m_secure )
00437     {
00438       size = SSL_read( m_ssl, m_buf, BUFSIZE );
00439     }
00440     else
00441 #endif
00442     {
00443 #ifdef SKYOS
00444       size = ::recv( m_socket, (unsigned char*)m_buf, BUFSIZE - 1, 0 );
00445 #else
00446       size = ::recv( m_socket, m_buf, BUFSIZE - 1, 0 );
00447 #endif
00448     }
00449 
00450     if( size < 0 )
00451     {
00452       // error
00453       return CONN_IO_ERROR;
00454     }
00455     else if( size == 0 )
00456     {
00457       // connection closed
00458       return CONN_USER_DISCONNECTED;
00459     }
00460     else
00461     {
00462       std::string buf;
00463 #ifdef HAVE_ZLIB
00464       if( m_compression )
00465         buf = decompress( m_buf );
00466       else
00467 #endif
00468         buf = m_buf;
00469 
00470       Parser::ParserState ret = m_parser->feed( buf );
00471       if( ret != Parser::PARSER_OK )
00472       {
00473         cleanup();
00474 #ifdef DEBUG
00475         switch( ret )
00476         {
00477           case Parser::PARSER_BADXML:
00478             printf( "XML parse error\n" );
00479             break;
00480           case Parser::PARSER_NOMEM:
00481             printf( "memory allocation error\n" );
00482             break;
00483           default:
00484             break;
00485         }
00486 #endif
00487         return CONN_IO_ERROR;
00488       }
00489     }
00490 
00491     return CONN_OK;
00492   }
00493 
00494   ConnectionError Connection::receive()
00495   {
00496     if( m_socket == -1 || !m_parser )
00497       return CONN_IO_ERROR;
00498 
00499     while( !m_cancel )
00500     {
00501       ConnectionError r = recv();
00502       if( r != CONN_OK )
00503         return r;
00504     }
00505 
00506     return m_disconnect;
00507   }
00508 
00509   void Connection::send( const std::string& data )
00510   {
00511     if( data.empty() )
00512       return;
00513 
00514     char *xml;
00515 #ifdef HAVE_ZLIB
00516     if( m_compression )
00517       xml = strdup( compress( data ).c_str() );
00518     else
00519 #endif
00520       xml = strdup( data.c_str() );
00521 
00522     if( !xml )
00523       return;
00524 
00525 #if defined( USE_GNUTLS )
00526     if( m_secure )
00527     {
00528       int ret;
00529       int len = strlen( xml );
00530       do
00531       {
00532         ret = gnutls_record_send( m_session, xml, len );
00533       }
00534       while( ( ret == GNUTLS_E_AGAIN ) || ( ret == GNUTLS_E_INTERRUPTED ) );
00535     }
00536     else
00537 #elif defined( USE_OPENSSL )
00538     if( m_secure )
00539     {
00540       int ret;
00541       int len = strlen( xml );
00542       ret = SSL_write( m_ssl, xml, len );
00543     }
00544     else
00545 #endif
00546     {
00547       int num = 0;
00548       int len = strlen( xml );
00549       while( num < len )
00550 #ifdef SKYOS
00551         num += ::send( m_socket, (unsigned char*)(xml+num), len - num, 0 );
00552 #else
00553         num += ::send( m_socket, (xml+num), len - num, 0 );
00554 #endif
00555     }
00556 
00557     free( xml );
00558   }
00559 
00560   ConnectionState Connection::connect()
00561   {
00562     if( m_socket != -1 && m_state >= STATE_CONNECTING )
00563       return m_state;
00564 
00565     m_state = STATE_CONNECTING;
00566 
00567     if( m_port == -1 )
00568       m_socket = DNS::connect( m_server );
00569     else
00570       m_socket = DNS::connect( m_server, m_port );
00571 
00572     if( m_socket < 0 )
00573     {
00574 #ifdef DEBUG
00575       switch( m_socket )
00576       {
00577         case -DNS::DNS_COULD_NOT_CONNECT:
00578           printf( "could not connect\n" );
00579           break;
00580         case -DNS::DNS_NO_HOSTS_FOUND:
00581           printf( "no hosts found\n" );
00582           break;
00583         case -DNS::DNS_COULD_NOT_RESOLVE:
00584           printf( "could not resolve\n" );
00585           break;
00586       }
00587       printf( "connection error\n" );
00588 #endif
00589       cleanup();
00590     }
00591     else
00592       m_state = STATE_CONNECTED;
00593 
00594     m_cancel = false;
00595     return m_state;
00596   }
00597 
00598   void Connection::cleanup()
00599   {
00600     if( m_socket != -1 )
00601     {
00602 #ifdef WIN32
00603       closesocket( m_socket );
00604 #else
00605       close( m_socket );
00606 #endif
00607       m_socket = -1;
00608     }
00609     m_state = STATE_DISCONNECTED;
00610     m_disconnect = CONN_OK;
00611 
00612 #if defined( USE_GNUTLS )
00613     if( m_secure )
00614     {
00615       gnutls_bye( m_session, GNUTLS_SHUT_RDWR );
00616       gnutls_deinit( m_session );
00617       gnutls_certificate_free_credentials( m_credentials );
00618       gnutls_global_deinit();
00619     }
00620 #elif defined( USE_OPENSSL )
00621     if( m_secure )
00622     {
00623       SSL_shutdown( m_ssl );
00624       SSL_free( m_ssl );
00625     }
00626 #endif
00627     m_secure = false;
00628     m_cancel = true;
00629     m_fdRequested = false;
00630   }
00631 
00632 }

Generated on Mon Jan 16 16:19:54 2006 for gloox by  doxygen 1.4.6