00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "kernel/CommUlongStringKernel.h"
00013 #include "features/StringFeatures.h"
00014 #include "lib/io.h"
00015
00016 CCommUlongStringKernel::CCommUlongStringKernel(
00017 INT size, bool us, ENormalizationType n)
00018 : CStringKernel<ULONG>(size), sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00019 initialized(false), use_sign(us), normalization(n)
00020 {
00021 properties |= KP_LINADD;
00022 clear_normal();
00023 }
00024
00025 CCommUlongStringKernel::CCommUlongStringKernel(
00026 CStringFeatures<ULONG>* l, CStringFeatures<ULONG>* r, bool us,
00027 ENormalizationType n, INT size)
00028 : CStringKernel<ULONG>(size), sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL),
00029 initialized(false), use_sign(us), normalization(n)
00030 {
00031 properties |= KP_LINADD;
00032 clear_normal();
00033 init(l,r);
00034 }
00035
00036 CCommUlongStringKernel::~CCommUlongStringKernel()
00037 {
00038 cleanup();
00039 }
00040
00041 void CCommUlongStringKernel::remove_lhs()
00042 {
00043 delete_optimization();
00044
00045 #ifdef SVMLIGHT
00046 if (lhs)
00047 cache_reset();
00048 #endif
00049
00050 if (sqrtdiag_lhs != sqrtdiag_rhs)
00051 delete[] sqrtdiag_rhs;
00052 delete[] sqrtdiag_lhs;
00053
00054 lhs = NULL ;
00055 rhs = NULL ;
00056 initialized = false;
00057 sqrtdiag_lhs = NULL;
00058 sqrtdiag_rhs = NULL;
00059 }
00060
00061 void CCommUlongStringKernel::remove_rhs()
00062 {
00063 #ifdef SVMLIGHT
00064 if (rhs)
00065 cache_reset();
00066 #endif
00067
00068 if (sqrtdiag_lhs != sqrtdiag_rhs)
00069 delete[] sqrtdiag_rhs;
00070 sqrtdiag_rhs = sqrtdiag_lhs;
00071 rhs = lhs;
00072 }
00073
00074 bool CCommUlongStringKernel::init(CFeatures* l, CFeatures* r)
00075 {
00076 bool result=CStringKernel<ULONG>::init(l,r);
00077 INT i;
00078
00079 initialized=false;
00080 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00081 delete[] sqrtdiag_rhs;
00082 sqrtdiag_rhs=NULL;
00083 delete[] sqrtdiag_lhs;
00084 sqrtdiag_lhs=new DREAL[lhs->get_num_vectors()];
00085
00086 for (i=0; i<lhs->get_num_vectors(); i++)
00087 sqrtdiag_lhs[i]=1;
00088
00089 if (l==r)
00090 sqrtdiag_rhs=sqrtdiag_lhs;
00091 else
00092 {
00093 sqrtdiag_rhs=new DREAL[rhs->get_num_vectors()];
00094 for (i=0; i<rhs->get_num_vectors(); i++)
00095 sqrtdiag_rhs[i]=1;
00096 }
00097
00098 this->lhs=(CStringFeatures<ULONG>*) l;
00099 this->rhs=(CStringFeatures<ULONG>*) l;
00100
00101
00102 for (i=0; i<lhs->get_num_vectors(); i++)
00103 {
00104 sqrtdiag_lhs[i]=sqrt(compute(i,i));
00105
00106
00107 if (sqrtdiag_lhs[i]==0)
00108 sqrtdiag_lhs[i]=1e-16;
00109 }
00110
00111
00112
00113 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00114 {
00115 this->lhs=(CStringFeatures<ULONG>*) r;
00116 this->rhs=(CStringFeatures<ULONG>*) r;
00117
00118
00119 for (i=0; i<rhs->get_num_vectors(); i++)
00120 {
00121 sqrtdiag_rhs[i]=sqrt(compute(i,i));
00122
00123
00124 if (sqrtdiag_rhs[i]==0)
00125 sqrtdiag_rhs[i]=1e-16;
00126 }
00127 }
00128
00129 this->lhs=(CStringFeatures<ULONG>*) l;
00130 this->rhs=(CStringFeatures<ULONG>*) r;
00131
00132 initialized = true;
00133 return result;
00134 }
00135
00136 void CCommUlongStringKernel::cleanup()
00137 {
00138 delete_optimization();
00139 clear_normal();
00140
00141 initialized=false;
00142
00143 if (sqrtdiag_lhs != sqrtdiag_rhs)
00144 delete[] sqrtdiag_rhs;
00145
00146 sqrtdiag_rhs=NULL;
00147
00148 delete[] sqrtdiag_lhs;
00149 sqrtdiag_lhs=NULL;
00150
00151 CKernel::cleanup();
00152 }
00153
00154 bool CCommUlongStringKernel::load_init(FILE* src)
00155 {
00156 return false;
00157 }
00158
00159 bool CCommUlongStringKernel::save_init(FILE* dest)
00160 {
00161 return false;
00162 }
00163
00164 DREAL CCommUlongStringKernel::compute(INT idx_a, INT idx_b)
00165 {
00166 INT alen, blen;
00167 ULONG* avec=((CStringFeatures<ULONG>*) lhs)->get_feature_vector(idx_a, alen);
00168 ULONG* bvec=((CStringFeatures<ULONG>*) rhs)->get_feature_vector(idx_b, blen);
00169
00170 DREAL result=0;
00171
00172 INT left_idx=0;
00173 INT right_idx=0;
00174
00175 if (use_sign)
00176 {
00177 while (left_idx < alen && right_idx < blen)
00178 {
00179 if (avec[left_idx]==bvec[right_idx])
00180 {
00181 ULONG sym=avec[left_idx];
00182
00183 while (left_idx< alen && avec[left_idx]==sym)
00184 left_idx++;
00185
00186 while (right_idx< blen && bvec[right_idx]==sym)
00187 right_idx++;
00188
00189 result++;
00190 }
00191 else if (avec[left_idx]<bvec[right_idx])
00192 left_idx++;
00193 else
00194 right_idx++;
00195 }
00196 }
00197 else
00198 {
00199 while (left_idx < alen && right_idx < blen)
00200 {
00201 if (avec[left_idx]==bvec[right_idx])
00202 {
00203 INT old_left_idx=left_idx;
00204 INT old_right_idx=right_idx;
00205
00206 ULONG sym=avec[left_idx];
00207
00208 while (left_idx< alen && avec[left_idx]==sym)
00209 left_idx++;
00210
00211 while (right_idx< blen && bvec[right_idx]==sym)
00212 right_idx++;
00213
00214 result+=((DREAL) (left_idx-old_left_idx)) * ((DREAL) (right_idx-old_right_idx));
00215 }
00216 else if (avec[left_idx]<bvec[right_idx])
00217 left_idx++;
00218 else
00219 right_idx++;
00220 }
00221 }
00222
00223 if (initialized)
00224 {
00225 switch (normalization)
00226 {
00227 case NO_NORMALIZATION:
00228 return result;
00229 case SQRT_NORMALIZATION:
00230 return result/sqrt(sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]);
00231 case FULL_NORMALIZATION:
00232 return result/(sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]);
00233 case SQRTLEN_NORMALIZATION:
00234 return result/sqrt(sqrt(alen*blen));
00235 case LEN_NORMALIZATION:
00236 return result/sqrt(alen*blen);
00237 case SQLEN_NORMALIZATION:
00238 return result/(alen*blen);
00239 default:
00240 SG_ERROR( "Unknown Normalization in use!\n");
00241 return -CMath::INFTY;
00242 }
00243 }
00244 else
00245 return result;
00246 }
00247
00248 void CCommUlongStringKernel::add_to_normal(INT vec_idx, DREAL weight)
00249 {
00250 INT t=0;
00251 INT j=0;
00252 INT k=0;
00253 INT last_j=0;
00254 INT len=-1;
00255 ULONG* vec=((CStringFeatures<ULONG>*) lhs)->get_feature_vector(vec_idx, len);
00256
00257 if (vec && len>0)
00258 {
00259
00260 ULONG* dic= (ULONG*) malloc((len+dictionary.get_num_elements())*sizeof(ULONG));
00261 DREAL* dic_weights= (DREAL*) malloc((len+dictionary.get_num_elements())*sizeof(DREAL));
00262
00263 if (use_sign)
00264 {
00265 for (j=1; j<len; j++)
00266 {
00267 if (vec[j]==vec[j-1])
00268 continue;
00269
00270 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight, vec_idx, len, normalization);
00271 }
00272
00273 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight, vec_idx, len, normalization);
00274
00275 while (k<dictionary.get_num_elements())
00276 {
00277 dic[t]=dictionary[k];
00278 dic_weights[t]=dictionary_weights[k];
00279 t++;
00280 k++;
00281 }
00282 }
00283 else
00284 {
00285 for (j=1; j<len; j++)
00286 {
00287 if (vec[j]==vec[j-1])
00288 continue;
00289
00290 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight*(j-last_j), vec_idx, len, normalization);
00291 last_j = j;
00292 }
00293
00294 merge_dictionaries(t, j, k, vec, dic, dic_weights, weight*(j-last_j), vec_idx, len, normalization);
00295
00296 while (k<dictionary.get_num_elements())
00297 {
00298 dic[t]=dictionary[k];
00299 dic_weights[t]=dictionary_weights[k];
00300 t++;
00301 k++;
00302 }
00303 }
00304
00305 dictionary.set_array(dic, t, len+dictionary.get_num_elements());
00306 dictionary_weights.set_array(dic_weights, t, len+dictionary.get_num_elements());
00307 }
00308
00309 set_is_initialized(true);
00310 }
00311
00312 void CCommUlongStringKernel::clear_normal()
00313 {
00314 dictionary.resize_array(0);
00315 dictionary_weights.resize_array(0);
00316 set_is_initialized(false);
00317 }
00318
00319 bool CCommUlongStringKernel::init_optimization(INT count, INT *IDX, DREAL * weights)
00320 {
00321 clear_normal();
00322
00323 if (count<=0)
00324 {
00325 set_is_initialized(true);
00326 SG_DEBUG( "empty set of SVs\n");
00327 return true;
00328 }
00329
00330 SG_DEBUG( "initializing CCommUlongStringKernel optimization\n");
00331
00332 for (int i=0; i<count; i++)
00333 {
00334 if ( (i % (count/10+1)) == 0)
00335 SG_PROGRESS(i, 0, count);
00336
00337 add_to_normal(IDX[i], weights[i]);
00338 }
00339
00340 SG_PRINT( "Done. \n");
00341
00342 set_is_initialized(true);
00343 return true;
00344 }
00345
00346 bool CCommUlongStringKernel::delete_optimization()
00347 {
00348 SG_DEBUG( "deleting CCommUlongStringKernel optimization\n");
00349 clear_normal();
00350 return true;
00351 }
00352
00353
00354
00355 DREAL CCommUlongStringKernel::compute_optimized(INT i)
00356 {
00357 DREAL result = 0;
00358 INT j, last_j=0;
00359 INT old_idx = 0;
00360
00361 if (!get_is_initialized())
00362 {
00363 SG_ERROR( "CCommUlongStringKernel optimization not initialized\n");
00364 return 0 ;
00365 }
00366
00367
00368
00369 INT alen = -1;
00370 ULONG* avec=((CStringFeatures<ULONG>*) rhs)->get_feature_vector(i, alen);
00371
00372 if (avec && alen>0)
00373 {
00374 if (use_sign)
00375 {
00376 for (j=1; j<alen; j++)
00377 {
00378 if (avec[j]==avec[j-1])
00379 continue;
00380
00381 INT idx = CMath::binary_search_max_lower_equal(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[j-1]);
00382
00383 if (idx!=-1)
00384 {
00385 if (dictionary[idx+old_idx] == avec[j-1])
00386 result += dictionary_weights[idx+old_idx];
00387
00388 old_idx+=idx;
00389 }
00390 }
00391
00392 INT idx = CMath::binary_search(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[alen-1]);
00393 if (idx!=-1)
00394 result += dictionary_weights[idx+old_idx];
00395 }
00396 else
00397 {
00398 for (j=1; j<alen; j++)
00399 {
00400 if (avec[j]==avec[j-1])
00401 continue;
00402
00403 INT idx = CMath::binary_search_max_lower_equal(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[j-1]);
00404
00405 if (idx!=-1)
00406 {
00407 if (dictionary[idx+old_idx] == avec[j-1])
00408 result += dictionary_weights[idx+old_idx]*(j-last_j);
00409
00410 old_idx+=idx;
00411 }
00412
00413 last_j = j;
00414 }
00415
00416 INT idx = CMath::binary_search(&(dictionary.get_array()[old_idx]), dictionary.get_num_elements()-old_idx, avec[alen-1]);
00417 if (idx!=-1)
00418 result += dictionary_weights[idx+old_idx]*(alen-last_j);
00419 }
00420
00421 switch (normalization)
00422 {
00423 case NO_NORMALIZATION:
00424 return result;
00425 case SQRT_NORMALIZATION:
00426 return result/sqrt(sqrtdiag_rhs[i]);
00427 case FULL_NORMALIZATION:
00428 return result/sqrtdiag_rhs[i];
00429 case SQRTLEN_NORMALIZATION:
00430 return result/sqrt(sqrt(alen));
00431 case LEN_NORMALIZATION:
00432 return result/sqrt(alen);
00433 case SQLEN_NORMALIZATION:
00434 return result/alen;
00435 default:
00436 SG_ERROR( "Unknown Normalization in use!\n");
00437 return -CMath::INFTY;
00438 }
00439 }
00440 return result;
00441 }