00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "kernel/CommWordStringKernel.h"
00013 #include "kernel/SqrtDiagKernelNormalizer.h"
00014 #include "features/StringFeatures.h"
00015 #include "lib/io.h"
00016
00017 CCommWordStringKernel::CCommWordStringKernel(int32_t size, bool s)
00018 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00019 use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00020 {
00021 properties |= KP_LINADD;
00022 init_dictionary(1<<(sizeof(uint16_t)*8));
00023 set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00024 }
00025
00026 CCommWordStringKernel::CCommWordStringKernel(
00027 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, bool s,
00028 int32_t size)
00029 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00030 use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00031 {
00032 properties |= KP_LINADD;
00033
00034 init_dictionary(1<<(sizeof(uint16_t)*8));
00035 set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00036 init(l,r);
00037 }
00038
00039
00040 bool CCommWordStringKernel::init_dictionary(int32_t size)
00041 {
00042 dictionary_size= size;
00043 delete[] dictionary_weights;
00044 dictionary_weights=new float64_t[size];
00045 SG_DEBUG( "using dictionary of %d words\n", size);
00046 clear_normal();
00047
00048 return dictionary_weights!=NULL;
00049 }
00050
00051 CCommWordStringKernel::~CCommWordStringKernel()
00052 {
00053 cleanup();
00054
00055 delete[] dictionary_weights;
00056 delete[] dict_diagonal_optimization ;
00057 }
00058
00059 void CCommWordStringKernel::remove_lhs()
00060 {
00061 delete_optimization();
00062
00063 #ifdef SVMLIGHT
00064 if (lhs)
00065 cache_reset();
00066 #endif
00067
00068 lhs = NULL ;
00069 rhs = NULL ;
00070 }
00071
00072 void CCommWordStringKernel::remove_rhs()
00073 {
00074 #ifdef SVMLIGHT
00075 if (rhs)
00076 cache_reset();
00077 #endif
00078
00079 rhs = lhs;
00080 }
00081
00082 bool CCommWordStringKernel::init(CFeatures* l, CFeatures* r)
00083 {
00084 CStringKernel<uint16_t>::init(l,r);
00085
00086 if (use_dict_diagonal_optimization)
00087 {
00088 delete[] dict_diagonal_optimization ;
00089 dict_diagonal_optimization=new int32_t[int32_t(((CStringFeatures<uint16_t>*)l)->get_num_symbols())];
00090 ASSERT(((CStringFeatures<uint16_t>*)l)->get_num_symbols() == ((CStringFeatures<uint16_t>*)r)->get_num_symbols()) ;
00091 }
00092
00093 return init_normalizer();
00094 }
00095
00096 void CCommWordStringKernel::cleanup()
00097 {
00098 delete_optimization();
00099 clear_normal();
00100 CKernel::cleanup();
00101 }
00102
00103 bool CCommWordStringKernel::load_init(FILE* src)
00104 {
00105 return false;
00106 }
00107
00108 bool CCommWordStringKernel::save_init(FILE* dest)
00109 {
00110 return false;
00111 }
00112
00113 float64_t CCommWordStringKernel::compute_diag(int32_t idx_a)
00114 {
00115 int32_t alen;
00116 CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00117 CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00118
00119 uint16_t* av=l->get_feature_vector(idx_a, alen);
00120
00121 float64_t result=0.0 ;
00122 ASSERT(l==r);
00123 ASSERT(sizeof(uint16_t)<=sizeof(float64_t));
00124 ASSERT((1<<(sizeof(uint16_t)*8)) > alen);
00125
00126 int32_t num_symbols=(int32_t) l->get_num_symbols();
00127 ASSERT(num_symbols<=dictionary_size);
00128
00129 int32_t* dic = dict_diagonal_optimization;
00130 memset(dic, 0, num_symbols*sizeof(int32_t));
00131
00132 for (int32_t i=0; i<alen; i++)
00133 dic[av[i]]++;
00134
00135 if (use_sign)
00136 {
00137 for (int32_t i=0; i<(int32_t) l->get_num_symbols(); i++)
00138 {
00139 if (dic[i]!=0)
00140 result++;
00141 }
00142 }
00143 else
00144 {
00145 for (int32_t i=0; i<num_symbols; i++)
00146 {
00147 if (dic[i]!=0)
00148 result+=dic[i]*dic[i];
00149 }
00150 }
00151
00152 return result;
00153 }
00154
00155 float64_t CCommWordStringKernel::compute_helper(
00156 int32_t idx_a, int32_t idx_b, bool do_sort)
00157 {
00158 int32_t alen, blen;
00159 CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00160 CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00161
00162 uint16_t* av=l->get_feature_vector(idx_a, alen);
00163 uint16_t* bv=r->get_feature_vector(idx_b, blen);
00164
00165 uint16_t* avec=av;
00166 uint16_t* bvec=bv;
00167
00168 if (do_sort)
00169 {
00170 if (alen>0)
00171 {
00172 avec=new uint16_t[alen];
00173 memcpy(avec, av, sizeof(uint16_t)*alen);
00174 CMath::radix_sort(avec, alen);
00175 }
00176 else
00177 avec=NULL;
00178
00179 if (blen>0)
00180 {
00181 bvec=new uint16_t[blen];
00182 memcpy(bvec, bv, sizeof(uint16_t)*blen);
00183 CMath::radix_sort(bvec, blen);
00184 }
00185 else
00186 bvec=NULL;
00187 }
00188 else
00189 {
00190 if ( (l->get_num_preproc() != l->get_num_preprocessed()) ||
00191 (r->get_num_preproc() != r->get_num_preprocessed()))
00192 {
00193 SG_ERROR("not all preprocessors have been applied to training (%d/%d)"
00194 " or test (%d/%d) data\n", l->get_num_preprocessed(), l->get_num_preproc(),
00195 r->get_num_preprocessed(), r->get_num_preproc());
00196 }
00197 }
00198
00199 float64_t result=0;
00200
00201 int32_t left_idx=0;
00202 int32_t right_idx=0;
00203
00204 if (use_sign)
00205 {
00206 while (left_idx < alen && right_idx < blen)
00207 {
00208 if (avec[left_idx]==bvec[right_idx])
00209 {
00210 uint16_t sym=avec[left_idx];
00211
00212 while (left_idx< alen && avec[left_idx]==sym)
00213 left_idx++;
00214
00215 while (right_idx< blen && bvec[right_idx]==sym)
00216 right_idx++;
00217
00218 result++;
00219 }
00220 else if (avec[left_idx]<bvec[right_idx])
00221 left_idx++;
00222 else
00223 right_idx++;
00224 }
00225 }
00226 else
00227 {
00228 while (left_idx < alen && right_idx < blen)
00229 {
00230 if (avec[left_idx]==bvec[right_idx])
00231 {
00232 int32_t old_left_idx=left_idx;
00233 int32_t old_right_idx=right_idx;
00234
00235 uint16_t sym=avec[left_idx];
00236
00237 while (left_idx< alen && avec[left_idx]==sym)
00238 left_idx++;
00239
00240 while (right_idx< blen && bvec[right_idx]==sym)
00241 right_idx++;
00242
00243 result+=((float64_t) (left_idx-old_left_idx))*
00244 ((float64_t) (right_idx-old_right_idx));
00245 }
00246 else if (avec[left_idx]<bvec[right_idx])
00247 left_idx++;
00248 else
00249 right_idx++;
00250 }
00251 }
00252
00253 if (do_sort)
00254 {
00255 delete[] avec;
00256 delete[] bvec;
00257 }
00258
00259 return result;
00260 }
00261
00262 void CCommWordStringKernel::add_to_normal(int32_t vec_idx, float64_t weight)
00263 {
00264 int32_t len=-1;
00265 uint16_t* vec=((CStringFeatures<uint16_t>*) lhs)->
00266 get_feature_vector(vec_idx, len);
00267
00268 if (len>0)
00269 {
00270 int32_t j, last_j=0;
00271 if (use_sign)
00272 {
00273 for (j=1; j<len; j++)
00274 {
00275 if (vec[j]==vec[j-1])
00276 continue;
00277
00278 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00279 normalize_lhs(weight, vec_idx);
00280 }
00281
00282 dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00283 normalize_lhs(weight, vec_idx);
00284 }
00285 else
00286 {
00287 for (j=1; j<len; j++)
00288 {
00289 if (vec[j]==vec[j-1])
00290 continue;
00291
00292 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00293 normalize_lhs(weight*(j-last_j), vec_idx);
00294 last_j = j;
00295 }
00296
00297 dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00298 normalize_lhs(weight*(len-last_j), vec_idx);
00299 }
00300 set_is_initialized(true);
00301 }
00302 }
00303
00304 void CCommWordStringKernel::clear_normal()
00305 {
00306 memset(dictionary_weights, 0, dictionary_size*sizeof(float64_t));
00307 set_is_initialized(false);
00308 }
00309
00310 bool CCommWordStringKernel::init_optimization(
00311 int32_t count, int32_t* IDX, float64_t* weights)
00312 {
00313 delete_optimization();
00314
00315 if (count<=0)
00316 {
00317 set_is_initialized(true);
00318 SG_DEBUG("empty set of SVs\n");
00319 return true;
00320 }
00321
00322 SG_DEBUG("initializing CCommWordStringKernel optimization\n");
00323
00324 for (int32_t i=0; i<count; i++)
00325 {
00326 if ( (i % (count/10+1)) == 0)
00327 SG_PROGRESS(i, 0, count);
00328
00329 add_to_normal(IDX[i], weights[i]);
00330 }
00331
00332 set_is_initialized(true);
00333 return true;
00334 }
00335
00336 bool CCommWordStringKernel::delete_optimization()
00337 {
00338 SG_DEBUG( "deleting CCommWordStringKernel optimization\n");
00339
00340 clear_normal();
00341 return true;
00342 }
00343
00344 float64_t CCommWordStringKernel::compute_optimized(int32_t i)
00345 {
00346 if (!get_is_initialized())
00347 {
00348 SG_ERROR( "CCommWordStringKernel optimization not initialized\n");
00349 return 0 ;
00350 }
00351
00352 float64_t result = 0;
00353 int32_t len = -1;
00354 uint16_t* vec=((CStringFeatures<uint16_t>*) rhs)->
00355 get_feature_vector(i, len);
00356
00357 int32_t j, last_j=0;
00358 if (vec && len>0)
00359 {
00360 if (use_sign)
00361 {
00362 for (j=1; j<len; j++)
00363 {
00364 if (vec[j]==vec[j-1])
00365 continue;
00366
00367 result += dictionary_weights[(int32_t) vec[j-1]];
00368 }
00369
00370 result += dictionary_weights[(int32_t) vec[len-1]];
00371 }
00372 else
00373 {
00374 for (j=1; j<len; j++)
00375 {
00376 if (vec[j]==vec[j-1])
00377 continue;
00378
00379 result += dictionary_weights[(int32_t) vec[j-1]]*(j-last_j);
00380 last_j = j;
00381 }
00382
00383 result += dictionary_weights[(int32_t) vec[len-1]]*(len-last_j);
00384 }
00385
00386 result=normalizer->normalize_rhs(result, i);
00387 }
00388 return result;
00389 }
00390
00391 float64_t* CCommWordStringKernel::compute_scoring(
00392 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, float64_t* target,
00393 int32_t num_suppvec, int32_t* IDX, float64_t* alphas, bool do_init)
00394 {
00395 ASSERT(lhs);
00396 CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00397 num_feat=1;
00398 CAlphabet* alpha=str->get_alphabet();
00399 ASSERT(alpha);
00400 int32_t num_bits=alpha->get_num_bits();
00401 int32_t order=str->get_order();
00402 ASSERT(max_degree<=order);
00403
00404 int32_t num_words=(int32_t) str->get_original_num_symbols();
00405 int32_t offset=0;
00406
00407 num_sym=0;
00408
00409 for (int32_t i=0; i<order; i++)
00410 num_sym+=CMath::pow((int32_t) num_words,i+1);
00411
00412 SG_DEBUG("num_words:%d, order:%d, len:%d sz:%d (len*sz:%d)\n", num_words, order,
00413 num_feat, num_sym, num_feat*num_sym);
00414
00415 if (!target)
00416 target=new float64_t[num_feat*num_sym];
00417 memset(target, 0, num_feat*num_sym*sizeof(float64_t));
00418
00419 if (do_init)
00420 init_optimization(num_suppvec, IDX, alphas);
00421
00422 uint32_t kmer_mask=0;
00423 uint32_t words=CMath::pow((int32_t) num_words,(int32_t) order);
00424
00425 for (int32_t o=0; o<max_degree; o++)
00426 {
00427 float64_t* contrib=&target[offset];
00428 offset+=CMath::pow((int32_t) num_words,(int32_t) o+1);
00429
00430 kmer_mask=(kmer_mask<<(num_bits)) | str->get_masked_symbols(0xffff, 1);
00431
00432 for (int32_t p=-o; p<order; p++)
00433 {
00434 int32_t o_sym=0, m_sym=0, il=0,ir=0, jl=0;
00435 uint32_t imer_mask=kmer_mask;
00436 uint32_t jmer_mask=kmer_mask;
00437
00438 if (p<0)
00439 {
00440 il=-p;
00441 m_sym=order-o-p-1;
00442 o_sym=-p;
00443 }
00444 else if (p<order-o)
00445 {
00446 ir=p;
00447 m_sym=order-o-1;
00448 }
00449 else
00450 {
00451 ir=p;
00452 m_sym=p;
00453 o_sym=p-order+o+1;
00454 jl=order-ir;
00455 imer_mask=(kmer_mask>>(num_bits*o_sym));
00456 jmer_mask=(kmer_mask>>(num_bits*jl));
00457 }
00458
00459 float64_t marginalizer=
00460 1.0/CMath::pow((int32_t) num_words,(int32_t) m_sym);
00461
00462 for (uint32_t i=0; i<words; i++)
00463 {
00464 uint16_t x= ((i << (num_bits*il)) >> (num_bits*ir)) & imer_mask;
00465
00466 if (p>=0 && p<order-o)
00467 {
00468
00469 #ifdef DEBUG_COMMSCORING
00470 SG_PRINT("o=%d/%d p=%d/%d i=0x%x x=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d marg=%g o_sym:%d m_sym:%d weight(",
00471 o,order, p,order, i, x, imer_mask, jmer_mask, kmer_mask, il, ir, marginalizer, o_sym, m_sym);
00472
00473 SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n",
00474 alpha->remap_to_char((x>>(3*num_bits))&0x03), alpha->remap_to_char((x>>(2*num_bits))&0x03),
00475 alpha->remap_to_char((x>>num_bits)&0x03), alpha->remap_to_char(x&0x03),
00476 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00477 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00478 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00479 #endif
00480 contrib[x]+=dictionary_weights[i]*marginalizer;
00481 }
00482 else
00483 {
00484 for (uint32_t j=0; j< (uint32_t) CMath::pow((int32_t) num_words, (int32_t) o_sym); j++)
00485 {
00486 uint32_t c=x | ((j & jmer_mask) << (num_bits*jl));
00487 #ifdef DEBUG_COMMSCORING
00488
00489 SG_PRINT("o=%d/%d p=%d/%d i=0x%x j=0x%x x=0x%x c=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d jl=%d marg=%g o_sym:%d m_sym:%d weight(",
00490 o,order, p,order, i, j, x, c, imer_mask, jmer_mask, kmer_mask, il, ir, jl, marginalizer, o_sym, m_sym);
00491 SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n",
00492 alpha->remap_to_char((c>>(3*num_bits))&0x03), alpha->remap_to_char((c>>(2*num_bits))&0x03),
00493 alpha->remap_to_char((c>>num_bits)&0x03), alpha->remap_to_char(c&0x03),
00494 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00495 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00496 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00497 #endif
00498 contrib[c]+=dictionary_weights[i]*marginalizer;
00499 }
00500 }
00501 }
00502 }
00503 }
00504
00505 for (int32_t i=1; i<num_feat; i++)
00506 memcpy(&target[num_sym*i], target, num_sym*sizeof(float64_t));
00507
00508 SG_UNREF(alpha);
00509
00510 return target;
00511 }
00512
00513
00514 char* CCommWordStringKernel::compute_consensus(
00515 int32_t &result_len, int32_t num_suppvec, int32_t* IDX, float64_t* alphas)
00516 {
00517 ASSERT(lhs);
00518 ASSERT(IDX);
00519 ASSERT(alphas);
00520
00521 CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00522 int32_t num_words=(int32_t) str->get_num_symbols();
00523 int32_t num_feat=str->get_max_vector_length();
00524 int64_t total_len=((int64_t) num_feat) * num_words;
00525 CAlphabet* alpha=((CStringFeatures<uint16_t>*) lhs)->get_alphabet();
00526 ASSERT(alpha);
00527 int32_t num_bits=alpha->get_num_bits();
00528 int32_t order=str->get_order();
00529 int32_t max_idx=-1;
00530 float64_t max_score=0;
00531 result_len=num_feat+order-1;
00532
00533
00534 init_optimization(num_suppvec, IDX, alphas);
00535
00536 char* result=new char[result_len];
00537 int32_t* bt=new int32_t[total_len];
00538 float64_t* score=new float64_t[total_len];
00539
00540 for (int64_t i=0; i<total_len; i++)
00541 {
00542 bt[i]=-1;
00543 score[i]=0;
00544 }
00545
00546 for (int32_t t=0; t<num_words; t++)
00547 score[t]=dictionary_weights[t];
00548
00549
00550 for (int32_t i=1; i<num_feat; i++)
00551 {
00552 for (int32_t t1=0; t1<num_words; t1++)
00553 {
00554 max_idx=-1;
00555 max_score=0;
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565
00566 uint16_t suffix=(uint16_t) t1 >> num_bits;
00567
00568 for (int32_t sym=0; sym<str->get_original_num_symbols(); sym++)
00569 {
00570 uint16_t t=suffix | sym << (num_bits*(order-1));
00571
00572
00573
00574
00575 float64_t sc=score[num_words*(i-1) + t]+dictionary_weights[t1];
00576 if (sc > max_score || max_idx==-1)
00577 {
00578 max_idx=t;
00579 max_score=sc;
00580 }
00581 }
00582 ASSERT(max_idx!=-1);
00583
00584 score[num_words*i + t1]=max_score;
00585 bt[num_words*i + t1]=max_idx;
00586 }
00587 }
00588
00589
00590 max_idx=0;
00591 max_score=score[num_words*(num_feat-1) + 0];
00592 for (int32_t t=1; t<num_words; t++)
00593 {
00594 float64_t sc=score[num_words*(num_feat-1) + t];
00595 if (sc>max_score)
00596 {
00597 max_idx=t;
00598 max_score=sc;
00599 }
00600 }
00601
00602 SG_PRINT("max_idx:%i, max_score:%f\n", max_idx, max_score);
00603
00604 for (int32_t i=result_len-1; i>=num_feat; i--)
00605 result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(result_len-1-i)), 1) );
00606
00607 for (int32_t i=num_feat-1; i>=0; i--)
00608 {
00609 result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(order-1)), 1) );
00610 max_idx=bt[num_words*i + max_idx];
00611 }
00612
00613 delete[] bt;
00614 delete[] score;
00615 SG_UNREF(alpha);
00616 return result;
00617 }