00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/common.h"
00012 #include "kernel/WeightedCommWordStringKernel.h"
00013 #include "features/StringFeatures.h"
00014 #include "lib/io.h"
00015
00016 CWeightedCommWordStringKernel::CWeightedCommWordStringKernel(
00017 INT size, bool us, ENormalizationType n)
00018 : CCommWordStringKernel(size, us, n), degree(0), weights(NULL)
00019 {
00020 init_dictionary(1<<(sizeof(WORD)*9));
00021 ASSERT(us==false);
00022 }
00023
00024 CWeightedCommWordStringKernel::CWeightedCommWordStringKernel(
00025 CStringFeatures<WORD>* l, CStringFeatures<WORD>* r,
00026 bool us, ENormalizationType n, INT size)
00027 : CCommWordStringKernel(size, us, n), degree(0), weights(NULL)
00028 {
00029 init_dictionary(1<<(sizeof(WORD)*9));
00030 ASSERT(us==false);
00031
00032 init(l,r);
00033 }
00034
00035 CWeightedCommWordStringKernel::~CWeightedCommWordStringKernel()
00036 {
00037 delete[] weights;
00038 }
00039
00040 bool CWeightedCommWordStringKernel::init(CFeatures* l, CFeatures* r)
00041 {
00042 ASSERT(((CStringFeatures<WORD>*) l)->get_order() ==
00043 ((CStringFeatures<WORD>*) r)->get_order());
00044 degree=((CStringFeatures<WORD>*) l)->get_order();
00045 set_wd_weights();
00046
00047 return CCommWordStringKernel::init(l,r);
00048 }
00049
00050 void CWeightedCommWordStringKernel::cleanup()
00051 {
00052 delete[] weights;
00053 weights=NULL;
00054
00055 CCommWordStringKernel::cleanup();
00056 }
00057 bool CWeightedCommWordStringKernel::set_wd_weights()
00058 {
00059 SG_DEBUG("WSPEC degree: %d\n", degree);
00060 delete[] weights;
00061 weights=new DREAL[degree];
00062
00063 INT i;
00064 DREAL sum=0;
00065 for (i=0; i<degree; i++)
00066 {
00067 weights[i]=degree-i;
00068 sum+=weights[i];
00069 }
00070 for (i=0; i<degree; i++)
00071 weights[i]/=sum;
00072
00073 return weights!=NULL;
00074 }
00075
00076 bool CWeightedCommWordStringKernel::set_weights(DREAL* w, INT d)
00077 {
00078 ASSERT(d==degree);
00079
00080 delete[] weights;
00081 weights=new DREAL[degree];
00082 for (INT i=0; i<degree; i++)
00083 weights[i]=w[i];
00084 return true;
00085 }
00086
00087 DREAL CWeightedCommWordStringKernel::compute_helper(INT idx_a, INT idx_b, bool do_sort)
00088 {
00089 INT alen, blen;
00090
00091 CStringFeatures<WORD>* l = (CStringFeatures<WORD>*) lhs;
00092 CStringFeatures<WORD>* r = (CStringFeatures<WORD>*) rhs;
00093
00094 WORD* av=l->get_feature_vector(idx_a, alen);
00095 WORD* bv=r->get_feature_vector(idx_b, blen);
00096
00097 WORD* avec=av;
00098 WORD* bvec=bv;
00099
00100 if (do_sort)
00101 {
00102 if (alen>0)
00103 {
00104 avec=new WORD[alen];
00105 memcpy(avec, av, sizeof(WORD)*alen);
00106 CMath::radix_sort(avec, alen);
00107 }
00108 else
00109 avec=NULL;
00110
00111 if (blen>0)
00112 {
00113 bvec=new WORD[blen];
00114 memcpy(bvec, bv, sizeof(WORD)*blen);
00115 CMath::radix_sort(bvec, blen);
00116 }
00117 else
00118 bvec=NULL;
00119 }
00120 else
00121 {
00122 if ( (l->get_num_preproc() != l->get_num_preprocessed()) ||
00123 (r->get_num_preproc() != r->get_num_preprocessed()))
00124 {
00125 SG_ERROR("not all preprocessors have been applied to training (%d/%d)"
00126 " or test (%d/%d) data\n", l->get_num_preprocessed(), l->get_num_preproc(),
00127 r->get_num_preprocessed(), r->get_num_preproc());
00128 }
00129 }
00130
00131 DREAL result=0;
00132 BYTE mask=0;
00133
00134 for (INT d=0; d<degree; d++)
00135 {
00136 mask = mask | (1 << (degree-d-1));
00137 WORD masked=((CStringFeatures<WORD>*) lhs)->get_masked_symbols(0xffff, mask);
00138
00139 INT left_idx=0;
00140 INT right_idx=0;
00141
00142 while (left_idx < alen && right_idx < blen)
00143 {
00144 WORD lsym=avec[left_idx] & masked;
00145 WORD rsym=bvec[right_idx] & masked;
00146
00147 if (lsym == rsym)
00148 {
00149 INT old_left_idx=left_idx;
00150 INT old_right_idx=right_idx;
00151
00152 while (left_idx<alen && (avec[left_idx] & masked) ==lsym)
00153 left_idx++;
00154
00155 while (right_idx<blen && (bvec[right_idx] & masked) ==lsym)
00156 right_idx++;
00157
00158 result+=weights[d]*(left_idx-old_left_idx)*(right_idx-old_right_idx);
00159 }
00160 else if (lsym<rsym)
00161 left_idx++;
00162 else
00163 right_idx++;
00164 }
00165 }
00166
00167 if (do_sort)
00168 {
00169 delete[] avec;
00170 delete[] bvec;
00171 }
00172
00173 if (initialized)
00174 {
00175 switch (normalization)
00176 {
00177 case NO_NORMALIZATION:
00178 return result;
00179 case SQRT_NORMALIZATION:
00180 return result/sqrt(sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]);
00181 case FULL_NORMALIZATION:
00182 return result/(sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]);
00183 case SQRTLEN_NORMALIZATION:
00184 return result/sqrt(sqrt(alen*blen));
00185 case LEN_NORMALIZATION:
00186 return result/sqrt(alen*blen);
00187 case SQLEN_NORMALIZATION:
00188 return result/(alen*blen);
00189 default:
00190 SG_ERROR( "Unknown Normalization in use!\n");
00191 return -CMath::INFTY;
00192 }
00193 }
00194 else
00195 return result;
00196 }
00197
00198 void CWeightedCommWordStringKernel::add_to_normal(INT vec_idx, DREAL weight)
00199 {
00200 INT len=-1;
00201 CStringFeatures<WORD>* s=(CStringFeatures<WORD>*) lhs;
00202 WORD* vec=s->get_feature_vector(vec_idx, len);
00203
00204 if (len>0)
00205 {
00206 for (INT j=0; j<len; j++)
00207 {
00208 BYTE mask=0;
00209 INT offs=0;
00210 for (INT d=0; d<degree; d++)
00211 {
00212 mask = mask | (1 << (degree-d-1));
00213 INT idx=s->get_masked_symbols(vec[j], mask);
00214 idx=s->shift_symbol(idx, degree-d-1);
00215 dictionary_weights[offs + idx] += normalize_weight(sqrtdiag_lhs, weight*weights[d], vec_idx, len, normalization);
00216 offs+=s->shift_offset(1,d+1);
00217 }
00218 }
00219
00220 set_is_initialized(true);
00221 }
00222 }
00223
00224 void CWeightedCommWordStringKernel::merge_normal()
00225 {
00226 ASSERT(get_is_initialized());
00227 ASSERT(use_sign==false);
00228
00229 CStringFeatures<WORD>* s=(CStringFeatures<WORD>*) rhs;
00230 UINT num_symbols=(UINT) s->get_num_symbols();
00231 INT dic_size=1<<(sizeof(WORD)*8);
00232 DREAL* dic=new DREAL[dic_size];
00233 memset(dic, 0, sizeof(DREAL)*dic_size);
00234
00235 for (UINT sym=0; sym<num_symbols; sym++)
00236 {
00237 DREAL result=0;
00238 BYTE mask=0;
00239 INT offs=0;
00240 for (INT d=0; d<degree; d++)
00241 {
00242 mask = mask | (1 << (degree-d-1));
00243 INT idx=s->get_masked_symbols(sym, mask);
00244 idx=s->shift_symbol(idx, degree-d-1);
00245 result += dictionary_weights[offs + idx];
00246 offs+=s->shift_offset(1,d+1);
00247 }
00248 dic[sym]=result;
00249 }
00250
00251 init_dictionary(1<<(sizeof(WORD)*8));
00252 memcpy(dictionary_weights, dic, sizeof(DREAL)*dic_size);
00253 delete[] dic;
00254 }
00255
00256 DREAL CWeightedCommWordStringKernel::compute_optimized(INT i)
00257 {
00258 if (!get_is_initialized())
00259 SG_ERROR( "CCommWordStringKernel optimization not initialized\n");
00260
00261 ASSERT(use_sign==false);
00262
00263 DREAL result=0;
00264 INT len=-1;
00265 CStringFeatures<WORD>* s=(CStringFeatures<WORD>*) rhs;
00266 WORD* vec=s->get_feature_vector(i, len);
00267
00268 if (vec && len>0)
00269 {
00270 for (INT j=0; j<len; j++)
00271 {
00272 BYTE mask=0;
00273 INT offs=0;
00274 for (INT d=0; d<degree; d++)
00275 {
00276 mask = mask | (1 << (degree-d-1));
00277 INT idx=s->get_masked_symbols(vec[j], mask);
00278 idx=s->shift_symbol(idx, degree-d-1);
00279 result += dictionary_weights[offs + idx];
00280 offs+=s->shift_offset(1,d+1);
00281 }
00282 }
00283
00284 result=normalize_weight(sqrtdiag_rhs, result, i, len, normalization);
00285 }
00286 return result;
00287 }
00288
00289 DREAL* CWeightedCommWordStringKernel::compute_scoring(INT max_degree, INT& num_feat,
00290 INT& num_sym, DREAL* target, INT num_suppvec, INT* IDX, DREAL* alphas, bool do_init)
00291 {
00292 if (do_init)
00293 CCommWordStringKernel::init_optimization(num_suppvec, IDX, alphas);
00294
00295 INT dic_size=1<<(sizeof(WORD)*9);
00296 DREAL* dic=new DREAL[dic_size];
00297 memcpy(dic, dictionary_weights, sizeof(DREAL)*dic_size);
00298
00299 merge_normal();
00300 DREAL* result=CCommWordStringKernel::compute_scoring(max_degree, num_feat,
00301 num_sym, target, num_suppvec, IDX, alphas, false);
00302
00303 init_dictionary(1<<(sizeof(WORD)*9));
00304 memcpy(dictionary_weights,dic, sizeof(DREAL)*dic_size);
00305 delete[] dic;
00306
00307 return result;
00308 }