WeightedCommWordStringKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Soeren Sonnenburg
00008  * Copyright (C) 1999-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "kernel/WeightedCommWordStringKernel.h"
00013 #include "features/StringFeatures.h"
00014 #include "lib/io.h"
00015 
00016 CWeightedCommWordStringKernel::CWeightedCommWordStringKernel(
00017     INT size, bool us, ENormalizationType n)
00018 : CCommWordStringKernel(size, us, n), degree(0), weights(NULL)
00019 {
00020     init_dictionary(1<<(sizeof(WORD)*9));
00021     ASSERT(us==false);
00022 }
00023 
00024 CWeightedCommWordStringKernel::CWeightedCommWordStringKernel(
00025     CStringFeatures<WORD>* l, CStringFeatures<WORD>* r,
00026     bool us, ENormalizationType n, INT size)
00027 : CCommWordStringKernel(size, us, n), degree(0), weights(NULL)
00028 {
00029     init_dictionary(1<<(sizeof(WORD)*9));
00030     ASSERT(us==false);
00031 
00032     init(l,r);
00033 }
00034 
00035 CWeightedCommWordStringKernel::~CWeightedCommWordStringKernel()
00036 {
00037     delete[] weights;
00038 }
00039 
00040 bool CWeightedCommWordStringKernel::init(CFeatures* l, CFeatures* r)
00041 {
00042     ASSERT(((CStringFeatures<WORD>*) l)->get_order() ==
00043             ((CStringFeatures<WORD>*) r)->get_order());
00044     degree=((CStringFeatures<WORD>*) l)->get_order();
00045     set_wd_weights();
00046 
00047     return CCommWordStringKernel::init(l,r);
00048 }
00049 
00050 void CWeightedCommWordStringKernel::cleanup()
00051 {
00052     delete[] weights;
00053     weights=NULL;
00054 
00055     CCommWordStringKernel::cleanup();
00056 }
00057 bool CWeightedCommWordStringKernel::set_wd_weights()
00058 {
00059     SG_DEBUG("WSPEC degree: %d\n", degree);
00060     delete[] weights;
00061     weights=new DREAL[degree];
00062 
00063     INT i;
00064     DREAL sum=0;
00065     for (i=0; i<degree; i++)
00066     {
00067         weights[i]=degree-i;
00068         sum+=weights[i];
00069     }
00070     for (i=0; i<degree; i++)
00071         weights[i]/=sum;
00072 
00073     return weights!=NULL;
00074 }
00075 
00076 bool CWeightedCommWordStringKernel::set_weights(DREAL* w, INT d)
00077 {
00078     ASSERT(d==degree);
00079 
00080     delete[] weights;
00081     weights=new DREAL[degree];
00082     for (INT i=0; i<degree; i++)
00083         weights[i]=w[i];
00084     return true;
00085 }
00086   
00087 DREAL CWeightedCommWordStringKernel::compute_helper(INT idx_a, INT idx_b, bool do_sort)
00088 {
00089     INT alen, blen;
00090 
00091     CStringFeatures<WORD>* l = (CStringFeatures<WORD>*) lhs;
00092     CStringFeatures<WORD>* r = (CStringFeatures<WORD>*) rhs;
00093 
00094     WORD* av=l->get_feature_vector(idx_a, alen);
00095     WORD* bv=r->get_feature_vector(idx_b, blen);
00096 
00097     WORD* avec=av;
00098     WORD* bvec=bv;
00099 
00100     if (do_sort)
00101     {
00102         if (alen>0)
00103         {
00104             avec=new WORD[alen];
00105             memcpy(avec, av, sizeof(WORD)*alen);
00106             CMath::radix_sort(avec, alen);
00107         }
00108         else
00109             avec=NULL;
00110 
00111         if (blen>0)
00112         {
00113             bvec=new WORD[blen];
00114             memcpy(bvec, bv, sizeof(WORD)*blen);
00115             CMath::radix_sort(bvec, blen);
00116         }
00117         else
00118             bvec=NULL;
00119     }
00120     else
00121     {
00122         if ( (l->get_num_preproc() != l->get_num_preprocessed()) ||
00123                 (r->get_num_preproc() != r->get_num_preprocessed()))
00124         {
00125             SG_ERROR("not all preprocessors have been applied to training (%d/%d)"
00126                     " or test (%d/%d) data\n", l->get_num_preprocessed(), l->get_num_preproc(),
00127                     r->get_num_preprocessed(), r->get_num_preproc());
00128         }
00129     }
00130 
00131     DREAL result=0;
00132     BYTE mask=0;
00133 
00134     for (INT d=0; d<degree; d++)
00135     {
00136         mask = mask | (1 << (degree-d-1));
00137         WORD masked=((CStringFeatures<WORD>*) lhs)->get_masked_symbols(0xffff, mask);
00138 
00139         INT left_idx=0;
00140         INT right_idx=0;
00141 
00142         while (left_idx < alen && right_idx < blen)
00143         {
00144             WORD lsym=avec[left_idx] & masked;
00145             WORD rsym=bvec[right_idx] & masked;
00146 
00147             if (lsym == rsym)
00148             {
00149                 INT old_left_idx=left_idx;
00150                 INT old_right_idx=right_idx;
00151 
00152                 while (left_idx<alen && (avec[left_idx] & masked) ==lsym)
00153                     left_idx++;
00154 
00155                 while (right_idx<blen && (bvec[right_idx] & masked) ==lsym)
00156                     right_idx++;
00157 
00158                 result+=weights[d]*(left_idx-old_left_idx)*(right_idx-old_right_idx);
00159             }
00160             else if (lsym<rsym)
00161                 left_idx++;
00162             else
00163                 right_idx++;
00164         }
00165     }
00166 
00167     if (do_sort)
00168     {
00169         delete[] avec;
00170         delete[] bvec;
00171     }
00172 
00173     if (initialized)
00174     {
00175         switch (normalization)
00176         {
00177             case NO_NORMALIZATION:
00178                 return result;
00179             case SQRT_NORMALIZATION:
00180                 return result/sqrt(sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]);
00181             case FULL_NORMALIZATION:
00182                 return result/(sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]);
00183             case SQRTLEN_NORMALIZATION:
00184                 return result/sqrt(sqrt(alen*blen));
00185             case LEN_NORMALIZATION:
00186                 return result/sqrt(alen*blen);
00187             case SQLEN_NORMALIZATION:
00188                 return result/(alen*blen);
00189             default:
00190                 SG_ERROR( "Unknown Normalization in use!\n");
00191                 return -CMath::INFTY;
00192         }
00193     }
00194     else
00195         return result;
00196 }
00197 
00198 void CWeightedCommWordStringKernel::add_to_normal(INT vec_idx, DREAL weight)
00199 {
00200     INT len=-1;
00201     CStringFeatures<WORD>* s=(CStringFeatures<WORD>*) lhs;
00202     WORD* vec=s->get_feature_vector(vec_idx, len);
00203 
00204     if (len>0)
00205     {
00206         for (INT j=0; j<len; j++)
00207         {
00208             BYTE mask=0;
00209             INT offs=0;
00210             for (INT d=0; d<degree; d++)
00211             {
00212                 mask = mask | (1 << (degree-d-1));
00213                 INT idx=s->get_masked_symbols(vec[j], mask);
00214                 idx=s->shift_symbol(idx, degree-d-1);
00215                 dictionary_weights[offs + idx] += normalize_weight(sqrtdiag_lhs, weight*weights[d], vec_idx, len, normalization);
00216                 offs+=s->shift_offset(1,d+1);
00217             }
00218         }
00219 
00220         set_is_initialized(true);
00221     }
00222 }
00223 
00224 void CWeightedCommWordStringKernel::merge_normal()
00225 {
00226     ASSERT(get_is_initialized());
00227     ASSERT(use_sign==false);
00228 
00229     CStringFeatures<WORD>* s=(CStringFeatures<WORD>*) rhs;
00230     UINT num_symbols=(UINT) s->get_num_symbols();
00231     INT dic_size=1<<(sizeof(WORD)*8);
00232     DREAL* dic=new DREAL[dic_size];
00233     memset(dic, 0, sizeof(DREAL)*dic_size);
00234 
00235     for (UINT sym=0; sym<num_symbols; sym++)
00236     {
00237         DREAL result=0;
00238         BYTE mask=0;
00239         INT offs=0;
00240         for (INT d=0; d<degree; d++)
00241         {
00242             mask = mask | (1 << (degree-d-1));
00243             INT idx=s->get_masked_symbols(sym, mask);
00244             idx=s->shift_symbol(idx, degree-d-1);
00245             result += dictionary_weights[offs + idx];
00246             offs+=s->shift_offset(1,d+1);
00247         }
00248         dic[sym]=result;
00249     }
00250 
00251     init_dictionary(1<<(sizeof(WORD)*8));
00252     memcpy(dictionary_weights, dic, sizeof(DREAL)*dic_size);
00253     delete[] dic;
00254 }
00255 
00256 DREAL CWeightedCommWordStringKernel::compute_optimized(INT i) 
00257 { 
00258     if (!get_is_initialized())
00259         SG_ERROR( "CCommWordStringKernel optimization not initialized\n");
00260 
00261     ASSERT(use_sign==false);
00262 
00263     DREAL result=0;
00264     INT len=-1;
00265     CStringFeatures<WORD>* s=(CStringFeatures<WORD>*) rhs;
00266     WORD* vec=s->get_feature_vector(i, len);
00267 
00268     if (vec && len>0)
00269     {
00270         for (INT j=0; j<len; j++)
00271         {
00272             BYTE mask=0;
00273             INT offs=0;
00274             for (INT d=0; d<degree; d++)
00275             {
00276                 mask = mask | (1 << (degree-d-1));
00277                 INT idx=s->get_masked_symbols(vec[j], mask);
00278                 idx=s->shift_symbol(idx, degree-d-1);
00279                 result += dictionary_weights[offs + idx];
00280                 offs+=s->shift_offset(1,d+1);
00281             }
00282         }
00283 
00284         result=normalize_weight(sqrtdiag_rhs, result, i, len, normalization);
00285     }
00286     return result;
00287 }
00288 
00289 DREAL* CWeightedCommWordStringKernel::compute_scoring(INT max_degree, INT& num_feat,
00290         INT& num_sym, DREAL* target, INT num_suppvec, INT* IDX, DREAL* alphas, bool do_init)
00291 {
00292     if (do_init)
00293         CCommWordStringKernel::init_optimization(num_suppvec, IDX, alphas);
00294 
00295     INT dic_size=1<<(sizeof(WORD)*9);
00296     DREAL* dic=new DREAL[dic_size];
00297     memcpy(dic, dictionary_weights, sizeof(DREAL)*dic_size);
00298 
00299     merge_normal();
00300     DREAL* result=CCommWordStringKernel::compute_scoring(max_degree, num_feat,
00301             num_sym, target, num_suppvec, IDX, alphas, false);
00302 
00303     init_dictionary(1<<(sizeof(WORD)*9));
00304     memcpy(dictionary_weights,dic,  sizeof(DREAL)*dic_size);
00305     delete[] dic;
00306 
00307     return result;
00308 }

SHOGUN Machine Learning Toolbox - Documentation