CommWordStringKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "kernel/CommWordStringKernel.h"
00013 #include "kernel/SqrtDiagKernelNormalizer.h"
00014 #include "features/StringFeatures.h"
00015 #include "lib/io.h"
00016 
00017 CCommWordStringKernel::CCommWordStringKernel(int32_t size, bool s)
00018 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00019     use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00020 {
00021     properties |= KP_LINADD;
00022     init_dictionary(1<<(sizeof(uint16_t)*8));
00023     set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00024 }
00025 
00026 CCommWordStringKernel::CCommWordStringKernel(
00027     CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, bool s,
00028     int32_t size)
00029 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00030     use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00031 {
00032     properties |= KP_LINADD;
00033 
00034     init_dictionary(1<<(sizeof(uint16_t)*8));
00035     set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00036     init(l,r);
00037 }
00038 
00039 
00040 bool CCommWordStringKernel::init_dictionary(int32_t size)
00041 {
00042     dictionary_size= size;
00043     delete[] dictionary_weights;
00044     dictionary_weights=new float64_t[size];
00045     SG_DEBUG( "using dictionary of %d words\n", size);
00046     clear_normal();
00047 
00048     return dictionary_weights!=NULL;
00049 }
00050 
00051 CCommWordStringKernel::~CCommWordStringKernel() 
00052 {
00053     cleanup();
00054 
00055     delete[] dictionary_weights;
00056     delete[] dict_diagonal_optimization ;
00057 }
00058   
00059 void CCommWordStringKernel::remove_lhs() 
00060 { 
00061     delete_optimization();
00062 
00063 #ifdef SVMLIGHT
00064     if (lhs)
00065         cache_reset();
00066 #endif
00067 
00068     lhs = NULL ; 
00069     rhs = NULL ; 
00070 }
00071 
00072 void CCommWordStringKernel::remove_rhs()
00073 {
00074 #ifdef SVMLIGHT
00075     if (rhs)
00076         cache_reset();
00077 #endif
00078 
00079     rhs = lhs;
00080 }
00081 
00082 bool CCommWordStringKernel::init(CFeatures* l, CFeatures* r)
00083 {
00084     CStringKernel<uint16_t>::init(l,r);
00085 
00086     if (use_dict_diagonal_optimization)
00087     {
00088         delete[] dict_diagonal_optimization ;
00089         dict_diagonal_optimization=new int32_t[int32_t(((CStringFeatures<uint16_t>*)l)->get_num_symbols())];
00090         ASSERT(((CStringFeatures<uint16_t>*)l)->get_num_symbols() == ((CStringFeatures<uint16_t>*)r)->get_num_symbols()) ;
00091     }
00092 
00093     return init_normalizer();
00094 }
00095 
00096 void CCommWordStringKernel::cleanup()
00097 {
00098     delete_optimization();
00099     clear_normal();
00100     CKernel::cleanup();
00101 }
00102 
00103 bool CCommWordStringKernel::load_init(FILE* src)
00104 {
00105     return false;
00106 }
00107 
00108 bool CCommWordStringKernel::save_init(FILE* dest)
00109 {
00110     return false;
00111 }
00112 
00113 float64_t CCommWordStringKernel::compute_diag(int32_t idx_a)
00114 {
00115     int32_t alen;
00116     CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00117     CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00118 
00119     uint16_t* av=l->get_feature_vector(idx_a, alen);
00120 
00121     float64_t result=0.0 ;
00122     ASSERT(l==r);
00123     ASSERT(sizeof(uint16_t)<=sizeof(float64_t));
00124     ASSERT((1<<(sizeof(uint16_t)*8)) > alen);
00125 
00126     int32_t num_symbols=(int32_t) l->get_num_symbols();
00127     ASSERT(num_symbols<=dictionary_size);
00128 
00129     int32_t* dic = dict_diagonal_optimization;
00130     memset(dic, 0, num_symbols*sizeof(int32_t));
00131 
00132     for (int32_t i=0; i<alen; i++)
00133         dic[av[i]]++;
00134 
00135     if (use_sign)
00136     {
00137         for (int32_t i=0; i<(int32_t) l->get_num_symbols(); i++)
00138         {
00139             if (dic[i]!=0)
00140                 result++;
00141         }
00142     }
00143     else
00144     {
00145         for (int32_t i=0; i<num_symbols; i++)
00146         {
00147             if (dic[i]!=0)
00148                 result+=dic[i]*dic[i];
00149         }
00150     }
00151 
00152     return result;
00153 }
00154 
00155 float64_t CCommWordStringKernel::compute_helper(
00156     int32_t idx_a, int32_t idx_b, bool do_sort)
00157 {
00158     int32_t alen, blen;
00159     CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00160     CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00161 
00162     uint16_t* av=l->get_feature_vector(idx_a, alen);
00163     uint16_t* bv=r->get_feature_vector(idx_b, blen);
00164 
00165     uint16_t* avec=av;
00166     uint16_t* bvec=bv;
00167 
00168     if (do_sort)
00169     {
00170         if (alen>0)
00171         {
00172             avec=new uint16_t[alen];
00173             memcpy(avec, av, sizeof(uint16_t)*alen);
00174             CMath::radix_sort(avec, alen);
00175         }
00176         else
00177             avec=NULL;
00178 
00179         if (blen>0)
00180         {
00181             bvec=new uint16_t[blen];
00182             memcpy(bvec, bv, sizeof(uint16_t)*blen);
00183             CMath::radix_sort(bvec, blen);
00184         }
00185         else
00186             bvec=NULL;
00187     }
00188     else
00189     {
00190         if ( (l->get_num_preproc() != l->get_num_preprocessed()) ||
00191                 (r->get_num_preproc() != r->get_num_preprocessed()))
00192         {
00193             SG_ERROR("not all preprocessors have been applied to training (%d/%d)"
00194                     " or test (%d/%d) data\n", l->get_num_preprocessed(), l->get_num_preproc(),
00195                     r->get_num_preprocessed(), r->get_num_preproc());
00196         }
00197     }
00198 
00199     float64_t result=0;
00200 
00201     int32_t left_idx=0;
00202     int32_t right_idx=0;
00203 
00204     if (use_sign)
00205     {
00206         while (left_idx < alen && right_idx < blen)
00207         {
00208             if (avec[left_idx]==bvec[right_idx])
00209             {
00210                 uint16_t sym=avec[left_idx];
00211 
00212                 while (left_idx< alen && avec[left_idx]==sym)
00213                     left_idx++;
00214 
00215                 while (right_idx< blen && bvec[right_idx]==sym)
00216                     right_idx++;
00217 
00218                 result++;
00219             }
00220             else if (avec[left_idx]<bvec[right_idx])
00221                 left_idx++;
00222             else
00223                 right_idx++;
00224         }
00225     }
00226     else
00227     {
00228         while (left_idx < alen && right_idx < blen)
00229         {
00230             if (avec[left_idx]==bvec[right_idx])
00231             {
00232                 int32_t old_left_idx=left_idx;
00233                 int32_t old_right_idx=right_idx;
00234 
00235                 uint16_t sym=avec[left_idx];
00236 
00237                 while (left_idx< alen && avec[left_idx]==sym)
00238                     left_idx++;
00239 
00240                 while (right_idx< blen && bvec[right_idx]==sym)
00241                     right_idx++;
00242 
00243                 result+=((float64_t) (left_idx-old_left_idx))*
00244                     ((float64_t) (right_idx-old_right_idx));
00245             }
00246             else if (avec[left_idx]<bvec[right_idx])
00247                 left_idx++;
00248             else
00249                 right_idx++;
00250         }
00251     }
00252 
00253     if (do_sort)
00254     {
00255         delete[] avec;
00256         delete[] bvec;
00257     }
00258 
00259     return result;
00260 }
00261 
00262 void CCommWordStringKernel::add_to_normal(int32_t vec_idx, float64_t weight)
00263 {
00264     int32_t len=-1;
00265     uint16_t* vec=((CStringFeatures<uint16_t>*) lhs)->
00266         get_feature_vector(vec_idx, len);
00267 
00268     if (len>0)
00269     {
00270         int32_t j, last_j=0;
00271         if (use_sign)
00272         {
00273             for (j=1; j<len; j++)
00274             {
00275                 if (vec[j]==vec[j-1])
00276                     continue;
00277 
00278                 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00279                     normalize_lhs(weight, vec_idx);
00280             }
00281 
00282             dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00283                 normalize_lhs(weight, vec_idx);
00284         }
00285         else
00286         {
00287             for (j=1; j<len; j++)
00288             {
00289                 if (vec[j]==vec[j-1])
00290                     continue;
00291 
00292                 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00293                     normalize_lhs(weight*(j-last_j), vec_idx);
00294                 last_j = j;
00295             }
00296 
00297             dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00298                 normalize_lhs(weight*(len-last_j), vec_idx);
00299         }
00300         set_is_initialized(true);
00301     }
00302 }
00303 
00304 void CCommWordStringKernel::clear_normal()
00305 {
00306     memset(dictionary_weights, 0, dictionary_size*sizeof(float64_t));
00307     set_is_initialized(false);
00308 }
00309 
00310 bool CCommWordStringKernel::init_optimization(
00311     int32_t count, int32_t* IDX, float64_t* weights)
00312 {
00313     delete_optimization();
00314 
00315     if (count<=0)
00316     {
00317         set_is_initialized(true);
00318         SG_DEBUG("empty set of SVs\n");
00319         return true;
00320     }
00321 
00322     SG_DEBUG("initializing CCommWordStringKernel optimization\n");
00323 
00324     for (int32_t i=0; i<count; i++)
00325     {
00326         if ( (i % (count/10+1)) == 0)
00327             SG_PROGRESS(i, 0, count);
00328 
00329         add_to_normal(IDX[i], weights[i]);
00330     }
00331 
00332     set_is_initialized(true);
00333     return true;
00334 }
00335 
00336 bool CCommWordStringKernel::delete_optimization() 
00337 {
00338     SG_DEBUG( "deleting CCommWordStringKernel optimization\n");
00339 
00340     clear_normal();
00341     return true;
00342 }
00343 
00344 float64_t CCommWordStringKernel::compute_optimized(int32_t i)
00345 { 
00346     if (!get_is_initialized())
00347     {
00348       SG_ERROR( "CCommWordStringKernel optimization not initialized\n");
00349         return 0 ; 
00350     }
00351 
00352     float64_t result = 0;
00353     int32_t len = -1;
00354     uint16_t* vec=((CStringFeatures<uint16_t>*) rhs)->
00355         get_feature_vector(i, len);
00356 
00357     int32_t j, last_j=0;
00358     if (vec && len>0)
00359     {
00360         if (use_sign)
00361         {
00362             for (j=1; j<len; j++)
00363             {
00364                 if (vec[j]==vec[j-1])
00365                     continue;
00366 
00367                 result += dictionary_weights[(int32_t) vec[j-1]];
00368             }
00369 
00370             result += dictionary_weights[(int32_t) vec[len-1]];
00371         }
00372         else
00373         {
00374             for (j=1; j<len; j++)
00375             {
00376                 if (vec[j]==vec[j-1])
00377                     continue;
00378 
00379                 result += dictionary_weights[(int32_t) vec[j-1]]*(j-last_j);
00380                 last_j = j;
00381             }
00382 
00383             result += dictionary_weights[(int32_t) vec[len-1]]*(len-last_j);
00384         }
00385 
00386         result=normalizer->normalize_rhs(result, i);
00387     }
00388     return result;
00389 }
00390 
00391 float64_t* CCommWordStringKernel::compute_scoring(
00392     int32_t max_degree, int32_t& num_feat, int32_t& num_sym, float64_t* target,
00393     int32_t num_suppvec, int32_t* IDX, float64_t* alphas, bool do_init)
00394 {
00395     ASSERT(lhs);
00396     CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00397     num_feat=1;//str->get_max_vector_length();
00398     CAlphabet* alpha=str->get_alphabet();
00399     ASSERT(alpha);
00400     int32_t num_bits=alpha->get_num_bits();
00401     int32_t order=str->get_order();
00402     ASSERT(max_degree<=order);
00403     //int32_t num_words=(int32_t) str->get_num_symbols();
00404     int32_t num_words=(int32_t) str->get_original_num_symbols();
00405     int32_t offset=0;
00406 
00407     num_sym=0;
00408     
00409     for (int32_t i=0; i<order; i++)
00410         num_sym+=CMath::pow((int32_t) num_words,i+1);
00411 
00412     SG_DEBUG("num_words:%d, order:%d, len:%d sz:%d (len*sz:%d)\n", num_words, order,
00413             num_feat, num_sym, num_feat*num_sym);
00414 
00415     if (!target)
00416         target=new float64_t[num_feat*num_sym];
00417     memset(target, 0, num_feat*num_sym*sizeof(float64_t));
00418 
00419     if (do_init)
00420         init_optimization(num_suppvec, IDX, alphas);
00421 
00422     uint32_t kmer_mask=0;
00423     uint32_t words=CMath::pow((int32_t) num_words,(int32_t) order);
00424 
00425     for (int32_t o=0; o<max_degree; o++)
00426     {
00427         float64_t* contrib=&target[offset];
00428         offset+=CMath::pow((int32_t) num_words,(int32_t) o+1);
00429 
00430         kmer_mask=(kmer_mask<<(num_bits)) | str->get_masked_symbols(0xffff, 1);
00431 
00432         for (int32_t p=-o; p<order; p++)
00433         {
00434             int32_t o_sym=0, m_sym=0, il=0,ir=0, jl=0;
00435             uint32_t imer_mask=kmer_mask;
00436             uint32_t jmer_mask=kmer_mask;
00437 
00438             if (p<0)
00439             {
00440                 il=-p;
00441                 m_sym=order-o-p-1;
00442                 o_sym=-p;
00443             }
00444             else if (p<order-o)
00445             {
00446                 ir=p;
00447                 m_sym=order-o-1;
00448             }
00449             else
00450             {
00451                 ir=p;
00452                 m_sym=p;
00453                 o_sym=p-order+o+1;
00454                 jl=order-ir;
00455                 imer_mask=(kmer_mask>>(num_bits*o_sym));
00456                 jmer_mask=(kmer_mask>>(num_bits*jl));
00457             }
00458 
00459             float64_t marginalizer=
00460                 1.0/CMath::pow((int32_t) num_words,(int32_t) m_sym);
00461             
00462             for (uint32_t i=0; i<words; i++)
00463             {
00464                 uint16_t x= ((i << (num_bits*il)) >> (num_bits*ir)) & imer_mask;
00465 
00466                 if (p>=0 && p<order-o)
00467                 {
00468 //#define DEBUG_COMMSCORING
00469 #ifdef DEBUG_COMMSCORING
00470                     SG_PRINT("o=%d/%d p=%d/%d i=0x%x x=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d marg=%g o_sym:%d m_sym:%d weight(",
00471                             o,order, p,order, i, x, imer_mask, jmer_mask, kmer_mask, il, ir, marginalizer, o_sym, m_sym);
00472 
00473                     SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n", 
00474                             alpha->remap_to_char((x>>(3*num_bits))&0x03), alpha->remap_to_char((x>>(2*num_bits))&0x03),
00475                             alpha->remap_to_char((x>>num_bits)&0x03), alpha->remap_to_char(x&0x03),
00476                             alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00477                             alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00478                             dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00479 #endif
00480                     contrib[x]+=dictionary_weights[i]*marginalizer;
00481                 }
00482                 else
00483                 {
00484                     for (uint32_t j=0; j< (uint32_t) CMath::pow((int32_t) num_words, (int32_t) o_sym); j++)
00485                     {
00486                         uint32_t c=x | ((j & jmer_mask) << (num_bits*jl));
00487 #ifdef DEBUG_COMMSCORING
00488 
00489                         SG_PRINT("o=%d/%d p=%d/%d i=0x%x j=0x%x x=0x%x c=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d jl=%d marg=%g o_sym:%d m_sym:%d weight(",
00490                                 o,order, p,order, i, j, x, c, imer_mask, jmer_mask, kmer_mask, il, ir, jl, marginalizer, o_sym, m_sym);
00491                         SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n", 
00492                                 alpha->remap_to_char((c>>(3*num_bits))&0x03), alpha->remap_to_char((c>>(2*num_bits))&0x03),
00493                                 alpha->remap_to_char((c>>num_bits)&0x03), alpha->remap_to_char(c&0x03),
00494                                 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00495                                 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00496                                 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00497 #endif
00498                         contrib[c]+=dictionary_weights[i]*marginalizer;
00499                     }
00500                 }
00501             }
00502         }
00503     }
00504 
00505     for (int32_t i=1; i<num_feat; i++)
00506         memcpy(&target[num_sym*i], target, num_sym*sizeof(float64_t));
00507 
00508     SG_UNREF(alpha);
00509 
00510     return target;
00511 }
00512 
00513 
00514 char* CCommWordStringKernel::compute_consensus(
00515     int32_t &result_len, int32_t num_suppvec, int32_t* IDX, float64_t* alphas)
00516 {
00517     ASSERT(lhs);
00518     ASSERT(IDX);
00519     ASSERT(alphas);
00520 
00521     CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00522     int32_t num_words=(int32_t) str->get_num_symbols();
00523     int32_t num_feat=str->get_max_vector_length();
00524     int64_t total_len=((int64_t) num_feat) * num_words;
00525     CAlphabet* alpha=((CStringFeatures<uint16_t>*) lhs)->get_alphabet();
00526     ASSERT(alpha);
00527     int32_t num_bits=alpha->get_num_bits();
00528     int32_t order=str->get_order();
00529     int32_t max_idx=-1;
00530     float64_t max_score=0; 
00531     result_len=num_feat+order-1;
00532 
00533     //init
00534     init_optimization(num_suppvec, IDX, alphas);
00535 
00536     char* result=new char[result_len];
00537     int32_t* bt=new int32_t[total_len];
00538     float64_t* score=new float64_t[total_len];
00539 
00540     for (int64_t i=0; i<total_len; i++)
00541     {
00542         bt[i]=-1;
00543         score[i]=0;
00544     }
00545 
00546     for (int32_t t=0; t<num_words; t++)
00547         score[t]=dictionary_weights[t];
00548 
00549     //dynamic program
00550     for (int32_t i=1; i<num_feat; i++)
00551     {
00552         for (int32_t t1=0; t1<num_words; t1++)
00553         {
00554             max_idx=-1;
00555             max_score=0; 
00556 
00557             /* ignore weights the svm does not care about 
00558              * (has not seen in training). note that this assumes that zero 
00559              * weights are very unlikely to appear elsewise */
00560 
00561             //if (dictionary_weights[t1]==0.0)
00562                 //continue;
00563 
00564             /* iterate over words t ending on t1 and find the highest scoring
00565              * pair */
00566             uint16_t suffix=(uint16_t) t1 >> num_bits;
00567 
00568             for (int32_t sym=0; sym<str->get_original_num_symbols(); sym++)
00569             {
00570                 uint16_t t=suffix | sym << (num_bits*(order-1));
00571 
00572                 //if (dictionary_weights[t]==0.0)
00573                 //  continue;
00574 
00575                 float64_t sc=score[num_words*(i-1) + t]+dictionary_weights[t1];
00576                 if (sc > max_score || max_idx==-1)
00577                 {
00578                     max_idx=t;
00579                     max_score=sc;
00580                 }
00581             }
00582             ASSERT(max_idx!=-1);
00583 
00584             score[num_words*i + t1]=max_score;
00585             bt[num_words*i + t1]=max_idx;
00586         }
00587     }
00588 
00589     //backtracking
00590     max_idx=0;
00591     max_score=score[num_words*(num_feat-1) + 0];
00592     for (int32_t t=1; t<num_words; t++)
00593     {
00594         float64_t sc=score[num_words*(num_feat-1) + t];
00595         if (sc>max_score)
00596         {
00597             max_idx=t;
00598             max_score=sc;
00599         }
00600     }
00601 
00602     SG_PRINT("max_idx:%i, max_score:%f\n", max_idx, max_score);
00603     
00604     for (int32_t i=result_len-1; i>=num_feat; i--)
00605         result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(result_len-1-i)), 1) );
00606 
00607     for (int32_t i=num_feat-1; i>=0; i--)
00608     {
00609         result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(order-1)), 1) );
00610         max_idx=bt[num_words*i + max_idx];
00611     }
00612 
00613     delete[] bt;
00614     delete[] score;
00615     SG_UNREF(alpha);
00616     return result;
00617 }

SHOGUN Machine Learning Toolbox - Documentation