CommUlongStringKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "kernel/CommUlongStringKernel.h"
00013 #include "features/StringFeatures.h"
00014 #include "lib/io.h"
00015 
00016 CCommUlongStringKernel::CCommUlongStringKernel(
00017     INT size, bool us, ENormalizationType n)
00018 : CStringKernel<ULONG>(size), sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00019     initialized(false), use_sign(us), normalization(n)
00020 {
00021     properties |= KP_LINADD;
00022     clear_normal();
00023 }
00024 
00025 CCommUlongStringKernel::CCommUlongStringKernel(
00026     CStringFeatures<ULONG>* l, CStringFeatures<ULONG>* r, bool us,
00027     ENormalizationType n, INT size)
00028 : CStringKernel<ULONG>(size), sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00029     initialized(false), use_sign(us), normalization(n)
00030 {
00031     properties |= KP_LINADD;
00032     clear_normal();
00033     init(l,r);
00034 }
00035 
00036 CCommUlongStringKernel::~CCommUlongStringKernel()
00037 {
00038     cleanup();
00039 }
00040 
00041 void CCommUlongStringKernel::remove_lhs()
00042 {
00043     delete_optimization();
00044 
00045 #ifdef SVMLIGHT
00046     if (lhs)
00047         cache_reset();
00048 #endif
00049 
00050     if (sqrtdiag_lhs != sqrtdiag_rhs)
00051         delete[] sqrtdiag_rhs;
00052     delete[] sqrtdiag_lhs;
00053 
00054     lhs = NULL ; 
00055     rhs = NULL ; 
00056     initialized = false;
00057     sqrtdiag_lhs = NULL;
00058     sqrtdiag_rhs = NULL;
00059 }
00060 
00061 void CCommUlongStringKernel::remove_rhs()
00062 {
00063 #ifdef SVMLIGHT
00064     if (rhs)
00065         cache_reset();
00066 #endif
00067 
00068     if (sqrtdiag_lhs != sqrtdiag_rhs)
00069         delete[] sqrtdiag_rhs;
00070     sqrtdiag_rhs = sqrtdiag_lhs;
00071     rhs = lhs;
00072 }
00073 
00074 bool CCommUlongStringKernel::init(CFeatures* l, CFeatures* r)
00075 {
00076     bool result=CStringKernel<ULONG>::init(l,r);
00077     INT i;
00078 
00079     initialized=false;
00080     if (sqrtdiag_lhs!=sqrtdiag_rhs)
00081         delete[] sqrtdiag_rhs;
00082     sqrtdiag_rhs=NULL;
00083     delete[] sqrtdiag_lhs;
00084     sqrtdiag_lhs=new DREAL[lhs->get_num_vectors()];
00085 
00086     for (i=0; i<lhs->get_num_vectors(); i++)
00087         sqrtdiag_lhs[i]=1;
00088 
00089     if (l==r)
00090         sqrtdiag_rhs=sqrtdiag_lhs;
00091     else
00092     {
00093         sqrtdiag_rhs=new DREAL[rhs->get_num_vectors()];
00094         for (i=0; i<rhs->get_num_vectors(); i++)
00095             sqrtdiag_rhs[i]=1;
00096     }
00097 
00098     this->lhs=(CStringFeatures<ULONG>*) l;
00099     this->rhs=(CStringFeatures<ULONG>*) l;
00100 
00101     //compute normalize to 1 values
00102     for (i=0; i<lhs->get_num_vectors(); i++)
00103     {
00104         sqrtdiag_lhs[i]=sqrt(compute(i,i));
00105 
00106         //trap divide by zero exception
00107         if (sqrtdiag_lhs[i]==0)
00108             sqrtdiag_lhs[i]=1e-16;
00109     }
00110 
00111     // if lhs is different from rhs (train/test data)
00112     // compute also the normalization for rhs
00113     if (sqrtdiag_lhs!=sqrtdiag_rhs)
00114     {
00115         this->lhs=(CStringFeatures<ULONG>*) r;
00116         this->rhs=(CStringFeatures<ULONG>*) r;
00117 
00118         //compute normalize to 1 values
00119         for (i=0; i<rhs->get_num_vectors(); i++)
00120         {
00121             sqrtdiag_rhs[i]=sqrt(compute(i,i));
00122 
00123             //trap divide by zero exception
00124             if (sqrtdiag_rhs[i]==0)
00125                 sqrtdiag_rhs[i]=1e-16;
00126         }
00127     }
00128 
00129     this->lhs=(CStringFeatures<ULONG>*) l;
00130     this->rhs=(CStringFeatures<ULONG>*) r;
00131 
00132     initialized = true;
00133     return result;
00134 }
00135 
00136 void CCommUlongStringKernel::cleanup()
00137 {
00138     delete_optimization();
00139     clear_normal();
00140 
00141     initialized=false;
00142 
00143     if (sqrtdiag_lhs != sqrtdiag_rhs)
00144         delete[] sqrtdiag_rhs;
00145 
00146     sqrtdiag_rhs=NULL;
00147 
00148     delete[] sqrtdiag_lhs;
00149     sqrtdiag_lhs=NULL;
00150 
00151     CKernel::cleanup();
00152 }
00153 
00154 bool CCommUlongStringKernel::load_init(FILE* src)
00155 {
00156     return false;
00157 }
00158 
00159 bool CCommUlongStringKernel::save_init(FILE* dest)
00160 {
00161     return false;
00162 }
00163   
00164 DREAL CCommUlongStringKernel::compute(INT idx_a, INT idx_b)
00165 {
00166     INT alen, blen;
00167     ULONG* avec=((CStringFeatures<ULONG>*) lhs)->get_feature_vector(idx_a, alen);
00168     ULONG* bvec=((CStringFeatures<ULONG>*) rhs)->get_feature_vector(idx_b, blen);
00169 
00170     DREAL result=0;
00171 
00172     INT left_idx=0;
00173     INT right_idx=0;
00174 
00175     if (use_sign)
00176     {
00177         while (left_idx < alen && right_idx < blen)
00178         {
00179             if (avec[left_idx]==bvec[right_idx])
00180             {
00181                 ULONG sym=avec[left_idx];
00182 
00183                 while (left_idx< alen && avec[left_idx]==sym)
00184                     left_idx++;
00185 
00186                 while (right_idx< blen && bvec[right_idx]==sym)
00187                     right_idx++;
00188 
00189                 result++;
00190             }
00191             else if (avec[left_idx]<bvec[right_idx])
00192                 left_idx++;
00193             else
00194                 right_idx++;
00195         }
00196     }
00197     else
00198     {
00199         while (left_idx < alen && right_idx < blen)
00200         {
00201             if (avec[left_idx]==bvec[right_idx])
00202             {
00203                 INT old_left_idx=left_idx;
00204                 INT old_right_idx=right_idx;
00205 
00206                 ULONG sym=avec[left_idx];
00207 
00208                 while (left_idx< alen && avec[left_idx]==sym)
00209                     left_idx++;
00210 
00211                 while (right_idx< blen && bvec[right_idx]==sym)
00212                     right_idx++;
00213 
00214                 result+=((DREAL) (left_idx-old_left_idx)) * ((DREAL) (right_idx-old_right_idx));
00215             }
00216             else if (avec[left_idx]<bvec[right_idx])
00217                 left_idx++;
00218             else
00219                 right_idx++;
00220         }
00221     }
00222 
00223     if (initialized)
00224     {
00225         switch (normalization)
00226         {
00227             case NO_NORMALIZATION:
00228                 return result;
00229             case SQRT_NORMALIZATION:
00230                 return result/sqrt(sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]);
00231             case FULL_NORMALIZATION:
00232                 return result/(sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]);
00233             case SQRTLEN_NORMALIZATION:
00234                 return result/sqrt(sqrt(alen*blen));
00235             case LEN_NORMALIZATION:
00236                 return result/sqrt(alen*blen);
00237             case SQLEN_NORMALIZATION:
00238                 return result/(alen*blen);
00239             default:
00240             SG_ERROR( "Unknown Normalization in use!\n");
00241                 return -CMath::INFTY;
00242         }
00243     }
00244     else
00245         return result;
00246 }
00247 
00248 void CCommUlongStringKernel::add_to_normal(INT vec_idx, DREAL weight)
00249 {
00250     INT t=0;
00251     INT j=0;
00252     INT k=0;
00253     INT last_j=0;
00254     INT len=-1;
00255     ULONG* vec=((CStringFeatures<ULONG>*) lhs)->get_feature_vector(vec_idx, len);
00256 
00257     if (vec && len>0)
00258     {
00259         //use malloc not new [] as DynamicArray uses it
00260         ULONG* dic= (ULONG*) malloc((len+dictionary.get_num_elements())*sizeof(ULONG));
00261         DREAL* dic_weights= (DREAL*) malloc((len+dictionary.get_num_elements())*sizeof(DREAL));
00262 
00263         if (use_sign)
00264         {
00265             for (j=1; j<len; j++)
00266             {
00267                 if (vec[j]==vec[j-1])
00268                     continue;
00269 
00270                 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight, vec_idx, len, normalization);
00271             }
00272 
00273             merge_dictionaries(t, j, k, vec, dic, dic_weights, weight, vec_idx, len, normalization);
00274 
00275             while (k<dictionary.get_num_elements())
00276             {
00277                 dic[t]=dictionary[k];
00278                 dic_weights[t]=dictionary_weights[k];
00279                 t++;
00280                 k++;
00281             }
00282         }
00283         else
00284         {
00285             for (j=1; j<len; j++)
00286             {
00287                 if (vec[j]==vec[j-1])
00288                     continue;
00289 
00290                 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight*(j-last_j), vec_idx, len, normalization);
00291                 last_j = j;
00292             }
00293 
00294             merge_dictionaries(t, j, k, vec, dic, dic_weights, weight*(j-last_j), vec_idx, len, normalization);
00295 
00296             while (k<dictionary.get_num_elements())
00297             {
00298                 dic[t]=dictionary[k];
00299                 dic_weights[t]=dictionary_weights[k];
00300                 t++;
00301                 k++;
00302             }
00303         }
00304 
00305         dictionary.set_array(dic, t, len+dictionary.get_num_elements());
00306         dictionary_weights.set_array(dic_weights, t, len+dictionary.get_num_elements());
00307     }
00308 
00309     set_is_initialized(true);
00310 }
00311 
00312 void CCommUlongStringKernel::clear_normal()
00313 {
00314     dictionary.resize_array(0);
00315     dictionary_weights.resize_array(0);
00316     set_is_initialized(false);
00317 }
00318 
00319 bool CCommUlongStringKernel::init_optimization(INT count, INT *IDX, DREAL * weights) 
00320 {
00321     clear_normal();
00322 
00323     if (count<=0)
00324     {
00325         set_is_initialized(true);
00326         SG_DEBUG( "empty set of SVs\n");
00327         return true;
00328     }
00329 
00330     SG_DEBUG( "initializing CCommUlongStringKernel optimization\n");
00331 
00332     for (int i=0; i<count; i++)
00333     {
00334         if ( (i % (count/10+1)) == 0)
00335             SG_PROGRESS(i, 0, count);
00336 
00337         add_to_normal(IDX[i], weights[i]);
00338     }
00339 
00340     SG_PRINT( "Done.         \n");
00341     
00342     set_is_initialized(true);
00343     return true;
00344 }
00345 
00346 bool CCommUlongStringKernel::delete_optimization() 
00347 {
00348     SG_DEBUG( "deleting CCommUlongStringKernel optimization\n");
00349     clear_normal();
00350     return true;
00351 }
00352 
00353 // binary search for each feature. trick: as features are sorted save last found idx in old_idx and
00354 // only search in the remainder of the dictionary
00355 DREAL CCommUlongStringKernel::compute_optimized(INT i) 
00356 { 
00357     DREAL result = 0;
00358     INT j, last_j=0;
00359     INT old_idx = 0;
00360 
00361     if (!get_is_initialized())
00362     {
00363       SG_ERROR( "CCommUlongStringKernel optimization not initialized\n");
00364         return 0 ; 
00365     }
00366 
00367 
00368 
00369     INT alen = -1;
00370     ULONG* avec=((CStringFeatures<ULONG>*) rhs)->get_feature_vector(i, alen);
00371 
00372     if (avec && alen>0)
00373     {
00374         if (use_sign)
00375         {
00376             for (j=1; j<alen; j++)
00377             {
00378                 if (avec[j]==avec[j-1])
00379                     continue;
00380 
00381                 INT idx = CMath::binary_search_max_lower_equal(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[j-1]);
00382 
00383                 if (idx!=-1)
00384                 {
00385                     if (dictionary[idx+old_idx] == avec[j-1])
00386                         result += dictionary_weights[idx+old_idx];
00387 
00388                     old_idx+=idx;
00389                 }
00390             }
00391 
00392             INT idx = CMath::binary_search(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[alen-1]);
00393             if (idx!=-1)
00394                 result += dictionary_weights[idx+old_idx];
00395         }
00396         else
00397         {
00398             for (j=1; j<alen; j++)
00399             {
00400                 if (avec[j]==avec[j-1])
00401                     continue;
00402 
00403                 INT idx = CMath::binary_search_max_lower_equal(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[j-1]);
00404 
00405                 if (idx!=-1)
00406                 {
00407                     if (dictionary[idx+old_idx] == avec[j-1])
00408                         result += dictionary_weights[idx+old_idx]*(j-last_j);
00409 
00410                     old_idx+=idx;
00411                 }
00412 
00413                 last_j = j;
00414             }
00415 
00416             INT idx = CMath::binary_search(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[alen-1]);
00417             if (idx!=-1)
00418                 result += dictionary_weights[idx+old_idx]*(alen-last_j);
00419         }
00420 
00421         switch (normalization)
00422         {
00423             case NO_NORMALIZATION:
00424                 return result;
00425             case SQRT_NORMALIZATION:
00426                 return result/sqrt(sqrtdiag_rhs[i]);
00427             case FULL_NORMALIZATION:
00428                 return result/sqrtdiag_rhs[i];
00429             case SQRTLEN_NORMALIZATION:
00430                 return result/sqrt(sqrt(alen));
00431             case LEN_NORMALIZATION:
00432                 return result/sqrt(alen);
00433             case SQLEN_NORMALIZATION:
00434                 return result/alen;
00435             default:
00436             SG_ERROR( "Unknown Normalization in use!\n");
00437                 return -CMath::INFTY;
00438         }
00439     }
00440     return result;
00441 }

SHOGUN Machine Learning Toolbox - Documentation